From 869fa1d06744760955a5943d613520cb473cd336 Mon Sep 17 00:00:00 2001 From: Brayan Jules Date: Tue, 15 Aug 2023 17:54:54 -0400 Subject: [PATCH 001/103] feat(rust): utf8 to temporal casting --- crates/polars-core/src/chunked_array/cast.rs | 31 ++++++++++++++++++++ py-polars/tests/unit/test_lazy.py | 2 +- py-polars/tests/unit/test_queries.py | 27 +++++++++++++++++ 3 files changed, 59 insertions(+), 1 deletion(-) diff --git a/crates/polars-core/src/chunked_array/cast.rs b/crates/polars-core/src/chunked_array/cast.rs index 6a93bc7f0364..0acdf51739f3 100644 --- a/crates/polars-core/src/chunked_array/cast.rs +++ b/crates/polars-core/src/chunked_array/cast.rs @@ -7,6 +7,7 @@ use arrow::compute::cast::CastOptions; use crate::chunked_array::categorical::CategoricalChunkedBuilder; #[cfg(feature = "timezones")] use crate::chunked_array::temporal::validate_time_zone; +use crate::prelude::DataType::Datetime; use crate::prelude::*; pub(crate) fn cast_chunks( @@ -203,6 +204,36 @@ impl ChunkCast for Utf8Chunked { polars_bail!(ComputeError: "expected 'precision' or 'scale' when casting to Decimal") }, }, + #[cfg(feature = "dtype-date")] + DataType::Date => { + let result = cast_chunks(&self.chunks, &data_type, true)?; + let out = Series::try_from((self.name(), result))?; + Ok(out) + }, + #[cfg(feature = "dtype-datetime")] + DataType::Datetime(tu, tz) => { + let out = match tz { + #[cfg(feature = "timezones")] + Some(tz) => { + validate_time_zone(tz)?; + let result = cast_chunks( + &self.chunks, + &Datetime(TimeUnit::Nanoseconds, Some(tz.clone())), + true, + )?; + Series::try_from((self.name(), result)) + }, + _ => { + let result = cast_chunks( + &self.chunks, + &Datetime(TimeUnit::Nanoseconds, None), + true, + )?; + Series::try_from((self.name(), result)) + }, + }; + out + }, _ => cast_impl(self.name(), &self.chunks, data_type), } } diff --git a/py-polars/tests/unit/test_lazy.py b/py-polars/tests/unit/test_lazy.py index f3fc2cdcab14..47a108580083 100644 --- a/py-polars/tests/unit/test_lazy.py +++ b/py-polars/tests/unit/test_lazy.py @@ -1313,7 +1313,7 @@ def test_quadratic_behavior_4736() -> None: ldf.select(reduce(add, (pl.col(fld) for fld in ldf.columns))) -@pytest.mark.parametrize("input_dtype", [pl.Utf8, pl.Int64, pl.Float64]) +@pytest.mark.parametrize("input_dtype", [pl.Int64, pl.Float64]) def test_from_epoch(input_dtype: pl.PolarsDataType) -> None: ldf = pl.LazyFrame( [ diff --git a/py-polars/tests/unit/test_queries.py b/py-polars/tests/unit/test_queries.py index 946c0f42e6f6..18dc04b589a9 100644 --- a/py-polars/tests/unit/test_queries.py +++ b/py-polars/tests/unit/test_queries.py @@ -364,3 +364,30 @@ def test_datetime_supertype_5236() -> None: ) assert out.shape == (0, 2) assert out.dtypes == [pl.Datetime("ns", "UTC")] * 2 + + +def test_utf8_date() -> None: + df = pl.DataFrame({"x1": ["2021-01-01"]}).with_columns( + **{"x1-date": pl.col("x1").cast(pl.Date)} + ) + out = df.select(pl.col("x1-date")) + assert out.shape == (1, 1) + assert out.dtypes == [pl.Date] + + +def test_utf8_datetime() -> None: + df = pl.DataFrame( + {"x1": ["2021-12-19T16:39:57-02:00", "2022-12-19T16:39:57"]} + ).with_columns( + **{ + "x1-datetime-ns": pl.col("x1").cast(pl.Datetime(time_unit="ns")), + "x1-datetime-ms": pl.col("x1").cast(pl.Datetime(time_unit="ms")), + "x1-datetime-us": pl.col("x1").cast(pl.Datetime(time_unit="us")), + } + ) + + out = df.select( + pl.col("x1-datetime-ns"), pl.col("x1-datetime-ms"), pl.col("x1-datetime-us") + ) + assert out.shape == (2, 3) + assert out.dtypes == [pl.Datetime, pl.Datetime, pl.Datetime] From bb731535c90a82da6d1df2420048a64363aab3a7 Mon Sep 17 00:00:00 2001 From: Brayan Jules Date: Tue, 15 Aug 2023 17:54:54 -0400 Subject: [PATCH 002/103] feat(rust): utf8 to temporal casting --- crates/polars-core/src/chunked_array/cast.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/polars-core/src/chunked_array/cast.rs b/crates/polars-core/src/chunked_array/cast.rs index 0acdf51739f3..27ecd5bb6e7b 100644 --- a/crates/polars-core/src/chunked_array/cast.rs +++ b/crates/polars-core/src/chunked_array/cast.rs @@ -206,12 +206,12 @@ impl ChunkCast for Utf8Chunked { }, #[cfg(feature = "dtype-date")] DataType::Date => { - let result = cast_chunks(&self.chunks, &data_type, true)?; + let result = cast_chunks(&self.chunks, data_type, true)?; let out = Series::try_from((self.name(), result))?; Ok(out) }, #[cfg(feature = "dtype-datetime")] - DataType::Datetime(tu, tz) => { + DataType::Datetime(_tu, tz) => { let out = match tz { #[cfg(feature = "timezones")] Some(tz) => { From 34ba75a35fc9d7cf7610bdd4d0ba186c2168fbb6 Mon Sep 17 00:00:00 2001 From: Brayan Jules Date: Mon, 9 Oct 2023 23:37:07 -0300 Subject: [PATCH 003/103] feat: utf8 to timestamp/date casting support Support for different timeunits added in nano-arrow --- crates/nano-arrow/src/compute/cast/mod.rs | 36 +++++++++++++-- crates/nano-arrow/src/compute/cast/utf8_to.rs | 32 +++++++------ crates/nano-arrow/src/temporal_conversions.rs | 46 +++++++------------ crates/polars-core/src/chunked_array/cast.rs | 11 ++--- py-polars/tests/unit/test_queries.py | 21 ++++++++- 5 files changed, 91 insertions(+), 55 deletions(-) diff --git a/crates/nano-arrow/src/compute/cast/mod.rs b/crates/nano-arrow/src/compute/cast/mod.rs index f13a638a9c0d..0ac8e7b8085e 100644 --- a/crates/nano-arrow/src/compute/cast/mod.rs +++ b/crates/nano-arrow/src/compute/cast/mod.rs @@ -578,9 +578,23 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu LargeUtf8 => Ok(Box::new(utf8_to_large_utf8( array.as_any().downcast_ref().unwrap(), ))), - Timestamp(TimeUnit::Nanosecond, None) => utf8_to_naive_timestamp_ns_dyn::(array), + Timestamp(TimeUnit::Nanosecond, None) => { + utf8_to_naive_timestamp_dyn::(array, TimeUnit::Nanosecond) + }, + Timestamp(TimeUnit::Millisecond, None) => { + utf8_to_naive_timestamp_dyn::(array, TimeUnit::Millisecond) + }, + Timestamp(TimeUnit::Microsecond, None) => { + utf8_to_naive_timestamp_dyn::(array, TimeUnit::Microsecond) + }, Timestamp(TimeUnit::Nanosecond, Some(tz)) => { - utf8_to_timestamp_ns_dyn::(array, tz.clone()) + utf8_to_timestamp_dyn::(array, tz.clone(), TimeUnit::Nanosecond) + }, + Timestamp(TimeUnit::Millisecond, Some(tz)) => { + utf8_to_timestamp_dyn::(array, tz.clone(), TimeUnit::Millisecond) + }, + Timestamp(TimeUnit::Microsecond, Some(tz)) => { + utf8_to_timestamp_dyn::(array, tz.clone(), TimeUnit::Microsecond) }, _ => Err(Error::NotYetImplemented(format!( "Casting from {from_type:?} to {to_type:?} not supported", @@ -605,9 +619,23 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu to_type.clone(), ) .boxed()), - Timestamp(TimeUnit::Nanosecond, None) => utf8_to_naive_timestamp_ns_dyn::(array), + Timestamp(TimeUnit::Nanosecond, None) => { + utf8_to_naive_timestamp_dyn::(array, TimeUnit::Nanosecond) + }, + Timestamp(TimeUnit::Millisecond, None) => { + utf8_to_naive_timestamp_dyn::(array, TimeUnit::Millisecond) + }, + Timestamp(TimeUnit::Microsecond, None) => { + utf8_to_naive_timestamp_dyn::(array, TimeUnit::Microsecond) + }, Timestamp(TimeUnit::Nanosecond, Some(tz)) => { - utf8_to_timestamp_ns_dyn::(array, tz.clone()) + utf8_to_timestamp_dyn::(array, tz.clone(), TimeUnit::Nanosecond) + }, + Timestamp(TimeUnit::Millisecond, Some(tz)) => { + utf8_to_timestamp_dyn::(array, tz.clone(), TimeUnit::Millisecond) + }, + Timestamp(TimeUnit::Microsecond, Some(tz)) => { + utf8_to_timestamp_dyn::(array, tz.clone(), TimeUnit::Microsecond) }, _ => Err(Error::NotYetImplemented(format!( "Casting from {from_type:?} to {to_type:?} not supported", diff --git a/crates/nano-arrow/src/compute/cast/utf8_to.rs b/crates/nano-arrow/src/compute/cast/utf8_to.rs index 9c86ff85da54..c1d2cfa73414 100644 --- a/crates/nano-arrow/src/compute/cast/utf8_to.rs +++ b/crates/nano-arrow/src/compute/cast/utf8_to.rs @@ -2,12 +2,12 @@ use chrono::Datelike; use super::CastOptions; use crate::array::*; -use crate::datatypes::DataType; +use crate::datatypes::{DataType, TimeUnit}; use crate::error::Result; use crate::offset::Offset; use crate::temporal_conversions::{ - utf8_to_naive_timestamp_ns as utf8_to_naive_timestamp_ns_, - utf8_to_timestamp_ns as utf8_to_timestamp_ns_, EPOCH_DAYS_FROM_CE, + utf8_to_naive_timestamp as utf8_to_naive_timestamp_, utf8_to_timestamp as utf8_to_timestamp_, + EPOCH_DAYS_FROM_CE, }; use crate::types::NativeType; @@ -110,34 +110,40 @@ pub fn utf8_to_dictionary( Ok(array.into()) } -pub(super) fn utf8_to_naive_timestamp_ns_dyn( +pub(super) fn utf8_to_naive_timestamp_dyn( from: &dyn Array, + tu: TimeUnit, ) -> Result> { let from = from.as_any().downcast_ref().unwrap(); - Ok(Box::new(utf8_to_naive_timestamp_ns::(from))) + Ok(Box::new(utf8_to_naive_timestamp::(from, tu))) } -/// [`crate::temporal_conversions::utf8_to_timestamp_ns`] applied for RFC3339 formatting -pub fn utf8_to_naive_timestamp_ns(from: &Utf8Array) -> PrimitiveArray { - utf8_to_naive_timestamp_ns_(from, RFC3339) +/// [`crate::temporal_conversions::utf8_to_timestamp`] applied for RFC3339 formatting +pub fn utf8_to_naive_timestamp( + from: &Utf8Array, + tu: TimeUnit, +) -> PrimitiveArray { + utf8_to_naive_timestamp_(from, RFC3339, tu) } -pub(super) fn utf8_to_timestamp_ns_dyn( +pub(super) fn utf8_to_timestamp_dyn( from: &dyn Array, timezone: String, + tu: TimeUnit, ) -> Result> { let from = from.as_any().downcast_ref().unwrap(); - utf8_to_timestamp_ns::(from, timezone) + utf8_to_timestamp::(from, timezone, tu) .map(Box::new) .map(|x| x as Box) } -/// [`crate::temporal_conversions::utf8_to_timestamp_ns`] applied for RFC3339 formatting -pub fn utf8_to_timestamp_ns( +/// [`crate::temporal_conversions::utf8_to_timestamp`] applied for RFC3339 formatting +pub fn utf8_to_timestamp( from: &Utf8Array, timezone: String, + tu: TimeUnit, ) -> Result> { - utf8_to_timestamp_ns_(from, RFC3339, timezone) + utf8_to_timestamp_(from, RFC3339, timezone, tu) } /// Conversion of utf8 diff --git a/crates/nano-arrow/src/temporal_conversions.rs b/crates/nano-arrow/src/temporal_conversions.rs index 5058d1d887bd..8ba3d2523678 100644 --- a/crates/nano-arrow/src/temporal_conversions.rs +++ b/crates/nano-arrow/src/temporal_conversions.rs @@ -323,17 +323,6 @@ pub fn parse_offset(offset: &str) -> Result { .expect("FixedOffset::east out of bounds")) } -/// Parses `value` to `Option` consistent with the Arrow's definition of timestamp with timezone. -/// `tz` must be built from `timezone` (either via [`parse_offset`] or `chrono-tz`). -#[inline] -pub fn utf8_to_timestamp_ns_scalar( - value: &str, - fmt: &str, - tz: &T, -) -> Option { - utf8_to_timestamp_scalar(value, fmt, tz, &TimeUnit::Nanosecond) -} - /// Parses `value` to `Option` consistent with the Arrow's definition of timestamp with timezone. /// `tz` must be built from `timezone` (either via [`parse_offset`] or `chrono-tz`). /// Returns in scale `tz` of `TimeUnit`. @@ -364,12 +353,6 @@ pub fn utf8_to_timestamp_scalar( } } -/// Parses `value` to `Option` consistent with the Arrow's definition of timestamp without timezone. -#[inline] -pub fn utf8_to_naive_timestamp_ns_scalar(value: &str, fmt: &str) -> Option { - utf8_to_naive_timestamp_scalar(value, fmt, &TimeUnit::Nanosecond) -} - /// Parses `value` to `Option` consistent with the Arrow's definition of timestamp without timezone. /// Returns in scale `tz` of `TimeUnit`. #[inline] @@ -388,18 +371,18 @@ pub fn utf8_to_naive_timestamp_scalar(value: &str, fmt: &str, tu: &TimeUnit) -> .ok() } -fn utf8_to_timestamp_ns_impl( +fn utf8_to_timestamp_impl( array: &Utf8Array, fmt: &str, timezone: String, tz: T, + tu: TimeUnit, ) -> PrimitiveArray { let iter = array .iter() - .map(|x| x.and_then(|x| utf8_to_timestamp_ns_scalar(x, fmt, &tz))); + .map(|x| x.and_then(|x| utf8_to_timestamp_scalar(x, fmt, &tz, &tu))); - PrimitiveArray::from_trusted_len_iter(iter) - .to(DataType::Timestamp(TimeUnit::Nanosecond, Some(timezone))) + PrimitiveArray::from_trusted_len_iter(iter).to(DataType::Timestamp(tu, Some(timezone))) } /// Parses `value` to a [`chrono_tz::Tz`] with the Arrow's definition of timestamp with a timezone. @@ -413,13 +396,14 @@ pub fn parse_offset_tz(timezone: &str) -> Result { #[cfg(feature = "chrono-tz")] #[cfg_attr(docsrs, doc(cfg(feature = "chrono-tz")))] -fn chrono_tz_utf_to_timestamp_ns( +fn chrono_tz_utf_to_timestamp( array: &Utf8Array, fmt: &str, timezone: String, + tu: TimeUnit, ) -> Result> { let tz = parse_offset_tz(&timezone)?; - Ok(utf8_to_timestamp_ns_impl(array, fmt, timezone, tz)) + Ok(utf8_to_timestamp_impl(array, fmt, timezone, tz, tu)) } #[cfg(not(feature = "chrono-tz"))] @@ -436,22 +420,23 @@ fn chrono_tz_utf_to_timestamp_ns( /// Parses a [`Utf8Array`] to a timeozone-aware timestamp, i.e. [`PrimitiveArray`] with type `Timestamp(Nanosecond, Some(timezone))`. /// # Implementation /// * parsed values with timezone other than `timezone` are converted to `timezone`. -/// * parsed values without timezone are null. Use [`utf8_to_naive_timestamp_ns`] to parse naive timezones. +/// * parsed values without timezone are null. Use [`utf8_to_naive_timestamp`] to parse naive timezones. /// * Null elements remain null; non-parsable elements are null. /// The feature `"chrono-tz"` enables IANA and zoneinfo formats for `timezone`. /// # Error /// This function errors iff `timezone` is not parsable to an offset. -pub fn utf8_to_timestamp_ns( +pub fn utf8_to_timestamp( array: &Utf8Array, fmt: &str, timezone: String, + tu: TimeUnit, ) -> Result> { let tz = parse_offset(timezone.as_str()); if let Ok(tz) = tz { - Ok(utf8_to_timestamp_ns_impl(array, fmt, timezone, tz)) + Ok(utf8_to_timestamp_impl(array, fmt, timezone, tz, tu)) } else { - chrono_tz_utf_to_timestamp_ns(array, fmt, timezone) + chrono_tz_utf_to_timestamp(array, fmt, timezone, tu) } } @@ -459,15 +444,16 @@ pub fn utf8_to_timestamp_ns( /// [`PrimitiveArray`] with type `Timestamp(Nanosecond, None)`. /// Timezones are ignored. /// Null elements remain null; non-parsable elements are set to null. -pub fn utf8_to_naive_timestamp_ns( +pub fn utf8_to_naive_timestamp( array: &Utf8Array, fmt: &str, + tu: TimeUnit, ) -> PrimitiveArray { let iter = array .iter() - .map(|x| x.and_then(|x| utf8_to_naive_timestamp_ns_scalar(x, fmt))); + .map(|x| x.and_then(|x| utf8_to_naive_timestamp_scalar(x, fmt, &tu))); - PrimitiveArray::from_trusted_len_iter(iter).to(DataType::Timestamp(TimeUnit::Nanosecond, None)) + PrimitiveArray::from_trusted_len_iter(iter).to(DataType::Timestamp(tu, None)) } fn add_month(year: i32, month: u32, months: i32) -> chrono::NaiveDate { diff --git a/crates/polars-core/src/chunked_array/cast.rs b/crates/polars-core/src/chunked_array/cast.rs index 65626ae990c7..67e5588fa9e4 100644 --- a/crates/polars-core/src/chunked_array/cast.rs +++ b/crates/polars-core/src/chunked_array/cast.rs @@ -204,24 +204,21 @@ impl ChunkCast for Utf8Chunked { Ok(out) }, #[cfg(feature = "dtype-datetime")] - DataType::Datetime(_tu, tz) => { + DataType::Datetime(tu, tz) => { let out = match tz { #[cfg(feature = "timezones")] Some(tz) => { validate_time_zone(tz)?; let result = cast_chunks( &self.chunks, - &Datetime(TimeUnit::Nanoseconds, Some(tz.clone())), + &Datetime(tu.to_owned(), Some(tz.clone())), true, )?; Series::try_from((self.name(), result)) }, _ => { - let result = cast_chunks( - &self.chunks, - &Datetime(TimeUnit::Nanoseconds, None), - true, - )?; + let result = + cast_chunks(&self.chunks, &Datetime(tu.to_owned(), None), true)?; Series::try_from((self.name(), result)) }, }; diff --git a/py-polars/tests/unit/test_queries.py b/py-polars/tests/unit/test_queries.py index 8ececf3a535e..13ec031ffc54 100644 --- a/py-polars/tests/unit/test_queries.py +++ b/py-polars/tests/unit/test_queries.py @@ -371,6 +371,7 @@ def test_shift_drop_nulls_10875() -> None: "a" ].to_list() == [1, 2] + def test_utf8_date() -> None: df = pl.DataFrame({"x1": ["2021-01-01"]}).with_columns( **{"x1-date": pl.col("x1").cast(pl.Date)} @@ -382,7 +383,25 @@ def test_utf8_date() -> None: def test_utf8_datetime() -> None: df = pl.DataFrame( - {"x1": ["2021-12-19T16:39:57-02:00", "2022-12-19T16:39:57"]} + {"x1": ["2021-12-19T00:39:57", "2022-12-19T16:39:57"]} + ).with_columns( + **{ + "x1-datetime-ns": pl.col("x1").cast(pl.Datetime(time_unit="ns")), + "x1-datetime-ms": pl.col("x1").cast(pl.Datetime(time_unit="ms")), + "x1-datetime-us": pl.col("x1").cast(pl.Datetime(time_unit="us")), + } + ) + + out = df.select( + pl.col("x1-datetime-ns"), pl.col("x1-datetime-ms"), pl.col("x1-datetime-us") + ) + assert out.shape == (2, 3) + assert out.dtypes == [pl.Datetime, pl.Datetime, pl.Datetime] + + +def test_utf8_datetime_timezone() -> None: + df = pl.DataFrame( + {"x1": ["1996-12-19T16:39:57-02:00", "2022-12-19T00:39:57-03:00"]} ).with_columns( **{ "x1-datetime-ns": pl.col("x1").cast(pl.Datetime(time_unit="ns")), From 8be15bd89d2e4ade5f740b43daed25f75f3344fd Mon Sep 17 00:00:00 2001 From: Brayan Jules Date: Mon, 16 Oct 2023 08:57:57 -0300 Subject: [PATCH 004/103] feat: added missing tests for failure scenarios, also fixed casting from int to date. --- crates/nano-arrow/src/temporal_conversions.rs | 3 +- crates/polars-core/src/chunked_array/cast.rs | 5 ++- .../src/chunked_array/temporal/mod.rs | 14 +++++++ py-polars/tests/unit/test_lazy.py | 2 +- py-polars/tests/unit/test_queries.py | 38 ++++++++++++++++--- 5 files changed, 53 insertions(+), 9 deletions(-) diff --git a/crates/nano-arrow/src/temporal_conversions.rs b/crates/nano-arrow/src/temporal_conversions.rs index 8ba3d2523678..8ab1b3ec3ffa 100644 --- a/crates/nano-arrow/src/temporal_conversions.rs +++ b/crates/nano-arrow/src/temporal_conversions.rs @@ -407,10 +407,11 @@ fn chrono_tz_utf_to_timestamp( } #[cfg(not(feature = "chrono-tz"))] -fn chrono_tz_utf_to_timestamp_ns( +fn chrono_tz_utf_to_timestamp( _: &Utf8Array, _: &str, timezone: String, + _: TimeUnit, ) -> Result> { Err(Error::InvalidArgumentError(format!( "timezone \"{timezone}\" cannot be parsed (feature chrono-tz is not active)", diff --git a/crates/polars-core/src/chunked_array/cast.rs b/crates/polars-core/src/chunked_array/cast.rs index 67e5588fa9e4..45a61d169601 100644 --- a/crates/polars-core/src/chunked_array/cast.rs +++ b/crates/polars-core/src/chunked_array/cast.rs @@ -5,6 +5,7 @@ use arrow::compute::cast::CastOptions; #[cfg(feature = "dtype-categorical")] use crate::chunked_array::categorical::CategoricalChunkedBuilder; +use crate::chunked_array::temporal::{validate_is_number}; #[cfg(feature = "timezones")] use crate::chunked_array::temporal::validate_time_zone; use crate::prelude::DataType::Datetime; @@ -198,13 +199,13 @@ impl ChunkCast for Utf8Chunked { }, }, #[cfg(feature = "dtype-date")] - DataType::Date => { + DataType::Date if !validate_is_number(&self.chunks) => { let result = cast_chunks(&self.chunks, data_type, true)?; let out = Series::try_from((self.name(), result))?; Ok(out) }, #[cfg(feature = "dtype-datetime")] - DataType::Datetime(tu, tz) => { + DataType::Datetime(tu, tz) if !validate_is_number(&self.chunks) => { let out = match tz { #[cfg(feature = "timezones")] Some(tz) => { diff --git a/crates/polars-core/src/chunked_array/temporal/mod.rs b/crates/polars-core/src/chunked_array/temporal/mod.rs index 737ff5086d47..9e0759a9b31d 100644 --- a/crates/polars-core/src/chunked_array/temporal/mod.rs +++ b/crates/polars-core/src/chunked_array/temporal/mod.rs @@ -15,6 +15,7 @@ use chrono::NaiveDateTime; use chrono::NaiveTime; #[cfg(feature = "timezones")] use chrono_tz::Tz; +use polars_arrow::prelude::ArrayRef; #[cfg(feature = "dtype-time")] pub use time::time_to_time64ns; @@ -35,3 +36,16 @@ pub(crate) fn validate_time_zone(tz: &str) -> PolarsResult<()> { }, } } + +pub(crate) fn validate_is_number(vec_array: &Vec) -> bool { + vec_array.iter().all(|array|is_parsable_as_number(array)) +} + +fn is_parsable_as_number(array: &ArrayRef) -> bool { + if let Some(array) = array.as_any().downcast_ref::() { + array.iter().all(|value| value.expect("Unable to parse int string to datetime").parse::().is_ok()) + } else { + false + } +} + diff --git a/py-polars/tests/unit/test_lazy.py b/py-polars/tests/unit/test_lazy.py index 5a7ba0ee4c74..39e46f6a0846 100644 --- a/py-polars/tests/unit/test_lazy.py +++ b/py-polars/tests/unit/test_lazy.py @@ -1328,7 +1328,7 @@ def test_quadratic_behavior_4736() -> None: ldf.select(reduce(add, (pl.col(fld) for fld in ldf.columns))) -@pytest.mark.parametrize("input_dtype", [pl.Int64, pl.Float64]) +@pytest.mark.parametrize("input_dtype", [pl.Int64, pl.Float64, pl.Utf8]) def test_from_epoch(input_dtype: pl.PolarsDataType) -> None: ldf = pl.LazyFrame( [ diff --git a/py-polars/tests/unit/test_queries.py b/py-polars/tests/unit/test_queries.py index 13ec031ffc54..8db9c1283904 100644 --- a/py-polars/tests/unit/test_queries.py +++ b/py-polars/tests/unit/test_queries.py @@ -5,8 +5,10 @@ import numpy as np import pandas as pd +import pytest import polars as pl +from polars import ComputeError from polars.testing import assert_frame_equal @@ -381,6 +383,13 @@ def test_utf8_date() -> None: assert out.dtypes == [pl.Date] +def test_wrong_utf8_date() -> None: + df = pl.DataFrame({"x1": ["2021-01-aa"]}) + + with pytest.raises(ComputeError): + df.with_columns(**{"x1-date": pl.col("x1").cast(pl.Date)}) + + def test_utf8_datetime() -> None: df = pl.DataFrame( {"x1": ["2021-12-19T00:39:57", "2022-12-19T16:39:57"]} @@ -399,19 +408,38 @@ def test_utf8_datetime() -> None: assert out.dtypes == [pl.Datetime, pl.Datetime, pl.Datetime] +def test_wrong_utf8_datetime() -> None: + df = pl.DataFrame({"x1": ["2021-12-19 00:39:57", "2022-12-19 16:39:57"]}) + with pytest.raises(ComputeError): + df.with_columns( + **{"x1-datetime-ns": pl.col("x1").cast(pl.Datetime(time_unit="ns"))} + ) + + def test_utf8_datetime_timezone() -> None: df = pl.DataFrame( - {"x1": ["1996-12-19T16:39:57-02:00", "2022-12-19T00:39:57-03:00"]} + {"x1": ["1996-12-19T16:39:57 +00:00", "2022-12-19T00:39:57 +00:00"]} ).with_columns( **{ - "x1-datetime-ns": pl.col("x1").cast(pl.Datetime(time_unit="ns")), - "x1-datetime-ms": pl.col("x1").cast(pl.Datetime(time_unit="ms")), - "x1-datetime-us": pl.col("x1").cast(pl.Datetime(time_unit="us")), + "x1-datetime-ns": pl.col("x1").cast( + pl.Datetime(time_unit="ns", time_zone="America/Caracas") + ), + "x1-datetime-ms": pl.col("x1").cast( + pl.Datetime(time_unit="ms", time_zone="America/Santiago") + ), + "x1-datetime-us": pl.col("x1").cast( + pl.Datetime(time_unit="us", time_zone="UTC") + ), } ) out = df.select( pl.col("x1-datetime-ns"), pl.col("x1-datetime-ms"), pl.col("x1-datetime-us") ) + assert out.shape == (2, 3) - assert out.dtypes == [pl.Datetime, pl.Datetime, pl.Datetime] + assert out.dtypes == [ + pl.Datetime("ns", "America/Caracas"), + pl.Datetime("ms", "America/Santiago"), + pl.Datetime("us", "UTC"), + ] From de17db621854fa8af5128ad9b2c839b5e673b204 Mon Sep 17 00:00:00 2001 From: Brayan Jules Date: Mon, 16 Oct 2023 09:31:44 -0300 Subject: [PATCH 005/103] fix: fixed issue regarding arrow libraries import and code formatting --- crates/polars-core/src/chunked_array/cast.rs | 2 +- .../polars-core/src/chunked_array/temporal/mod.rs | 14 +++++++++----- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/crates/polars-core/src/chunked_array/cast.rs b/crates/polars-core/src/chunked_array/cast.rs index e539182a234d..214ed5f19482 100644 --- a/crates/polars-core/src/chunked_array/cast.rs +++ b/crates/polars-core/src/chunked_array/cast.rs @@ -5,7 +5,7 @@ use arrow::compute::cast::CastOptions; #[cfg(feature = "dtype-categorical")] use crate::chunked_array::categorical::CategoricalChunkedBuilder; -use crate::chunked_array::temporal::{validate_is_number}; +use crate::chunked_array::temporal::validate_is_number; #[cfg(feature = "timezones")] use crate::chunked_array::temporal::validate_time_zone; use crate::prelude::DataType::Datetime; diff --git a/crates/polars-core/src/chunked_array/temporal/mod.rs b/crates/polars-core/src/chunked_array/temporal/mod.rs index 9e0759a9b31d..c6ea220b7d21 100644 --- a/crates/polars-core/src/chunked_array/temporal/mod.rs +++ b/crates/polars-core/src/chunked_array/temporal/mod.rs @@ -15,13 +15,13 @@ use chrono::NaiveDateTime; use chrono::NaiveTime; #[cfg(feature = "timezones")] use chrono_tz::Tz; -use polars_arrow::prelude::ArrayRef; #[cfg(feature = "dtype-time")] pub use time::time_to_time64ns; pub use self::conversion::*; #[cfg(feature = "timezones")] use crate::prelude::{polars_bail, PolarsResult}; +use crate::prelude::{ArrayRef, LargeStringArray}; pub fn unix_time() -> NaiveDateTime { NaiveDateTime::from_timestamp_opt(0, 0).unwrap() @@ -38,14 +38,18 @@ pub(crate) fn validate_time_zone(tz: &str) -> PolarsResult<()> { } pub(crate) fn validate_is_number(vec_array: &Vec) -> bool { - vec_array.iter().all(|array|is_parsable_as_number(array)) + vec_array.iter().all(|array| is_parsable_as_number(array)) } fn is_parsable_as_number(array: &ArrayRef) -> bool { - if let Some(array) = array.as_any().downcast_ref::() { - array.iter().all(|value| value.expect("Unable to parse int string to datetime").parse::().is_ok()) + if let Some(array) = array.as_any().downcast_ref::() { + array.iter().all(|value| { + value + .expect("Unable to parse int string to datetime") + .parse::() + .is_ok() + }) } else { false } } - From aea80f46f6b69f23c7c849e51e31e00a632d6058 Mon Sep 17 00:00:00 2001 From: nameexhaustion Date: Tue, 17 Oct 2023 04:32:55 +1100 Subject: [PATCH 006/103] fix(rust,python): only exclude final output names of group_by key expressions (#11768) --- crates/polars-lazy/src/frame/mod.rs | 10 +++++----- .../src/logical_plan/projection.rs | 19 ++++--------------- .../tests/unit/operations/test_group_by.py | 9 +++++++++ 3 files changed, 18 insertions(+), 20 deletions(-) diff --git a/crates/polars-lazy/src/frame/mod.rs b/crates/polars-lazy/src/frame/mod.rs index f08a16d28bf7..261fbaaaf517 100644 --- a/crates/polars-lazy/src/frame/mod.rs +++ b/crates/polars-lazy/src/frame/mod.rs @@ -31,7 +31,7 @@ use polars_plan::global::FETCH_ROWS; #[cfg(any(feature = "ipc", feature = "parquet", feature = "csv"))] use polars_plan::logical_plan::collect_fingerprints; use polars_plan::logical_plan::optimize; -use polars_plan::utils::expr_to_leaf_column_names; +use polars_plan::utils::expr_output_name; use smartstring::alias::String as SmartString; use crate::fallible; @@ -1674,10 +1674,10 @@ impl LazyGroupBy { let keys = self .keys .iter() - .flat_map(|k| expr_to_leaf_column_names(k).into_iter()) + .filter_map(|expr| expr_output_name(expr).ok()) .collect::>(); - self.agg([col("*").exclude(&keys).head(n).keep_name()]) + self.agg([col("*").exclude(&keys).head(n)]) .explode([col("*").exclude(&keys)]) } @@ -1686,10 +1686,10 @@ impl LazyGroupBy { let keys = self .keys .iter() - .flat_map(|k| expr_to_leaf_column_names(k).into_iter()) + .filter_map(|expr| expr_output_name(expr).ok()) .collect::>(); - self.agg([col("*").exclude(&keys).tail(n).keep_name()]) + self.agg([col("*").exclude(&keys).tail(n)]) .explode([col("*").exclude(&keys)]) } diff --git a/crates/polars-plan/src/logical_plan/projection.rs b/crates/polars-plan/src/logical_plan/projection.rs index 436c31b354c2..3c9832e4b4fa 100644 --- a/crates/polars-plan/src/logical_plan/projection.rs +++ b/crates/polars-plan/src/logical_plan/projection.rs @@ -4,6 +4,7 @@ use polars_core::utils::get_supertype; use super::*; use crate::prelude::function_expr::FunctionExpr; +use crate::utils::expr_output_name; /// This replace the wildcard Expr with a Column Expr. It also removes the Exclude Expr from the /// expression chain. @@ -354,21 +355,9 @@ fn prepare_excluded( } // exclude group_by keys - for mut expr in keys.iter() { - // Allow a number of aliases of a column expression, still exclude column from aggregation - loop { - match expr { - Expr::Column(name) => { - exclude.insert(name.clone()); - break; - }, - Expr::Alias(e, _) => { - expr = e; - }, - _ => { - break; - }, - } + for expr in keys.iter() { + if let Ok(name) = expr_output_name(expr) { + exclude.insert(name.clone()); } } Ok(exclude) diff --git a/py-polars/tests/unit/operations/test_group_by.py b/py-polars/tests/unit/operations/test_group_by.py index d3fb4d6dcb49..728aad9ee81b 100644 --- a/py-polars/tests/unit/operations/test_group_by.py +++ b/py-polars/tests/unit/operations/test_group_by.py @@ -788,3 +788,12 @@ def test_group_by_list_scalar_11749() -> None: "group_name": ["a;b", "c;d"], "eq": [[True, True, True, True], [True, False]], } + + +def test_group_by_with_expr_as_key() -> None: + gb = pl.select(x=1).group_by(pl.col("x").alias("key")) + assert gb.agg(pl.all().first()).frame_equal(gb.agg(pl.first("x"))) + + # tests: 11766 + assert gb.head(0).frame_equal(gb.agg(pl.col("x").head(0)).explode("x")) + assert gb.tail(0).frame_equal(gb.agg(pl.col("x").tail(0)).explode("x")) From 20a3991d676f42e1f2c55104042303a54fa4e6dc Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Mon, 16 Oct 2023 19:35:09 +0200 Subject: [PATCH 007/103] depr(python): Rename `group_by_rolling` to `rolling` (#11761) --- .../reference/dataframe/modify_select.rst | 1 + .../reference/lazyframe/modify_select.rst | 1 + py-polars/polars/dataframe/frame.py | 69 +++++++++++++-- py-polars/polars/dataframe/group_by.py | 6 +- py-polars/polars/expr/expr.py | 18 ++-- py-polars/polars/lazyframe/frame.py | 87 +++++++++++++++---- py-polars/src/lazyframe.rs | 2 +- .../tests/parametric/test_groupby_rolling.py | 4 +- .../tests/unit/datatypes/test_temporal.py | 22 ++--- .../unit/operations/map/test_map_groups.py | 6 +- .../unit/operations/rolling/test_rolling.py | 28 +++--- .../tests/unit/operations/test_group_by.py | 27 +++++- .../unit/operations/test_group_by_dynamic.py | 2 +- .../unit/operations/test_group_by_rolling.py | 50 +++++------ 14 files changed, 227 insertions(+), 96 deletions(-) diff --git a/py-polars/docs/source/reference/dataframe/modify_select.rst b/py-polars/docs/source/reference/dataframe/modify_select.rst index c3d1f1b91c8c..dec81a18a308 100644 --- a/py-polars/docs/source/reference/dataframe/modify_select.rst +++ b/py-polars/docs/source/reference/dataframe/modify_select.rst @@ -47,6 +47,7 @@ Manipulation/selection DataFrame.replace DataFrame.replace_at_idx DataFrame.reverse + DataFrame.rolling DataFrame.row DataFrame.rows DataFrame.rows_by_key diff --git a/py-polars/docs/source/reference/lazyframe/modify_select.rst b/py-polars/docs/source/reference/lazyframe/modify_select.rst index 1a1482ec4623..19cef033426f 100644 --- a/py-polars/docs/source/reference/lazyframe/modify_select.rst +++ b/py-polars/docs/source/reference/lazyframe/modify_select.rst @@ -36,6 +36,7 @@ Manipulation/selection LazyFrame.merge_sorted LazyFrame.rename LazyFrame.reverse + LazyFrame.rolling LazyFrame.select LazyFrame.select_seq LazyFrame.set_sorted diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 86d9051e557f..aca2118fd704 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -5125,7 +5125,7 @@ def group_by( """ return GroupBy(self, by, *more_by, maintain_order=maintain_order) - def group_by_rolling( + def rolling( self, index_column: IntoExpr, *, @@ -5177,7 +5177,7 @@ def group_by_rolling( not be 24 hours, due to daylight savings). Similarly for "calendar week", "calendar month", "calendar quarter", and "calendar year". - In case of a group_by_rolling on an integer column, the windows are defined by: + In case of a rolling operation on an integer column, the windows are defined by: - **"1i" # length 1** - **"10i" # length 10** @@ -5190,7 +5190,7 @@ def group_by_rolling( This column must be sorted in ascending order (or, if `by` is specified, then it must be sorted in ascending order within each group). - In case of a rolling group by on indices, dtype needs to be one of + In case of a rolling operation on indices, dtype needs to be one of {Int32, Int64}. Note that Int32 gets temporarily cast to Int64, so if performance matters use an Int64 column. period @@ -5232,7 +5232,7 @@ def group_by_rolling( >>> df = pl.DataFrame({"dt": dates, "a": [3, 7, 5, 9, 2, 1]}).with_columns( ... pl.col("dt").str.strptime(pl.Datetime).set_sorted() ... ) - >>> out = df.group_by_rolling(index_column="dt", period="2d").agg( + >>> out = df.rolling(index_column="dt", period="2d").agg( ... [ ... pl.sum("a").alias("sum_a"), ... pl.min("a").alias("min_a"), @@ -5370,7 +5370,7 @@ def group_by_dynamic( See Also -------- - group_by_rolling + rolling Notes ----- @@ -9906,7 +9906,7 @@ def groupby( """ return self.group_by(by, *more_by, maintain_order=maintain_order) - @deprecate_renamed_function("group_by_rolling", version="0.19.0") + @deprecate_renamed_function("rolling", version="0.19.0") def groupby_rolling( self, index_column: IntoExpr, @@ -9921,7 +9921,60 @@ def groupby_rolling( Create rolling groups based on a time, Int32, or Int64 column. .. deprecated:: 0.19.0 - This method has been renamed to :func:`DataFrame.group_by_rolling`. + This method has been renamed to :func:`DataFrame.rolling`. + + Parameters + ---------- + index_column + Column used to group based on the time window. + Often of type Date/Datetime. + This column must be sorted in ascending order (or, if `by` is specified, + then it must be sorted in ascending order within each group). + + In case of a rolling group by on indices, dtype needs to be one of + {Int32, Int64}. Note that Int32 gets temporarily cast to Int64, so if + performance matters use an Int64 column. + period + length of the window - must be non-negative + offset + offset of the window. Default is -period + closed : {'right', 'left', 'both', 'none'} + Define which sides of the temporal interval are closed (inclusive). + by + Also group by this column/these columns + check_sorted + When the ``by`` argument is given, polars can not check sortedness + by the metadata and has to do a full scan on the index column to + verify data is sorted. This is expensive. If you are sure the + data within the by groups is sorted, you can set this to ``False``. + Doing so incorrectly will lead to incorrect output + + """ + return self.rolling( + index_column, + period=period, + offset=offset, + closed=closed, + by=by, + check_sorted=check_sorted, + ) + + @deprecate_renamed_function("rolling", version="0.19.9") + def group_by_rolling( + self, + index_column: IntoExpr, + *, + period: str | timedelta, + offset: str | timedelta | None = None, + closed: ClosedInterval = "right", + by: IntoExpr | Iterable[IntoExpr] | None = None, + check_sorted: bool = True, + ) -> RollingGroupBy: + """ + Create rolling groups based on a time, Int32, or Int64 column. + + .. deprecated:: 0.19.9 + This method has been renamed to :func:`DataFrame.rolling`. Parameters ---------- @@ -9950,7 +10003,7 @@ def groupby_rolling( Doing so incorrectly will lead to incorrect output """ - return self.group_by_rolling( + return self.rolling( index_column, period=period, offset=offset, diff --git a/py-polars/polars/dataframe/group_by.py b/py-polars/polars/dataframe/group_by.py index e648b149970d..549c78cbd3c3 100644 --- a/py-polars/polars/dataframe/group_by.py +++ b/py-polars/polars/dataframe/group_by.py @@ -800,7 +800,7 @@ def __iter__(self) -> Self: groups_df = ( self.df.lazy() .with_row_count(name=temp_col) - .group_by_rolling( + .rolling( index_column=self.time_column, period=self.period, offset=self.offset, @@ -859,7 +859,7 @@ def agg( """ return ( self.df.lazy() - .group_by_rolling( + .rolling( index_column=self.time_column, period=self.period, offset=self.offset, @@ -903,7 +903,7 @@ def map_groups( """ return ( self.df.lazy() - .group_by_rolling( + .rolling( index_column=self.time_column, period=self.period, offset=self.offset, diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index cc9735faf2c1..e9eb033ad1af 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -3225,7 +3225,7 @@ def rolling( not be 24 hours, due to daylight savings). Similarly for "calendar week", "calendar month", "calendar quarter", and "calendar year". - In case of a group_by_rolling on an integer column, the windows are defined by: + In case of a rolling operation on an integer column, the windows are defined by: - "1i" # length 1 - "10i" # length 10 @@ -5530,7 +5530,7 @@ def rolling_min( Notes ----- If you want to compute multiple aggregation statistics over the same dynamic - window, consider using `group_by_rolling` this method can cache the window size + window, consider using `rolling` - this method can cache the window size computation. Examples @@ -5736,7 +5736,7 @@ def rolling_max( Notes ----- If you want to compute multiple aggregation statistics over the same dynamic - window, consider using `group_by_rolling` this method can cache the window size + window, consider using `rolling` - this method can cache the window size computation. Examples @@ -5973,7 +5973,7 @@ def rolling_mean( Notes ----- If you want to compute multiple aggregation statistics over the same dynamic - window, consider using `group_by_rolling` this method can cache the window size + window, consider using `rolling` - this method can cache the window size computation. Examples @@ -6206,7 +6206,7 @@ def rolling_sum( Notes ----- If you want to compute multiple aggregation statistics over the same dynamic - window, consider using `group_by_rolling` this method can cache the window size + window, consider using `rolling` - this method can cache the window size computation. Examples @@ -6442,7 +6442,7 @@ def rolling_std( Notes ----- If you want to compute multiple aggregation statistics over the same dynamic - window, consider using `group_by_rolling` this method can cache the window size + window, consider using `rolling` - this method can cache the window size computation. Examples @@ -6678,7 +6678,7 @@ def rolling_var( Notes ----- If you want to compute multiple aggregation statistics over the same dynamic - window, consider using `group_by_rolling` this method can cache the window size + window, consider using `rolling` - this method can cache the window size computation. Examples @@ -6917,7 +6917,7 @@ def rolling_median( Notes ----- If you want to compute multiple aggregation statistics over the same dynamic - window, consider using `group_by_rolling` this method can cache the window size + window, consider using `rolling` - this method can cache the window size computation. Examples @@ -7082,7 +7082,7 @@ def rolling_quantile( Notes ----- If you want to compute multiple aggregation statistics over the same dynamic - window, consider using `group_by_rolling` this method can cache the window size + window, consider using `rolling` - this method can cache the window size computation. Examples diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index 6943598beffc..3a53048b723d 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -2946,7 +2946,7 @@ def group_by( lgb = self._ldf.group_by(exprs, maintain_order) return LazyGroupBy(lgb) - def group_by_rolling( + def rolling( self, index_column: IntoExpr, *, @@ -2998,7 +2998,7 @@ def group_by_rolling( not be 24 hours, due to daylight savings). Similarly for "calendar week", "calendar month", "calendar quarter", and "calendar year". - In case of a group_by_rolling on an integer column, the windows are defined by: + In case of a rolling operation on an integer column, the windows are defined by: - "1i" # length 1 - "10i" # length 10 @@ -3054,19 +3054,14 @@ def group_by_rolling( ... pl.col("dt").str.strptime(pl.Datetime).set_sorted() ... ) >>> out = ( - ... df.group_by_rolling(index_column="dt", period="2d") + ... df.rolling(index_column="dt", period="2d") ... .agg( - ... [ - ... pl.sum("a").alias("sum_a"), - ... pl.min("a").alias("min_a"), - ... pl.max("a").alias("max_a"), - ... ] + ... pl.sum("a").alias("sum_a"), + ... pl.min("a").alias("min_a"), + ... pl.max("a").alias("max_a"), ... ) ... .collect() ... ) - >>> assert out["sum_a"].to_list() == [3, 10, 15, 24, 11, 1] - >>> assert out["max_a"].to_list() == [3, 7, 7, 9, 9, 1] - >>> assert out["min_a"].to_list() == [3, 3, 3, 3, 2, 1] >>> out shape: (6, 4) ┌─────────────────────┬───────┬───────┬───────┐ @@ -3091,7 +3086,7 @@ def group_by_rolling( period = _timedelta_to_pl_duration(period) offset = _timedelta_to_pl_duration(offset) - lgb = self._ldf.group_by_rolling( + lgb = self._ldf.rolling( index_column, period, offset, closed, pyexprs_by, check_sorted ) return LazyGroupBy(lgb) @@ -3198,7 +3193,7 @@ def group_by_dynamic( See Also -------- - group_by_rolling + rolling Notes ----- @@ -5862,7 +5857,7 @@ def groupby( """ return self.group_by(by, *more_by, maintain_order=maintain_order) - @deprecate_renamed_function("group_by_rolling", version="0.19.0") + @deprecate_renamed_function("rolling", version="0.19.0") def groupby_rolling( self, index_column: IntoExpr, @@ -5877,7 +5872,67 @@ def groupby_rolling( Create rolling groups based on a time, Int32, or Int64 column. .. deprecated:: 0.19.0 - This method has been renamed to :func:`LazyFrame.group_by_rolling`. + This method has been renamed to :func:`LazyFrame.rolling`. + + Parameters + ---------- + index_column + Column used to group based on the time window. + Often of type Date/Datetime. + This column must be sorted in ascending order (or, if `by` is specified, + then it must be sorted in ascending order within each group). + + In case of a rolling group by on indices, dtype needs to be one of + {Int32, Int64}. Note that Int32 gets temporarily cast to Int64, so if + performance matters use an Int64 column. + period + length of the window - must be non-negative + offset + offset of the window. Default is -period + closed : {'right', 'left', 'both', 'none'} + Define which sides of the temporal interval are closed (inclusive). + by + Also group by this column/these columns + check_sorted + When the ``by`` argument is given, polars can not check sortedness + by the metadata and has to do a full scan on the index column to + verify data is sorted. This is expensive. If you are sure the + data within the by groups is sorted, you can set this to ``False``. + Doing so incorrectly will lead to incorrect output + + Returns + ------- + LazyGroupBy + Object you can call ``.agg`` on to aggregate by groups, the result + of which will be sorted by `index_column` (but note that if `by` columns are + passed, it will only be sorted within each `by` group). + + """ + return self.rolling( + index_column, + period=period, + offset=offset, + closed=closed, + by=by, + check_sorted=check_sorted, + ) + + @deprecate_renamed_function("rolling", version="0.19.9") + def group_by_rolling( + self, + index_column: IntoExpr, + *, + period: str | timedelta, + offset: str | timedelta | None = None, + closed: ClosedInterval = "right", + by: IntoExpr | Iterable[IntoExpr] | None = None, + check_sorted: bool = True, + ) -> LazyGroupBy: + """ + Create rolling groups based on a time, Int32, or Int64 column. + + .. deprecated:: 0.19.9 + This method has been renamed to :func:`LazyFrame.rolling`. Parameters ---------- @@ -5913,7 +5968,7 @@ def groupby_rolling( passed, it will only be sorted within each `by` group). """ - return self.group_by_rolling( + return self.rolling( index_column, period=period, offset=offset, diff --git a/py-polars/src/lazyframe.rs b/py-polars/src/lazyframe.rs index 003266bbe02d..64815d0f551c 100644 --- a/py-polars/src/lazyframe.rs +++ b/py-polars/src/lazyframe.rs @@ -661,7 +661,7 @@ impl PyLazyFrame { PyLazyGroupBy { lgb: Some(lazy_gb) } } - fn group_by_rolling( + fn rolling( &mut self, index_column: PyExpr, period: &str, diff --git a/py-polars/tests/parametric/test_groupby_rolling.py b/py-polars/tests/parametric/test_groupby_rolling.py index e5b0e8716c53..62c1b188c12e 100644 --- a/py-polars/tests/parametric/test_groupby_rolling.py +++ b/py-polars/tests/parametric/test_groupby_rolling.py @@ -28,7 +28,7 @@ data=st.data(), time_unit=strategy_time_unit, ) -def test_group_by_rolling( +def test_rolling( period: str, offset: str, closed: ClosedInterval, @@ -57,7 +57,7 @@ def test_group_by_rolling( ) ) df = dataframe.sort("ts") - result = df.group_by_rolling("ts", period=period, offset=offset, closed=closed).agg( + result = df.rolling("ts", period=period, offset=offset, closed=closed).agg( pl.col("value") ) diff --git a/py-polars/tests/unit/datatypes/test_temporal.py b/py-polars/tests/unit/datatypes/test_temporal.py index 290b2087fbc2..aa79cfa82c3e 100644 --- a/py-polars/tests/unit/datatypes/test_temporal.py +++ b/py-polars/tests/unit/datatypes/test_temporal.py @@ -612,7 +612,7 @@ def test_rolling() -> None: period: str | timedelta for period in ("2d", timedelta(days=2)): # type: ignore[assignment] - out = df.group_by_rolling(index_column="dt", period=period).agg( + out = df.rolling(index_column="dt", period=period).agg( [ pl.sum("a").alias("sum_a"), pl.min("a").alias("min_a"), @@ -820,7 +820,7 @@ def test_asof_join_tolerance_grouper() -> None: def test_rolling_group_by_by_argument() -> None: df = pl.DataFrame({"times": range(10), "groups": [1] * 4 + [2] * 6}) - out = df.group_by_rolling("times", period="5i", by=["groups"]).agg( + out = df.rolling("times", period="5i", by=["groups"]).agg( pl.col("times").alias("agg_list") ) @@ -846,7 +846,7 @@ def test_rolling_group_by_by_argument() -> None: assert_frame_equal(out, expected) -def test_group_by_rolling_mean_3020() -> None: +def test_rolling_mean_3020() -> None: df = pl.DataFrame( { "Date": [ @@ -864,7 +864,7 @@ def test_group_by_rolling_mean_3020() -> None: period: str | timedelta for period in ("1w", timedelta(days=7)): # type: ignore[assignment] - result = df.group_by_rolling(index_column="Date", period=period).agg( + result = df.rolling(index_column="Date", period=period).agg( pl.col("val").mean().alias("val_mean") ) expected = pl.DataFrame( @@ -1275,7 +1275,7 @@ def test_unique_counts_on_dates() -> None: } -def test_group_by_rolling_by_ordering() -> None: +def test_rolling_by_ordering() -> None: # we must check that the keys still match the time labels after the rolling window # with a `by` argument. df = pl.DataFrame( @@ -1294,7 +1294,7 @@ def test_group_by_rolling_by_ordering() -> None: } ).set_sorted("dt") - assert df.group_by_rolling( + assert df.rolling( index_column="dt", period="2m", closed="both", @@ -1321,7 +1321,7 @@ def test_group_by_rolling_by_ordering() -> None: } -def test_group_by_rolling_by_() -> None: +def test_rolling_by_() -> None: df = pl.DataFrame({"group": pl.arange(0, 3, eager=True)}).join( pl.DataFrame( { @@ -1334,13 +1334,13 @@ def test_group_by_rolling_by_() -> None: ) out = ( df.sort("datetime") - .group_by_rolling(index_column="datetime", by="group", period=timedelta(days=3)) + .rolling(index_column="datetime", by="group", period=timedelta(days=3)) .agg([pl.count().alias("count")]) ) expected = ( df.sort(["group", "datetime"]) - .group_by_rolling(index_column="datetime", by="group", period="3d") + .rolling(index_column="datetime", by="group", period="3d") .agg([pl.count().alias("count")]) ) assert_frame_equal(out.sort(["group", "datetime"]), expected) @@ -2590,7 +2590,7 @@ def test_rolling_group_by_empty_groups_by_take_6330() -> None: .set_sorted("Date") ) assert ( - df.group_by_rolling( + df.rolling( index_column="Date", period="2i", offset="-2i", @@ -2752,7 +2752,7 @@ def test_pytime_conversion(tm: time) -> None: assert s.to_list() == [tm] -def test_group_by_rolling_duplicates() -> None: +def test_rolling_duplicates() -> None: df = pl.DataFrame( { "ts": [datetime(2000, 1, 1, 0, 0), datetime(2000, 1, 1, 0, 0)], diff --git a/py-polars/tests/unit/operations/map/test_map_groups.py b/py-polars/tests/unit/operations/map/test_map_groups.py index 64cc7c443adf..90699300ea9d 100644 --- a/py-polars/tests/unit/operations/map/test_map_groups.py +++ b/py-polars/tests/unit/operations/map/test_map_groups.py @@ -49,9 +49,7 @@ def function(df: pl.DataFrame) -> pl.DataFrame: pl.col("b").max(), ) - result = df.group_by_rolling("a", period="2i").map_groups( - function, schema=df.schema - ) + result = df.rolling("a", period="2i").map_groups(function, schema=df.schema) expected = pl.DataFrame( [ @@ -162,7 +160,7 @@ def test_apply_deprecated() -> None: with pytest.deprecated_call(): df.group_by("a").apply(lambda x: x) with pytest.deprecated_call(): - df.group_by_rolling("a", period="2i").apply(lambda x: x, schema=None) + df.rolling("a", period="2i").apply(lambda x: x, schema=None) with pytest.deprecated_call(): df.group_by_dynamic("a", every="2i").apply(lambda x: x, schema=None) with pytest.deprecated_call(): diff --git a/py-polars/tests/unit/operations/rolling/test_rolling.py b/py-polars/tests/unit/operations/rolling/test_rolling.py index 639b7b297031..0f763f1a04c2 100644 --- a/py-polars/tests/unit/operations/rolling/test_rolling.py +++ b/py-polars/tests/unit/operations/rolling/test_rolling.py @@ -36,7 +36,7 @@ def example_df() -> pl.DataFrame: ["1d", "2d", "3d", timedelta(days=1), timedelta(days=2), timedelta(days=3)], ) @pytest.mark.parametrize("closed", ["left", "right", "none", "both"]) -def test_rolling_kernels_and_group_by_rolling( +def test_rolling_kernels_and_rolling( example_df: pl.DataFrame, period: str | timedelta, closed: ClosedInterval ) -> None: out1 = example_df.set_sorted("dt").select( @@ -56,7 +56,7 @@ def test_rolling_kernels_and_group_by_rolling( ) out2 = ( example_df.set_sorted("dt") - .group_by_rolling("dt", period=period, closed=closed) + .rolling("dt", period=period, closed=closed) .agg( [ pl.col("values").sum().alias("sum"), @@ -145,7 +145,7 @@ def test_rolling_negative_offset( "value": [1, 2, 3, 4], } ) - result = df.group_by_rolling("ts", period="2d", offset=offset, closed=closed).agg( + result = df.rolling("ts", period="2d", offset=offset, closed=closed).agg( pl.col("value") ) expected = pl.DataFrame( @@ -271,7 +271,7 @@ def test_rolling_group_by_extrema() -> None: ).with_columns(pl.col("col1").reverse().alias("row_nr")) assert ( - df.group_by_rolling( + df.rolling( index_column="row_nr", period="3i", ) @@ -310,7 +310,7 @@ def test_rolling_group_by_extrema() -> None: ).with_columns(pl.col("col1").alias("row_nr")) assert ( - df.group_by_rolling( + df.rolling( index_column="row_nr", period="3i", ) @@ -348,7 +348,7 @@ def test_rolling_group_by_extrema() -> None: ).with_columns(pl.col("col1").sort().alias("row_nr")) assert ( - df.group_by_rolling( + df.rolling( index_column="row_nr", period="3i", ) @@ -379,7 +379,7 @@ def test_rolling_slice_pushdown() -> None: df = pl.DataFrame({"a": [1, 2, 3], "b": ["a", "a", "b"], "c": [1, 3, 5]}).lazy() df = ( df.sort("a") - .group_by_rolling( + .rolling( "a", by="b", period="2i", @@ -407,7 +407,7 @@ def test_overlapping_groups_4628() -> None: } ) assert ( - df.group_by_rolling(index_column=pl.col("index").set_sorted(), period="3i").agg( + df.rolling(index_column=pl.col("index").set_sorted(), period="3i").agg( [ pl.col("val").diff(n=1).alias("val.diff"), (pl.col("val") - pl.col("val").shift(1)).alias("val - val.shift"), @@ -473,7 +473,7 @@ def test_rolling_var_numerical_stability_5197() -> None: assert res[:4] == [None] * 4 -def test_group_by_rolling_iter() -> None: +def test_rolling_iter() -> None: df = pl.DataFrame( { "date": [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 5)], @@ -485,7 +485,7 @@ def test_group_by_rolling_iter() -> None: # Without 'by' argument result1 = [ (name, data.shape) - for name, data in df.group_by_rolling(index_column="date", period="2d") + for name, data in df.rolling(index_column="date", period="2d") ] expected1 = [ (date(2020, 1, 1), (1, 3)), @@ -497,7 +497,7 @@ def test_group_by_rolling_iter() -> None: # With 'by' argument result2 = [ (name, data.shape) - for name, data in df.group_by_rolling(index_column="date", period="2d", by="a") + for name, data in df.rolling(index_column="date", period="2d", by="a") ] expected2 = [ ((1, date(2020, 1, 1)), (1, 3)), @@ -507,18 +507,18 @@ def test_group_by_rolling_iter() -> None: assert result2 == expected2 -def test_group_by_rolling_negative_period() -> None: +def test_rolling_negative_period() -> None: df = pl.DataFrame({"ts": [datetime(2020, 1, 1)], "value": [1]}).with_columns( pl.col("ts").set_sorted() ) with pytest.raises( ComputeError, match="rolling window period should be strictly positive" ): - df.group_by_rolling("ts", period="-1d", offset="-1d").agg(pl.col("value")) + df.rolling("ts", period="-1d", offset="-1d").agg(pl.col("value")) with pytest.raises( ComputeError, match="rolling window period should be strictly positive" ): - df.lazy().group_by_rolling("ts", period="-1d", offset="-1d").agg( + df.lazy().rolling("ts", period="-1d", offset="-1d").agg( pl.col("value") ).collect() with pytest.raises(ComputeError, match="window size should be strictly positive"): diff --git a/py-polars/tests/unit/operations/test_group_by.py b/py-polars/tests/unit/operations/test_group_by.py index 728aad9ee81b..aa638d3b29ff 100644 --- a/py-polars/tests/unit/operations/test_group_by.py +++ b/py-polars/tests/unit/operations/test_group_by.py @@ -744,7 +744,32 @@ def test_groupby_rolling_deprecated() -> None: .collect() ) - expected = df.group_by_rolling("date", period="2d").agg(pl.sum("value")) + expected = df.rolling("date", period="2d").agg(pl.sum("value")) + assert_frame_equal(result, expected, check_row_order=False) + assert_frame_equal(result_lazy, expected, check_row_order=False) + + +def test_group_by_rolling_deprecated() -> None: + df = pl.DataFrame( + { + "date": pl.datetime_range( + datetime(2020, 1, 1), datetime(2020, 1, 5), eager=True + ), + "value": [1, 2, 3, 4, 5], + } + ) + + with pytest.deprecated_call(): + result = df.group_by_rolling("date", period="2d").agg(pl.sum("value")) + with pytest.deprecated_call(): + result_lazy = ( + df.lazy() + .groupby_rolling("date", period="2d") + .agg(pl.sum("value")) + .collect() + ) + + expected = df.rolling("date", period="2d").agg(pl.sum("value")) assert_frame_equal(result, expected, check_row_order=False) assert_frame_equal(result_lazy, expected, check_row_order=False) diff --git a/py-polars/tests/unit/operations/test_group_by_dynamic.py b/py-polars/tests/unit/operations/test_group_by_dynamic.py index d8eff5bc96e8..3862901e5a7b 100644 --- a/py-polars/tests/unit/operations/test_group_by_dynamic.py +++ b/py-polars/tests/unit/operations/test_group_by_dynamic.py @@ -364,7 +364,7 @@ def test_sorted_flag_group_by_dynamic() -> None: ) -def test_group_by_rolling_dynamic_sortedness_check() -> None: +def test_rolling_dynamic_sortedness_check() -> None: # when the by argument is passed, the sortedness flag # will be unset as the take shuffles data, so we must explicitly # check the sortedness diff --git a/py-polars/tests/unit/operations/test_group_by_rolling.py b/py-polars/tests/unit/operations/test_group_by_rolling.py index 8220dbba6042..435b7c34bdec 100644 --- a/py-polars/tests/unit/operations/test_group_by_rolling.py +++ b/py-polars/tests/unit/operations/test_group_by_rolling.py @@ -32,7 +32,7 @@ def test_rolling_group_by_overlapping_groups() -> None: ( df.with_row_count() .with_columns(pl.col("row_nr").cast(pl.Int32)) - .group_by_rolling( + .rolling( index_column="row_nr", period="5i", ) @@ -48,7 +48,7 @@ def test_rolling_group_by_overlapping_groups() -> None: @pytest.mark.parametrize("lazy", [True, False]) -def test_group_by_rolling_agg_input_types(lazy: bool) -> None: +def test_rolling_agg_input_types(lazy: bool) -> None: df = pl.DataFrame({"index_column": [0, 1, 2, 3], "b": [1, 3, 1, 2]}).set_sorted( "index_column" ) @@ -56,24 +56,24 @@ def test_group_by_rolling_agg_input_types(lazy: bool) -> None: for bad_param in bad_agg_parameters(): with pytest.raises(TypeError): # noqa: PT012 - result = df_or_lazy.group_by_rolling( - index_column="index_column", period="2i" - ).agg(bad_param) + result = df_or_lazy.rolling(index_column="index_column", period="2i").agg( + bad_param + ) if lazy: result.collect() # type: ignore[union-attr] expected = pl.DataFrame({"index_column": [0, 1, 2, 3], "b": [1, 4, 4, 3]}) for good_param in good_agg_parameters(): - result = df_or_lazy.group_by_rolling( - index_column="index_column", period="2i" - ).agg(good_param) + result = df_or_lazy.rolling(index_column="index_column", period="2i").agg( + good_param + ) if lazy: result = result.collect() # type: ignore[union-attr] assert_frame_equal(result, expected) -def test_group_by_rolling_negative_offset_3914() -> None: +def test_rolling_negative_offset_3914() -> None: df = pl.DataFrame( { "datetime": pl.datetime_range( @@ -81,7 +81,7 @@ def test_group_by_rolling_negative_offset_3914() -> None: ), } ) - assert df.group_by_rolling(index_column="datetime", period="2d", offset="-4d").agg( + assert df.rolling(index_column="datetime", period="2d", offset="-4d").agg( pl.count().alias("count") )["count"].to_list() == [0, 0, 1, 2, 2] @@ -91,7 +91,7 @@ def test_group_by_rolling_negative_offset_3914() -> None: } ) - assert df.group_by_rolling(index_column="ints", period="2i", offset="-5i").agg( + assert df.rolling(index_column="ints", period="2i", offset="-5i").agg( [pl.col("ints").alias("matches")] )["matches"].to_list() == [ [], @@ -118,7 +118,7 @@ def test_group_by_rolling_negative_offset_3914() -> None: @pytest.mark.parametrize("time_zone", [None, "US/Central"]) -def test_group_by_rolling_negative_offset_crossing_dst(time_zone: str | None) -> None: +def test_rolling_negative_offset_crossing_dst(time_zone: str | None) -> None: df = pl.DataFrame( { "datetime": pl.datetime_range( @@ -131,9 +131,9 @@ def test_group_by_rolling_negative_offset_crossing_dst(time_zone: str | None) -> "value": [1, 4, 9, 155], } ) - result = df.group_by_rolling( - index_column="datetime", period="2d", offset="-1d" - ).agg(pl.col("value")) + result = df.rolling(index_column="datetime", period="2d", offset="-1d").agg( + pl.col("value") + ) expected = pl.DataFrame( { "datetime": pl.datetime_range( @@ -163,7 +163,7 @@ def test_group_by_rolling_negative_offset_crossing_dst(time_zone: str | None) -> ("1d", "none", [[9], [155], [], []]), ], ) -def test_group_by_rolling_non_negative_offset_9077( +def test_rolling_non_negative_offset_9077( time_zone: str | None, offset: str, closed: ClosedInterval, @@ -181,7 +181,7 @@ def test_group_by_rolling_non_negative_offset_9077( "value": [1, 4, 9, 155], } ) - result = df.group_by_rolling( + result = df.rolling( index_column="datetime", period="2d", offset=offset, closed=closed ).agg(pl.col("value")) expected = pl.DataFrame( @@ -199,7 +199,7 @@ def test_group_by_rolling_non_negative_offset_9077( assert_frame_equal(result, expected) -def test_group_by_rolling_dynamic_sortedness_check() -> None: +def test_rolling_dynamic_sortedness_check() -> None: # when the by argument is passed, the sortedness flag # will be unset as the take shuffles data, so we must explicitly # check the sortedness @@ -211,19 +211,17 @@ def test_group_by_rolling_dynamic_sortedness_check() -> None: ) with pytest.raises(pl.ComputeError, match=r"input data is not sorted"): - df.group_by_rolling("idx", period="2i", by="group").agg( - pl.col("idx").alias("idx1") - ) + df.rolling("idx", period="2i", by="group").agg(pl.col("idx").alias("idx1")) # no `by` argument with pytest.raises( pl.InvalidOperationError, match=r"argument in operation 'group_by_rolling' is not explicitly sorted", ): - df.group_by_rolling("idx", period="2i").agg(pl.col("idx").alias("idx1")) + df.rolling("idx", period="2i").agg(pl.col("idx").alias("idx1")) -def test_group_by_rolling_empty_groups_9973() -> None: +def test_rolling_empty_groups_9973() -> None: dt1 = date(2001, 1, 1) dt2 = date(2001, 1, 2) @@ -250,7 +248,7 @@ def test_group_by_rolling_empty_groups_9973() -> None: } ) - out = data.group_by_rolling( + out = data.rolling( index_column="date", by="id", period="2d", @@ -262,7 +260,7 @@ def test_group_by_rolling_empty_groups_9973() -> None: assert_frame_equal(out, expected) -def test_group_by_rolling_duplicates_11281() -> None: +def test_rolling_duplicates_11281() -> None: df = pl.DataFrame( { "ts": [ @@ -276,6 +274,6 @@ def test_group_by_rolling_duplicates_11281() -> None: "val": [1, 2, 2, 2, 3, 4], } ).sort("ts") - result = df.group_by_rolling("ts", period="1d", closed="left").agg(pl.col("val")) + result = df.rolling("ts", period="1d", closed="left").agg(pl.col("val")) expected = df.with_columns(val=pl.Series([[], [1], [1], [1], [2, 2, 2], [3]])) assert_frame_equal(result, expected) From 7efc54e71cf1e3b08d9e7681b0d1b65a5f3ee3f2 Mon Sep 17 00:00:00 2001 From: Romano Vacca Date: Mon, 16 Oct 2023 20:13:21 +0200 Subject: [PATCH 008/103] refactor(rust): Make all emw function expr non-anonymous (#11638) Co-authored-by: Weijie Guo --- .../src/legacy/kernels/ewm/mod.rs | 5 +++- .../polars-plan/src/dsl/function_expr/ewm.rs | 13 ++++++++ .../polars-plan/src/dsl/function_expr/mod.rs | 26 ++++++++++++++++ .../src/dsl/function_expr/schema.rs | 6 ++++ crates/polars-plan/src/dsl/mod.rs | 30 ++----------------- 5 files changed, 52 insertions(+), 28 deletions(-) create mode 100644 crates/polars-plan/src/dsl/function_expr/ewm.rs diff --git a/crates/polars-arrow/src/legacy/kernels/ewm/mod.rs b/crates/polars-arrow/src/legacy/kernels/ewm/mod.rs index 8f45bbbef2fb..5984106f1521 100644 --- a/crates/polars-arrow/src/legacy/kernels/ewm/mod.rs +++ b/crates/polars-arrow/src/legacy/kernels/ewm/mod.rs @@ -2,9 +2,12 @@ mod average; mod variance; pub use average::*; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; pub use variance::*; -#[derive(Debug, Copy, Clone)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Copy, Clone, PartialEq)] #[must_use] pub struct EWMOptions { pub alpha: f64, diff --git a/crates/polars-plan/src/dsl/function_expr/ewm.rs b/crates/polars-plan/src/dsl/function_expr/ewm.rs new file mode 100644 index 000000000000..a26285eef33a --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/ewm.rs @@ -0,0 +1,13 @@ +use super::*; + +pub(super) fn ewm_mean(s: &Series, options: EWMOptions) -> PolarsResult { + s.ewm_mean(options) +} + +pub(super) fn ewm_std(s: &Series, options: EWMOptions) -> PolarsResult { + s.ewm_std(options) +} + +pub(super) fn ewm_var(s: &Series, options: EWMOptions) -> PolarsResult { + s.ewm_var(options) +} diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index 7f51397f8dfe..24a61df540f0 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -18,6 +18,8 @@ mod cum; #[cfg(feature = "temporal")] mod datetime; mod dispatch; +#[cfg(feature = "ewma")] +mod ewm; mod fill_null; #[cfg(feature = "fused")] mod fused; @@ -259,6 +261,18 @@ pub enum FunctionExpr { SumHorizontal, MaxHorizontal, MinHorizontal, + #[cfg(feature = "ewma")] + EwmMean { + options: EWMOptions, + }, + #[cfg(feature = "ewma")] + EwmStd { + options: EWMOptions, + }, + #[cfg(feature = "ewma")] + EwmVar { + options: EWMOptions, + }, } impl Hash for FunctionExpr { @@ -433,6 +447,12 @@ impl Display for FunctionExpr { SumHorizontal => "sum_horizontal", MaxHorizontal => "max_horizontal", MinHorizontal => "min_horizontal", + #[cfg(feature = "ewma")] + EwmMean { .. } => "ewm_mean", + #[cfg(feature = "ewma")] + EwmStd { .. } => "ewm_std", + #[cfg(feature = "ewma")] + EwmVar { .. } => "ewm_var", }; write!(f, "{s}") } @@ -755,6 +775,12 @@ impl From for SpecialEq> { SumHorizontal => map_as_slice!(dispatch::sum_horizontal), MaxHorizontal => wrap!(dispatch::max_horizontal), MinHorizontal => wrap!(dispatch::min_horizontal), + #[cfg(feature = "ewma")] + EwmMean { options } => map!(ewm::ewm_mean, options), + #[cfg(feature = "ewma")] + EwmStd { options } => map!(ewm::ewm_std, options), + #[cfg(feature = "ewma")] + EwmVar { options } => map!(ewm::ewm_var, options), } } } diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index 3d8996e74431..3cb2e2a852f7 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -242,6 +242,12 @@ impl FunctionExpr { SumHorizontal => mapper.map_to_supertype(), MaxHorizontal => mapper.map_to_supertype(), MinHorizontal => mapper.map_to_supertype(), + #[cfg(feature = "ewma")] + EwmMean { .. } => mapper.map_to_float_dtype(), + #[cfg(feature = "ewma")] + EwmStd { .. } => mapper.map_to_float_dtype(), + #[cfg(feature = "ewma")] + EwmVar { .. } => mapper.map_to_float_dtype(), } } } diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index 3b335e6d5475..aff4c2378a12 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -1666,43 +1666,19 @@ impl Expr { #[cfg(feature = "ewma")] /// Calculate the exponentially-weighted moving average. pub fn ewm_mean(self, options: EWMOptions) -> Self { - use DataType::*; - self.apply( - move |s| s.ewm_mean(options).map(Some), - GetOutput::map_dtype(|dt| match dt { - Float64 | Float32 => dt.clone(), - _ => Float64, - }), - ) - .with_fmt("ewm_mean") + self.apply_private(FunctionExpr::EwmMean { options }) } #[cfg(feature = "ewma")] /// Calculate the exponentially-weighted moving standard deviation. pub fn ewm_std(self, options: EWMOptions) -> Self { - use DataType::*; - self.apply( - move |s| s.ewm_std(options).map(Some), - GetOutput::map_dtype(|dt| match dt { - Float64 | Float32 => dt.clone(), - _ => Float64, - }), - ) - .with_fmt("ewm_std") + self.apply_private(FunctionExpr::EwmStd { options }) } #[cfg(feature = "ewma")] /// Calculate the exponentially-weighted moving variance. pub fn ewm_var(self, options: EWMOptions) -> Self { - use DataType::*; - self.apply( - move |s| s.ewm_var(options).map(Some), - GetOutput::map_dtype(|dt| match dt { - Float64 | Float32 => dt.clone(), - _ => Float64, - }), - ) - .with_fmt("ewm_var") + self.apply_private(FunctionExpr::EwmVar { options }) } /// Returns whether any of the values in the column are `true`. From 85898361874675bac01cb15ae41405023ecd25b0 Mon Sep 17 00:00:00 2001 From: cmdlineluser <99486669+cmdlineluser@users.noreply.github.com> Date: Mon, 16 Oct 2023 19:14:51 +0100 Subject: [PATCH 009/103] feat(python,rust,cli): add `DATE` function for SQL (#11541) --- crates/polars-sql/Cargo.toml | 2 +- crates/polars-sql/src/functions.rs | 44 +++++++++++++++++++++++++++- py-polars/tests/unit/sql/test_sql.py | 25 ++++++++++++++++ 3 files changed, 69 insertions(+), 2 deletions(-) diff --git a/crates/polars-sql/Cargo.toml b/crates/polars-sql/Cargo.toml index 8e2feb97ba53..b03757ec4435 100644 --- a/crates/polars-sql/Cargo.toml +++ b/crates/polars-sql/Cargo.toml @@ -12,7 +12,7 @@ description = "SQL transpiler for Polars. Converts SQL to Polars logical plans" arrow = { workspace = true } polars-core = { workspace = true } polars-error = { workspace = true } -polars-lazy = { workspace = true, features = ["strings", "cross_join", "trigonometry", "abs", "round_series", "log", "regex", "is_in", "meta", "cum_agg"] } +polars-lazy = { workspace = true, features = ["strings", "cross_join", "trigonometry", "abs", "round_series", "log", "regex", "is_in", "meta", "cum_agg", "dtype-date"] } polars-plan = { workspace = true } rand = { workspace = true } diff --git a/crates/polars-sql/src/functions.rs b/crates/polars-sql/src/functions.rs index 70f9cabe7669..84f6987c3e0b 100644 --- a/crates/polars-sql/src/functions.rs +++ b/crates/polars-sql/src/functions.rs @@ -2,8 +2,8 @@ use polars_core::prelude::{polars_bail, polars_err, PolarsResult}; use polars_lazy::dsl::Expr; use polars_plan::dsl::{coalesce, count, when}; use polars_plan::logical_plan::LiteralValue; -use polars_plan::prelude::lit; use polars_plan::prelude::LiteralValue::Null; +use polars_plan::prelude::{lit, StrptimeOptions}; use sqlparser::ast::{ Expr as SqlExpr, Function as SQLFunction, FunctionArg, FunctionArgExpr, Value as SqlValue, WindowSpec, WindowType, @@ -217,6 +217,16 @@ pub(crate) enum PolarsSqlFunctions { /// ``` Radians, + // ---- + // Date Functions + // ---- + /// SQL 'date' function + /// ```sql + /// SELECT DATE('2021-03-15') from df; + /// SELECT DATE('2021-03', '%Y-%m') from df; + /// ``` + Date, + // ---- // String functions // ---- @@ -471,6 +481,7 @@ impl PolarsSqlFunctions { "cot", "cotd", "count", + "date", "degrees", "ends_with", "exp", @@ -559,6 +570,11 @@ impl PolarsSqlFunctions { "nullif" => Self::NullIf, "coalesce" => Self::Coalesce, + // ---- + // Date functions + // ---- + "date" => Self::Date, + // ---- // String functions // ---- @@ -718,6 +734,14 @@ impl SqlFunctionVisitor<'_> { }), _ => polars_bail!(InvalidOperation:"Invalid number of arguments for RegexpLike: {}",function.args.len()), }, + Date => match function.args.len() { + 1 => self.visit_unary(|e| e.str().to_date(StrptimeOptions::default())), + 2 => self.visit_binary(|e, fmt| e.str().to_date(fmt)), + _ => polars_bail!(InvalidOperation: + "Invalid number of arguments for Date: {}", + function.args.len() + ), + }, RTrim => match function.args.len() { 1 => self.visit_unary(|e| e.str().strip_chars_end(lit(Null))), 2 => self.visit_binary(|e, s| e.str().strip_chars_end(s)), @@ -1076,6 +1100,24 @@ impl FromSqlExpr for String { } } +impl FromSqlExpr for StrptimeOptions { + fn from_sql_expr(expr: &SqlExpr, _: &mut SQLContext) -> PolarsResult + where + Self: Sized, + { + match expr { + SqlExpr::Value(v) => match v { + SqlValue::SingleQuotedString(s) => Ok(StrptimeOptions { + format: Some(s.clone()), + ..StrptimeOptions::default() + }), + _ => polars_bail!(ComputeError: "can't parse literal {:?}", v), + }, + _ => polars_bail!(ComputeError: "can't parse literal {:?}", expr), + } + } +} + impl FromSqlExpr for Expr { fn from_sql_expr(expr: &SqlExpr, ctx: &mut SQLContext) -> PolarsResult where diff --git a/py-polars/tests/unit/sql/test_sql.py b/py-polars/tests/unit/sql/test_sql.py index aad05b5f7b5e..73faef160377 100644 --- a/py-polars/tests/unit/sql/test_sql.py +++ b/py-polars/tests/unit/sql/test_sql.py @@ -1,5 +1,6 @@ from __future__ import annotations +import datetime import math from pathlib import Path @@ -1208,3 +1209,27 @@ def test_sql_unary_ops_8890(match_float: bool) -> None: "c": [-3, -3], "d": [4, 4], } + + +def test_sql_date() -> None: + df = pl.DataFrame( + { + "date": [ + datetime.date(2021, 3, 15), + datetime.date(2021, 3, 28), + datetime.date(2021, 4, 4), + ], + "version": ["0.0.1", "0.7.3", "0.7.4"], + } + ) + + with pl.SQLContext(df=df, eager_execution=True) as ctx: + expected = pl.DataFrame({"date": [True, False, False]}) + assert ctx.execute("SELECT date < DATE('2021-03-20') from df").frame_equal( + expected + ) + + expected = pl.DataFrame({"literal": ["2023-03-01"]}) + assert pl.select( + pl.sql_expr("""CAST(DATE('2023-03', '%Y-%m') as STRING)""") + ).frame_equal(expected) From 6e886f9e427f2518fca2684abd6d96ac6ef6bdc9 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Mon, 16 Oct 2023 22:16:18 +0400 Subject: [PATCH 010/103] chore: more granular polars-ops imports (#11760) --- crates/polars-lazy/src/physical_plan/executors/join.rs | 2 ++ .../polars-lazy/src/physical_plan/expressions/aggregation.rs | 2 ++ crates/polars-lazy/src/physical_plan/expressions/binary.rs | 2 ++ crates/polars-lazy/src/physical_plan/expressions/sort.rs | 1 + crates/polars-lazy/src/physical_plan/expressions/window.rs | 1 + crates/polars-lazy/src/prelude.rs | 3 ++- 6 files changed, 10 insertions(+), 1 deletion(-) diff --git a/crates/polars-lazy/src/physical_plan/executors/join.rs b/crates/polars-lazy/src/physical_plan/executors/join.rs index fa84d46e7a84..5898aea109ca 100644 --- a/crates/polars-lazy/src/physical_plan/executors/join.rs +++ b/crates/polars-lazy/src/physical_plan/executors/join.rs @@ -1,3 +1,5 @@ +use polars_ops::frame::DataFrameJoinOps; + use super::*; pub struct JoinExec { diff --git a/crates/polars-lazy/src/physical_plan/expressions/aggregation.rs b/crates/polars-lazy/src/physical_plan/expressions/aggregation.rs index cd22b26d6a21..be48122eb520 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/aggregation.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/aggregation.rs @@ -11,6 +11,8 @@ use polars_core::prelude::*; use polars_core::utils::NoNull; #[cfg(feature = "dtype-struct")] use polars_core::POOL; +#[cfg(feature = "propagate_nans")] +use polars_ops::prelude::nan_propagating_aggregate; use crate::physical_plan::state::ExecutionState; use crate::physical_plan::PartitionedAggregation; diff --git a/crates/polars-lazy/src/physical_plan/expressions/binary.rs b/crates/polars-lazy/src/physical_plan/expressions/binary.rs index b6cf6fcd177f..ef20e7abdc33 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/binary.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/binary.rs @@ -3,6 +3,8 @@ use std::sync::Arc; use polars_core::frame::group_by::GroupsProxy; use polars_core::prelude::*; use polars_core::POOL; +#[cfg(feature = "round_series")] +use polars_ops::prelude::floor_div_series; use crate::physical_plan::state::ExecutionState; use crate::prelude::*; diff --git a/crates/polars-lazy/src/physical_plan/expressions/sort.rs b/crates/polars-lazy/src/physical_plan/expressions/sort.rs index 0047a4d9e118..3fdabd22ee3e 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/sort.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/sort.rs @@ -4,6 +4,7 @@ use arrow::legacy::utils::CustomIterTools; use polars_core::frame::group_by::GroupsProxy; use polars_core::prelude::*; use polars_core::POOL; +use polars_ops::chunked_array::ListNameSpaceImpl; use rayon::prelude::*; use crate::physical_plan::state::ExecutionState; diff --git a/crates/polars-lazy/src/physical_plan/expressions/window.rs b/crates/polars-lazy/src/physical_plan/expressions/window.rs index 4f98e63ad7d4..e74d4c800210 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/window.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/window.rs @@ -11,6 +11,7 @@ use polars_core::{downcast_as_macro_arg_physical, POOL}; use polars_ops::frame::join::{ default_join_ids, private_left_join_multiple_keys, ChunkJoinOptIds, JoinValidation, }; +use polars_ops::frame::SeriesJoin; use polars_utils::format_smartstring; use polars_utils::sort::perfect_sort; use polars_utils::sync::SyncPtr; diff --git a/crates/polars-lazy/src/prelude.rs b/crates/polars-lazy/src/prelude.rs index 81c31fe943db..a5baeda1ea78 100644 --- a/crates/polars-lazy/src/prelude.rs +++ b/crates/polars-lazy/src/prelude.rs @@ -1,5 +1,6 @@ -pub(crate) use polars_ops::prelude::*; pub use polars_ops::prelude::{JoinArgs, JoinType, JoinValidation}; +#[cfg(feature = "rank")] +pub use polars_ops::prelude::{RankMethod, RankOptions}; pub use polars_plan::logical_plan::{ AnonymousScan, AnonymousScanOptions, Literal, LiteralValue, LogicalPlan, Null, NULL, }; From 5d48cc800bc9c71fe6d4ff97b96d7fed4601793b Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Mon, 16 Oct 2023 23:49:28 +0400 Subject: [PATCH 011/103] feat(python): primitive kwargs in plugins (#11268) --- crates/polars-arrow/src/ffi/schema.rs | 4 ++ crates/polars-ffi/src/lib.rs | 16 +++++ .../polars-plan/src/dsl/function_expr/mod.rs | 26 ++++++- .../src/dsl/function_expr/plugin.rs | 72 ++++++++++++++++--- .../src/dsl/function_expr/schema.rs | 2 +- py-polars/polars/expr/expr.py | 14 +++- py-polars/src/expr/general.rs | 8 ++- 7 files changed, 123 insertions(+), 19 deletions(-) diff --git a/crates/polars-arrow/src/ffi/schema.rs b/crates/polars-arrow/src/ffi/schema.rs index ebbf5b8f6c76..ded8215ca3ce 100644 --- a/crates/polars-arrow/src/ffi/schema.rs +++ b/crates/polars-arrow/src/ffi/schema.rs @@ -145,6 +145,10 @@ impl ArrowSchema { } } + pub fn is_null(&self) -> bool { + self.private_data.is_null() + } + /// returns the format of this schema. pub(crate) fn format(&self) -> &str { assert!(!self.format.is_null()); diff --git a/crates/polars-ffi/src/lib.rs b/crates/polars-ffi/src/lib.rs index 699d5e7a7fd5..59b2b0b8d9e9 100644 --- a/crates/polars-ffi/src/lib.rs +++ b/crates/polars-ffi/src/lib.rs @@ -24,6 +24,22 @@ pub struct SeriesExport { private_data: *mut std::os::raw::c_void, } +impl SeriesExport { + pub fn empty() -> Self { + Self { + field: std::ptr::null_mut(), + arrays: std::ptr::null_mut(), + len: 0, + release: None, + private_data: std::ptr::null_mut(), + } + } + + pub fn is_null(&self) -> bool { + self.private_data.is_null() + } +} + impl Drop for SeriesExport { fn drop(&mut self) { if let Some(release) = self.release { diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index 24a61df540f0..cd71fc6c96a0 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -248,9 +248,15 @@ pub enum FunctionExpr { }, SetSortedFlag(IsSorted), #[cfg(feature = "ffi_plugin")] + /// Creating this node is unsafe + /// This will lead to calls over FFI> FfiPlugin { + /// Shared library. lib: Arc, + /// Identifier in the shared lib. symbol: Arc, + /// Pickle serialized keyword arguments. + kwargs: Arc<[u8]>, }, BackwardFill { limit: FillNullLimit, @@ -309,7 +315,12 @@ impl Hash for FunctionExpr { #[cfg(feature = "dtype-categorical")] FunctionExpr::Categorical(f) => f.hash(state), #[cfg(feature = "ffi_plugin")] - FunctionExpr::FfiPlugin { lib, symbol } => { + FunctionExpr::FfiPlugin { + lib, + symbol, + kwargs, + } => { + kwargs.hash(state); lib.hash(state); symbol.hash(state); }, @@ -767,8 +778,17 @@ impl From for SpecialEq> { }, SetSortedFlag(sorted) => map!(dispatch::set_sorted_flag, sorted), #[cfg(feature = "ffi_plugin")] - FfiPlugin { lib, symbol, .. } => unsafe { - map_as_slice!(plugin::call_plugin, lib.as_ref(), symbol.as_ref()) + FfiPlugin { + lib, + symbol, + kwargs, + } => unsafe { + map_as_slice!( + plugin::call_plugin, + lib.as_ref(), + symbol.as_ref(), + kwargs.as_ref() + ) }, BackwardFill { limit } => map!(dispatch::backward_fill, limit), ForwardFill { limit } => map!(dispatch::forward_fill, limit), diff --git a/crates/polars-plan/src/dsl/function_expr/plugin.rs b/crates/polars-plan/src/dsl/function_expr/plugin.rs index 6c8113a54aac..85fea0edf7b8 100644 --- a/crates/polars-plan/src/dsl/function_expr/plugin.rs +++ b/crates/polars-plan/src/dsl/function_expr/plugin.rs @@ -1,3 +1,4 @@ +use std::ffi::CString; use std::sync::RwLock; use arrow::ffi::{import_field_from_c, ArrowSchema}; @@ -30,24 +31,59 @@ fn get_lib(lib: &str) -> PolarsResult<&'static Library> { } } -pub(super) unsafe fn call_plugin(s: &[Series], lib: &str, symbol: &str) -> PolarsResult { +unsafe fn retrieve_error_msg(lib: &Library) -> CString { + let symbol: libloading::Symbol *mut std::os::raw::c_char> = + lib.get(b"get_last_error_message\0").unwrap(); + let msg_ptr = symbol(); + CString::from_raw(msg_ptr) +} + +pub(super) unsafe fn call_plugin( + s: &[Series], + lib: &str, + symbol: &str, + kwargs: &[u8], +) -> PolarsResult { let lib = get_lib(lib)?; + // *const SeriesExport: pointer to Box + // * usize: length of that pointer + // *const u8: pointer to &[u8] + // usize: length of the u8 slice + // *mut SeriesExport: pointer where return value should be written. let symbol: libloading::Symbol< - unsafe extern "C" fn(*const SeriesExport, usize) -> SeriesExport, + unsafe extern "C" fn(*const SeriesExport, usize, *const u8, usize, *mut SeriesExport), > = lib.get(symbol.as_bytes()).unwrap(); - let n_args = s.len(); - let input = s.iter().map(export_series).collect::>(); + let input_len = s.len(); let slice_ptr = input.as_ptr(); - let out = symbol(slice_ptr, n_args); + let kwargs_ptr = kwargs.as_ptr(); + let kwargs_len = kwargs.len(); + + let mut return_value = SeriesExport::empty(); + let return_value_ptr = &mut return_value as *mut SeriesExport; + symbol( + slice_ptr, + input_len, + kwargs_ptr, + kwargs_len, + return_value_ptr, + ); + + // The inputs get dropped when the ffi side calls the drop callback. for e in input { std::mem::forget(e); } - import_series(out) + if !return_value.is_null() { + import_series(return_value) + } else { + let msg = retrieve_error_msg(lib); + let msg = msg.to_string_lossy(); + polars_bail!(ComputeError: "the plugin failed with message: {}", msg) + } } pub(super) unsafe fn plugin_field( @@ -57,8 +93,12 @@ pub(super) unsafe fn plugin_field( ) -> PolarsResult { let lib = get_lib(lib)?; - let symbol: libloading::Symbol ArrowSchema> = - lib.get(symbol.as_bytes()).unwrap(); + // *const ArrowSchema: pointer to heap Box + // usize: length of the boxed slice + // *mut ArrowSchema: pointer where the return value can be written + let symbol: libloading::Symbol< + unsafe extern "C" fn(*const ArrowSchema, usize, *mut ArrowSchema), + > = lib.get(symbol.as_bytes()).unwrap(); // we deallocate the fields buffer let fields = fields @@ -68,8 +108,18 @@ pub(super) unsafe fn plugin_field( .into_boxed_slice(); let n_args = fields.len(); let slice_ptr = fields.as_ptr(); - let out = symbol(slice_ptr, n_args); - let arrow_field = import_field_from_c(&out)?; - Ok(Field::from(&arrow_field)) + let mut return_value = ArrowSchema::empty(); + let return_value_ptr = &mut return_value as *mut ArrowSchema; + symbol(slice_ptr, n_args, return_value_ptr); + + if !return_value.is_null() { + let arrow_field = import_field_from_c(&return_value)?; + let out = Field::from(&arrow_field); + Ok(out) + } else { + let msg = retrieve_error_msg(lib); + let msg = msg.to_string_lossy(); + polars_bail!(ComputeError: "the plugin failed with message: {}", msg) + } } diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index 3cb2e2a852f7..9baa4a0e5c89 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -234,7 +234,7 @@ impl FunctionExpr { Random { .. } => mapper.with_same_dtype(), SetSortedFlag(_) => mapper.with_same_dtype(), #[cfg(feature = "ffi_plugin")] - FfiPlugin { lib, symbol } => unsafe { + FfiPlugin { lib, symbol, .. } => unsafe { plugin::plugin_field(fields, lib, &format!("__polars_field_{}", symbol.as_ref())) }, BackwardFill { .. } => mapper.with_same_dtype(), diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index e9eb033ad1af..a86da90bc996 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -9508,10 +9508,11 @@ def is_last(self) -> Self: def _register_plugin( self, + *, lib: str, symbol: str, args: list[IntoExpr] | None = None, - *, + kwargs: dict[Any, Any] | None = None, is_elementwise: bool = False, input_wildcard_expansion: bool = False, auto_explode: bool = False, @@ -9536,6 +9537,9 @@ def _register_plugin( Function to load. args Arguments (other than self) passed to this function. + These arguments have to be of type Expression. + kwargs + Non-expression arguments. They must be JSON serializable. is_elementwise If the function only operates on scalars this will trigger fast paths. @@ -9552,11 +9556,19 @@ def _register_plugin( args = [] else: args = [parse_as_expression(a) for a in args] + if kwargs is None: + serialized_kwargs = b"" + else: + import pickle + + serialized_kwargs = pickle.dumps(kwargs, protocol=2) + return self._from_pyexpr( self._pyexpr.register_plugin( lib, symbol, args, + serialized_kwargs, is_elementwise, input_wildcard_expansion, auto_explode, diff --git a/py-polars/src/expr/general.rs b/py-polars/src/expr/general.rs index 88173745d9b1..a14424027441 100644 --- a/py-polars/src/expr/general.rs +++ b/py-polars/src/expr/general.rs @@ -892,11 +892,12 @@ impl PyExpr { lib: &str, symbol: &str, args: Vec, + kwargs: Vec, is_elementwise: bool, input_wildcard_expansion: bool, auto_explode: bool, cast_to_supertypes: bool, - ) -> Self { + ) -> PyResult { use polars_plan::prelude::*; let inner = self.inner.clone(); @@ -911,11 +912,12 @@ impl PyExpr { input.push(a.inner) } - Expr::Function { + Ok(Expr::Function { input, function: FunctionExpr::FfiPlugin { lib: Arc::from(lib), symbol: Arc::from(symbol), + kwargs: Arc::from(kwargs), }, options: FunctionOptions { collect_groups, @@ -925,6 +927,6 @@ impl PyExpr { ..Default::default() }, } - .into() + .into()) } } From 17e14028669eec4e439e1efe7c7b19257e7fd191 Mon Sep 17 00:00:00 2001 From: Brayan Jules Date: Mon, 16 Oct 2023 18:45:03 -0300 Subject: [PATCH 012/103] fix: fixed validate_is_number import issue, also added missing dataframe validation on unit tests --- crates/polars-core/src/chunked_array/cast.rs | 3 +- .../src/chunked_array/temporal/mod.rs | 4 +- py-polars/tests/unit/test_queries.py | 49 +++++++++++++++++-- 3 files changed, 48 insertions(+), 8 deletions(-) diff --git a/crates/polars-core/src/chunked_array/cast.rs b/crates/polars-core/src/chunked_array/cast.rs index 214ed5f19482..365cd88cafb3 100644 --- a/crates/polars-core/src/chunked_array/cast.rs +++ b/crates/polars-core/src/chunked_array/cast.rs @@ -5,9 +5,8 @@ use arrow::compute::cast::CastOptions; #[cfg(feature = "dtype-categorical")] use crate::chunked_array::categorical::CategoricalChunkedBuilder; -use crate::chunked_array::temporal::validate_is_number; #[cfg(feature = "timezones")] -use crate::chunked_array::temporal::validate_time_zone; +use crate::chunked_array::temporal::{validate_is_number, validate_time_zone}; use crate::prelude::DataType::Datetime; use crate::prelude::*; diff --git a/crates/polars-core/src/chunked_array/temporal/mod.rs b/crates/polars-core/src/chunked_array/temporal/mod.rs index c6ea220b7d21..3b6a38aede8b 100644 --- a/crates/polars-core/src/chunked_array/temporal/mod.rs +++ b/crates/polars-core/src/chunked_array/temporal/mod.rs @@ -37,8 +37,8 @@ pub(crate) fn validate_time_zone(tz: &str) -> PolarsResult<()> { } } -pub(crate) fn validate_is_number(vec_array: &Vec) -> bool { - vec_array.iter().all(|array| is_parsable_as_number(array)) +pub(crate) fn validate_is_number(vec_array: &[ArrayRef]) -> bool { + vec_array.iter().all(is_parsable_as_number) } fn is_parsable_as_number(array: &ArrayRef) -> bool { diff --git a/py-polars/tests/unit/test_queries.py b/py-polars/tests/unit/test_queries.py index 8db9c1283904..af623feb0e2c 100644 --- a/py-polars/tests/unit/test_queries.py +++ b/py-polars/tests/unit/test_queries.py @@ -1,6 +1,6 @@ from __future__ import annotations -from datetime import datetime, timedelta +from datetime import datetime, timedelta, date from typing import Any import numpy as np @@ -378,9 +378,11 @@ def test_utf8_date() -> None: df = pl.DataFrame({"x1": ["2021-01-01"]}).with_columns( **{"x1-date": pl.col("x1").cast(pl.Date)} ) + expected = pl.DataFrame({"x1-date":[date(2021,1,1)]}) out = df.select(pl.col("x1-date")) assert out.shape == (1, 1) assert out.dtypes == [pl.Date] + assert_frame_equal(expected, out) def test_wrong_utf8_date() -> None: @@ -400,12 +402,26 @@ def test_utf8_datetime() -> None: "x1-datetime-us": pl.col("x1").cast(pl.Datetime(time_unit="us")), } ) + first_row = datetime(year=2021, month=12, day=19, hour=00, minute=39, second=57) + second_row = datetime(year=2022, month=12, day=19, hour=16, minute=39, second=57) + expected = pl.DataFrame( + { + "x1-datetime-ns": [first_row, second_row], + "x1-datetime-ms": [first_row, second_row], + "x1-datetime-us": [first_row, second_row] + } + ).select( + pl.col("x1-datetime-ns").dt.cast_time_unit("ns"), + pl.col("x1-datetime-ms").dt.cast_time_unit("ms"), + pl.col("x1-datetime-us").dt.cast_time_unit("us"), + ) out = df.select( pl.col("x1-datetime-ns"), pl.col("x1-datetime-ms"), pl.col("x1-datetime-us") ) assert out.shape == (2, 3) assert out.dtypes == [pl.Datetime, pl.Datetime, pl.Datetime] + assert_frame_equal(expected, out) def test_wrong_utf8_datetime() -> None: @@ -417,22 +433,46 @@ def test_wrong_utf8_datetime() -> None: def test_utf8_datetime_timezone() -> None: + ccs_tz = "America/Caracas" + stg_tz = "America/Santiago" + utc_tz = "UTC" df = pl.DataFrame( {"x1": ["1996-12-19T16:39:57 +00:00", "2022-12-19T00:39:57 +00:00"]} ).with_columns( **{ "x1-datetime-ns": pl.col("x1").cast( - pl.Datetime(time_unit="ns", time_zone="America/Caracas") + pl.Datetime(time_unit="ns", time_zone=ccs_tz) ), "x1-datetime-ms": pl.col("x1").cast( - pl.Datetime(time_unit="ms", time_zone="America/Santiago") + pl.Datetime(time_unit="ms", time_zone=stg_tz) ), "x1-datetime-us": pl.col("x1").cast( - pl.Datetime(time_unit="us", time_zone="UTC") + pl.Datetime(time_unit="us", time_zone=utc_tz) ), } ) + expected = pl.DataFrame( + { + "x1-datetime-ns": [ + datetime(year=1996, month=12, day=19, hour=12, minute=39, second=57), + datetime(year=2022, month=12, day=18, hour=20, minute=39, second=57), + ], + "x1-datetime-ms": [ + datetime(year=1996, month=12, day=19, hour=13, minute=39, second=57), + datetime(year=2022, month=12, day=18, hour=21, minute=39, second=57), + ], + "x1-datetime-us": [ + datetime(year=1996, month=12, day=19, hour=16, minute=39, second=57), + datetime(year=2022, month=12, day=19, hour=00, minute=39, second=57), + ], + } + ).select( + pl.col("x1-datetime-ns").dt.cast_time_unit("ns").dt.replace_time_zone(ccs_tz), + pl.col("x1-datetime-ms").dt.cast_time_unit("ms").dt.replace_time_zone(stg_tz), + pl.col("x1-datetime-us").dt.cast_time_unit("us").dt.replace_time_zone(utc_tz), + ) + out = df.select( pl.col("x1-datetime-ns"), pl.col("x1-datetime-ms"), pl.col("x1-datetime-us") ) @@ -443,3 +483,4 @@ def test_utf8_datetime_timezone() -> None: pl.Datetime("ms", "America/Santiago"), pl.Datetime("us", "UTC"), ] + assert_frame_equal(expected, out) From befe30893a3e69c68b7efa5d6dc26b8bc7039146 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Tue, 17 Oct 2023 00:52:10 +0200 Subject: [PATCH 013/103] refactor(python): Rename `IntegralType` to `IntegerType` (#11773) --- py-polars/polars/datatypes/__init__.py | 4 +- py-polars/polars/datatypes/classes.py | 53 +++++++++++++------------- py-polars/polars/type_aliases.py | 4 +- 3 files changed, 30 insertions(+), 31 deletions(-) diff --git a/py-polars/polars/datatypes/__init__.py b/py-polars/polars/datatypes/__init__.py index dc69e38eca39..9c5136f8d5e6 100644 --- a/py-polars/polars/datatypes/__init__.py +++ b/py-polars/polars/datatypes/__init__.py @@ -18,7 +18,7 @@ Int16, Int32, Int64, - IntegralType, + IntegerType, List, Null, NumericType, @@ -93,7 +93,7 @@ "Int32", "Int64", "Int8", - "IntegralType", + "IntegerType", "List", "Null", "NumericType", diff --git a/py-polars/polars/datatypes/classes.py b/py-polars/polars/datatypes/classes.py index 39e16d6d4ff2..0520866d76a8 100644 --- a/py-polars/polars/datatypes/classes.py +++ b/py-polars/polars/datatypes/classes.py @@ -45,24 +45,23 @@ def __repr__(cls) -> str: def _string_repr(cls) -> str: return _dtype_str_repr(cls) - def base_type(cls) -> DataTypeClass: - """Return the base type.""" - return cls + # Methods below defined here in signature only to satisfy mypy - @classproperty - def is_nested(self) -> bool: - """Check if this data type is nested.""" - return False + @classmethod + def base_type(cls) -> DataTypeClass: # noqa: D102 + ... @classmethod - def is_(cls, other: PolarsDataType) -> bool: - """Check if this DataType is the same as another DataType.""" - return cls == other and hash(cls) == hash(other) + def is_(cls, other: PolarsDataType) -> bool: # noqa: D102 + ... @classmethod - def is_not(cls, other: PolarsDataType) -> bool: - """Check if this DataType is NOT the same as another DataType.""" - return not cls.is_(other) + def is_not(cls, other: PolarsDataType) -> bool: # noqa: D102 + ... + + @classproperty + def is_nested(self) -> bool: # noqa: D102 + ... class DataType(metaclass=DataTypeClass): @@ -97,11 +96,6 @@ def base_type(cls) -> DataTypeClass: """ return cls - @classproperty - def is_nested(self) -> bool: - """Check if this data type is nested.""" - return False - @classinstmethod # type: ignore[arg-type] def is_(self, other: PolarsDataType) -> bool: """ @@ -148,6 +142,11 @@ def is_not(self, other: PolarsDataType) -> bool: """ return not self.is_(other) + @classproperty + def is_nested(self) -> bool: + """Check if this data type is nested.""" + return False + def _custom_reconstruct( cls: type[Any], base: type[Any], state: Any @@ -200,7 +199,7 @@ class NumericType(DataType): """Base class for numeric data types.""" -class IntegralType(NumericType): +class IntegerType(NumericType): """Base class for integral data types.""" @@ -225,35 +224,35 @@ def is_nested(self) -> bool: return True -class Int8(IntegralType): +class Int8(IntegerType): """8-bit signed integer type.""" -class Int16(IntegralType): +class Int16(IntegerType): """16-bit signed integer type.""" -class Int32(IntegralType): +class Int32(IntegerType): """32-bit signed integer type.""" -class Int64(IntegralType): +class Int64(IntegerType): """64-bit signed integer type.""" -class UInt8(IntegralType): +class UInt8(IntegerType): """8-bit unsigned integer type.""" -class UInt16(IntegralType): +class UInt16(IntegerType): """16-bit unsigned integer type.""" -class UInt32(IntegralType): +class UInt32(IntegerType): """32-bit unsigned integer type.""" -class UInt64(IntegralType): +class UInt64(IntegerType): """64-bit unsigned integer type.""" diff --git a/py-polars/polars/type_aliases.py b/py-polars/polars/type_aliases.py index ea6764e1d2af..a1bbb246b1bc 100644 --- a/py-polars/polars/type_aliases.py +++ b/py-polars/polars/type_aliases.py @@ -22,7 +22,7 @@ import sys from polars import DataFrame, Expr, LazyFrame, Series - from polars.datatypes import DataType, DataTypeClass, IntegralType, TemporalType + from polars.datatypes import DataType, DataTypeClass, IntegerType, TemporalType from polars.dependencies import numpy as np from polars.dependencies import pandas as pd from polars.dependencies import pyarrow as pa @@ -36,7 +36,7 @@ # Data types PolarsDataType: TypeAlias = Union["DataTypeClass", "DataType"] PolarsTemporalType: TypeAlias = Union[Type["TemporalType"], "TemporalType"] -PolarsIntegerType: TypeAlias = Union[Type["IntegralType"], "IntegralType"] +PolarsIntegerType: TypeAlias = Union[Type["IntegerType"], "IntegerType"] OneOrMoreDataTypes: TypeAlias = Union[PolarsDataType, Iterable[PolarsDataType]] PythonDataType: TypeAlias = Union[ Type[int], From 084a7e1c4838d3e5c98d8fda1bdad5646babd938 Mon Sep 17 00:00:00 2001 From: Brayan Jules Date: Mon, 16 Oct 2023 19:52:53 -0300 Subject: [PATCH 014/103] fix: fixed linter issues. --- crates/polars-core/src/chunked_array/cast.rs | 4 +++- py-polars/tests/unit/test_queries.py | 6 +++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/crates/polars-core/src/chunked_array/cast.rs b/crates/polars-core/src/chunked_array/cast.rs index 365cd88cafb3..96f5fa1ead04 100644 --- a/crates/polars-core/src/chunked_array/cast.rs +++ b/crates/polars-core/src/chunked_array/cast.rs @@ -5,8 +5,10 @@ use arrow::compute::cast::CastOptions; #[cfg(feature = "dtype-categorical")] use crate::chunked_array::categorical::CategoricalChunkedBuilder; +#[cfg(feature = "temporal")] +use crate::chunked_array::temporal::validate_is_number; #[cfg(feature = "timezones")] -use crate::chunked_array::temporal::{validate_is_number, validate_time_zone}; +use crate::chunked_array::temporal::validate_time_zone; use crate::prelude::DataType::Datetime; use crate::prelude::*; diff --git a/py-polars/tests/unit/test_queries.py b/py-polars/tests/unit/test_queries.py index af623feb0e2c..48e116c39656 100644 --- a/py-polars/tests/unit/test_queries.py +++ b/py-polars/tests/unit/test_queries.py @@ -1,6 +1,6 @@ from __future__ import annotations -from datetime import datetime, timedelta, date +from datetime import date, datetime, timedelta from typing import Any import numpy as np @@ -378,7 +378,7 @@ def test_utf8_date() -> None: df = pl.DataFrame({"x1": ["2021-01-01"]}).with_columns( **{"x1-date": pl.col("x1").cast(pl.Date)} ) - expected = pl.DataFrame({"x1-date":[date(2021,1,1)]}) + expected = pl.DataFrame({"x1-date": [date(2021, 1, 1)]}) out = df.select(pl.col("x1-date")) assert out.shape == (1, 1) assert out.dtypes == [pl.Date] @@ -408,7 +408,7 @@ def test_utf8_datetime() -> None: { "x1-datetime-ns": [first_row, second_row], "x1-datetime-ms": [first_row, second_row], - "x1-datetime-us": [first_row, second_row] + "x1-datetime-us": [first_row, second_row], } ).select( pl.col("x1-datetime-ns").dt.cast_time_unit("ns"), From 0cfce614f15ec10ee95fa740d079ea13ba5f3b4b Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Tue, 17 Oct 2023 12:19:33 +0800 Subject: [PATCH 015/103] chore(rust): Move cum_agg to polars-ops (#11770) --- crates/polars-core/Cargo.toml | 3 - .../src/chunked_array/ops/cum_agg.rs | 176 ------------ .../polars-core/src/chunked_array/ops/mod.rs | 22 -- .../src/series/implementations/dates_time.rs | 11 - .../src/series/implementations/datetime.rs | 16 -- .../src/series/implementations/duration.rs | 16 -- .../src/series/implementations/floats.rs | 10 - .../src/series/implementations/mod.rs | 10 - crates/polars-core/src/series/mod.rs | 88 ------ crates/polars-core/src/series/series_trait.rs | 12 - crates/polars-ops/Cargo.toml | 1 + crates/polars-ops/src/series/ops/cum_agg.rs | 268 ++++++++++++++++++ crates/polars-ops/src/series/ops/mod.rs | 4 + crates/polars-plan/Cargo.toml | 2 +- .../polars-plan/src/dsl/function_expr/cum.rs | 20 +- .../polars-plan/src/dsl/function_expr/mod.rs | 16 ++ .../src/dsl/function_expr/schema.rs | 5 + crates/polars-plan/src/dsl/mod.rs | 7 +- crates/polars/Cargo.toml | 2 +- crates/polars/src/lib.rs | 6 +- 20 files changed, 310 insertions(+), 385 deletions(-) delete mode 100644 crates/polars-core/src/chunked_array/ops/cum_agg.rs create mode 100644 crates/polars-ops/src/series/ops/cum_agg.rs diff --git a/crates/polars-core/Cargo.toml b/crates/polars-core/Cargo.toml index 012c5687d959..8a14cce9c024 100644 --- a/crates/polars-core/Cargo.toml +++ b/crates/polars-core/Cargo.toml @@ -89,8 +89,6 @@ reinterpret = [] take_opt_iter = [] # allow group_by operation on list type group_by_list = [] -# cumsum, cummin, etc. -cum_agg = [] # rolling window functions rolling_window = [] diff = [] @@ -150,7 +148,6 @@ docs-selection = [ "asof_join", "dot_product", "row_hash", - "cum_agg", "rolling_window", "diff", "moment", diff --git a/crates/polars-core/src/chunked_array/ops/cum_agg.rs b/crates/polars-core/src/chunked_array/ops/cum_agg.rs deleted file mode 100644 index a1b0f2e65ee0..000000000000 --- a/crates/polars-core/src/chunked_array/ops/cum_agg.rs +++ /dev/null @@ -1,176 +0,0 @@ -use std::iter::FromIterator; -use std::ops::{Add, AddAssign, Mul}; - -use num_traits::Bounded; - -use crate::prelude::*; -use crate::utils::CustomIterTools; - -fn det_max(state: &mut T, v: Option) -> Option> -where - T: Copy + PartialOrd + AddAssign + Add, -{ - match v { - Some(v) => { - if v > *state { - *state = v - } - Some(Some(*state)) - }, - None => Some(None), - } -} - -fn det_min(state: &mut T, v: Option) -> Option> -where - T: Copy + PartialOrd + AddAssign + Add, -{ - match v { - Some(v) => { - if v < *state { - *state = v - } - Some(Some(*state)) - }, - None => Some(None), - } -} - -fn det_sum(state: &mut Option, v: Option) -> Option> -where - T: Copy + PartialOrd + AddAssign + Add, -{ - match (*state, v) { - (Some(state_inner), Some(v)) => { - *state = Some(state_inner + v); - Some(*state) - }, - (None, Some(v)) => { - *state = Some(v); - Some(*state) - }, - (_, None) => Some(None), - } -} - -fn det_prod(state: &mut Option, v: Option) -> Option> -where - T: Copy + PartialOrd + Mul, -{ - match (*state, v) { - (Some(state_inner), Some(v)) => { - *state = Some(state_inner * v); - Some(*state) - }, - (None, Some(v)) => { - *state = Some(v); - Some(*state) - }, - (_, None) => Some(None), - } -} - -impl ChunkCumAgg for ChunkedArray -where - T: PolarsNumericType, - ChunkedArray: FromIterator>, -{ - fn cummax(&self, reverse: bool) -> ChunkedArray { - let init = Bounded::min_value(); - - let mut ca: Self = match reverse { - false => self.into_iter().scan(init, det_max).collect_trusted(), - true => self - .into_iter() - .rev() - .scan(init, det_max) - .collect_reversed(), - }; - - ca.rename(self.name()); - ca - } - - fn cummin(&self, reverse: bool) -> ChunkedArray { - let init = Bounded::max_value(); - let mut ca: Self = match reverse { - false => self.into_iter().scan(init, det_min).collect_trusted(), - true => self - .into_iter() - .rev() - .scan(init, det_min) - .collect_reversed(), - }; - - ca.rename(self.name()); - ca - } - - fn cumsum(&self, reverse: bool) -> ChunkedArray { - let init = None; - let mut ca: Self = match reverse { - false => self.into_iter().scan(init, det_sum).collect_trusted(), - true => self - .into_iter() - .rev() - .scan(init, det_sum) - .collect_reversed(), - }; - - ca.rename(self.name()); - ca - } - - fn cumprod(&self, reverse: bool) -> ChunkedArray { - let init = None; - let mut ca: Self = match reverse { - false => self.into_iter().scan(init, det_prod).collect_trusted(), - true => self - .into_iter() - .rev() - .scan(init, det_prod) - .collect_reversed(), - }; - - ca.rename(self.name()); - ca - } -} - -#[cfg(test)] -mod test { - use crate::prelude::*; - - #[test] - #[cfg(feature = "dtype-u8")] - fn test_cummax() { - let ca = UInt8Chunked::new("foo", &[None, Some(1), Some(3), None, Some(1)]); - let out = ca.cummax(true); - assert_eq!(Vec::from(&out), &[None, Some(3), Some(3), None, Some(1)]); - let out = ca.cummax(false); - assert_eq!(Vec::from(&out), &[None, Some(1), Some(3), None, Some(3)]); - } - - #[test] - #[cfg(feature = "dtype-u8")] - fn test_cummin() { - let ca = UInt8Chunked::new("foo", &[None, Some(1), Some(3), None, Some(2)]); - let out = ca.cummin(true); - assert_eq!(Vec::from(&out), &[None, Some(1), Some(2), None, Some(2)]); - let out = ca.cummin(false); - assert_eq!(Vec::from(&out), &[None, Some(1), Some(1), None, Some(1)]); - } - - #[test] - fn test_cumsum() { - let ca = Int32Chunked::new("foo", &[None, Some(1), Some(3), None, Some(1)]); - let out = ca.cumsum(true); - assert_eq!(Vec::from(&out), &[None, Some(5), Some(4), None, Some(1)]); - let out = ca.cumsum(false); - assert_eq!(Vec::from(&out), &[None, Some(1), Some(4), None, Some(5)]); - - // just check if the trait bounds allow for floats - let ca = Float32Chunked::new("foo", &[None, Some(1.0), Some(3.0), None, Some(1.0)]); - let _out = ca.cumsum(false); - } -} diff --git a/crates/polars-core/src/chunked_array/ops/mod.rs b/crates/polars-core/src/chunked_array/ops/mod.rs index 9a89da39968d..6a54b4ae9b5a 100644 --- a/crates/polars-core/src/chunked_array/ops/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/mod.rs @@ -16,8 +16,6 @@ pub mod arity; mod bit_repr; pub(crate) mod chunkops; pub(crate) mod compare_inner; -#[cfg(feature = "cum_agg")] -mod cum_agg; #[cfg(feature = "dtype-decimal")] mod decimal; pub(crate) mod downcast; @@ -91,26 +89,6 @@ pub trait ChunkAnyValue { fn get_any_value(&self, index: usize) -> PolarsResult; } -#[cfg(feature = "cum_agg")] -pub trait ChunkCumAgg { - /// Get an array with the cumulative max computed at every element - fn cummax(&self, _reverse: bool) -> ChunkedArray { - panic!("operation cummax not supported for this dtype") - } - /// Get an array with the cumulative min computed at every element - fn cummin(&self, _reverse: bool) -> ChunkedArray { - panic!("operation cummin not supported for this dtype") - } - /// Get an array with the cumulative sum computed at every element - fn cumsum(&self, _reverse: bool) -> ChunkedArray { - panic!("operation cumsum not supported for this dtype") - } - /// Get an array with the cumulative product computed at every element - fn cumprod(&self, _reverse: bool) -> ChunkedArray { - panic!("operation cumprod not supported for this dtype") - } -} - /// Explode/ flatten a List or Utf8 Series pub trait ChunkExplode { fn explode(&self) -> PolarsResult { diff --git a/crates/polars-core/src/series/implementations/dates_time.rs b/crates/polars-core/src/series/implementations/dates_time.rs index 120f7a1e11b4..831495f19e9f 100644 --- a/crates/polars-core/src/series/implementations/dates_time.rs +++ b/crates/polars-core/src/series/implementations/dates_time.rs @@ -53,17 +53,6 @@ macro_rules! impl_dyn_series { .into_series() } - #[cfg(feature = "cum_agg")] - fn _cummax(&self, reverse: bool) -> Series { - self.0.cummax(reverse).$into_logical().into_series() - } - - #[cfg(feature = "cum_agg")] - fn _cummin(&self, reverse: bool) -> Series { - self.0.cummin(reverse).$into_logical().into_series() - } - - #[cfg(feature = "zip_with")] fn zip_with_same_type( &self, diff --git a/crates/polars-core/src/series/implementations/datetime.rs b/crates/polars-core/src/series/implementations/datetime.rs index 59b3bce8a5e2..071da77dc53d 100644 --- a/crates/polars-core/src/series/implementations/datetime.rs +++ b/crates/polars-core/src/series/implementations/datetime.rs @@ -49,22 +49,6 @@ impl private::PrivateSeries for SeriesWrap { .into_series() } - #[cfg(feature = "cum_agg")] - fn _cummax(&self, reverse: bool) -> Series { - self.0 - .cummax(reverse) - .into_datetime(self.0.time_unit(), self.0.time_zone().clone()) - .into_series() - } - - #[cfg(feature = "cum_agg")] - fn _cummin(&self, reverse: bool) -> Series { - self.0 - .cummin(reverse) - .into_datetime(self.0.time_unit(), self.0.time_zone().clone()) - .into_series() - } - #[cfg(feature = "zip_with")] fn zip_with_same_type(&self, mask: &BooleanChunked, other: &Series) -> PolarsResult { let other = other.to_physical_repr().into_owned(); diff --git a/crates/polars-core/src/series/implementations/duration.rs b/crates/polars-core/src/series/implementations/duration.rs index 834c6c57c181..c3dca8662f0a 100644 --- a/crates/polars-core/src/series/implementations/duration.rs +++ b/crates/polars-core/src/series/implementations/duration.rs @@ -44,22 +44,6 @@ impl private::PrivateSeries for SeriesWrap { .into_series() } - #[cfg(feature = "cum_agg")] - fn _cummax(&self, reverse: bool) -> Series { - self.0 - .cummax(reverse) - .into_duration(self.0.time_unit()) - .into_series() - } - - #[cfg(feature = "cum_agg")] - fn _cummin(&self, reverse: bool) -> Series { - self.0 - .cummin(reverse) - .into_duration(self.0.time_unit()) - .into_series() - } - fn _set_flags(&mut self, flags: Settings) { self.0.deref_mut().set_flags(flags) } diff --git a/crates/polars-core/src/series/implementations/floats.rs b/crates/polars-core/src/series/implementations/floats.rs index 92b7cb018def..0e4949ad4ae3 100644 --- a/crates/polars-core/src/series/implementations/floats.rs +++ b/crates/polars-core/src/series/implementations/floats.rs @@ -40,16 +40,6 @@ macro_rules! impl_dyn_series { self.0.explode_by_offsets(offsets) } - #[cfg(feature = "cum_agg")] - fn _cummax(&self, reverse: bool) -> Series { - self.0.cummax(reverse).into_series() - } - - #[cfg(feature = "cum_agg")] - fn _cummin(&self, reverse: bool) -> Series { - self.0.cummin(reverse).into_series() - } - unsafe fn equal_element( &self, idx_self: usize, diff --git a/crates/polars-core/src/series/implementations/mod.rs b/crates/polars-core/src/series/implementations/mod.rs index 15dc5dfd86f6..86cb1bc8efe5 100644 --- a/crates/polars-core/src/series/implementations/mod.rs +++ b/crates/polars-core/src/series/implementations/mod.rs @@ -105,16 +105,6 @@ macro_rules! impl_dyn_series { self.0.explode_by_offsets(offsets) } - #[cfg(feature = "cum_agg")] - fn _cummax(&self, reverse: bool) -> Series { - self.0.cummax(reverse).into_series() - } - - #[cfg(feature = "cum_agg")] - fn _cummin(&self, reverse: bool) -> Series { - self.0.cummin(reverse).into_series() - } - unsafe fn equal_element( &self, idx_self: usize, diff --git a/crates/polars-core/src/series/mod.rs b/crates/polars-core/src/series/mod.rs index 2ec370820c67..d8672a93ae56 100644 --- a/crates/polars-core/src/series/mod.rs +++ b/crates/polars-core/src/series/mod.rs @@ -618,94 +618,6 @@ impl Series { } } - /// Get an array with the cumulative max computed at every element. - pub fn cummax(&self, _reverse: bool) -> Series { - #[cfg(feature = "cum_agg")] - { - self._cummax(_reverse) - } - #[cfg(not(feature = "cum_agg"))] - { - panic!("activate 'cum_agg' feature") - } - } - - /// Get an array with the cumulative min computed at every element. - pub fn cummin(&self, _reverse: bool) -> Series { - #[cfg(feature = "cum_agg")] - { - self._cummin(_reverse) - } - #[cfg(not(feature = "cum_agg"))] - { - panic!("activate 'cum_agg' feature") - } - } - - /// Get an array with the cumulative sum computed at every element - /// - /// If the [`DataType`] is one of `{Int8, UInt8, Int16, UInt16}` the `Series` is - /// first cast to `Int64` to prevent overflow issues. - #[allow(unused_variables)] - pub fn cumsum(&self, reverse: bool) -> Series { - #[cfg(feature = "cum_agg")] - { - use DataType::*; - match self.dtype() { - Boolean => self.cast(&DataType::UInt32).unwrap().cumsum(reverse), - Int8 | UInt8 | Int16 | UInt16 => { - let s = self.cast(&Int64).unwrap(); - s.cumsum(reverse) - }, - Int32 => self.i32().unwrap().cumsum(reverse).into_series(), - UInt32 => self.u32().unwrap().cumsum(reverse).into_series(), - UInt64 => self.u64().unwrap().cumsum(reverse).into_series(), - Int64 => self.i64().unwrap().cumsum(reverse).into_series(), - Float32 => self.f32().unwrap().cumsum(reverse).into_series(), - Float64 => self.f64().unwrap().cumsum(reverse).into_series(), - #[cfg(feature = "dtype-duration")] - Duration(tu) => { - let ca = self.to_physical_repr(); - let ca = ca.i64().unwrap(); - ca.cumsum(reverse).cast(&Duration(*tu)).unwrap() - }, - dt => panic!("cumsum not supported for dtype: {dt:?}"), - } - } - #[cfg(not(feature = "cum_agg"))] - { - panic!("activate 'cum_agg' feature") - } - } - - /// Get an array with the cumulative product computed at every element. - /// - /// If the [`DataType`] is one of `{Int8, UInt8, Int16, UInt16, Int32, UInt32}` the `Series` is - /// first cast to `Int64` to prevent overflow issues. - #[allow(unused_variables)] - pub fn cumprod(&self, reverse: bool) -> Series { - #[cfg(feature = "cum_agg")] - { - use DataType::*; - match self.dtype() { - Boolean => self.cast(&DataType::Int64).unwrap().cumprod(reverse), - Int8 | UInt8 | Int16 | UInt16 | Int32 | UInt32 => { - let s = self.cast(&Int64).unwrap(); - s.cumprod(reverse) - }, - Int64 => self.i64().unwrap().cumprod(reverse).into_series(), - UInt64 => self.u64().unwrap().cumprod(reverse).into_series(), - Float32 => self.f32().unwrap().cumprod(reverse).into_series(), - Float64 => self.f64().unwrap().cumprod(reverse).into_series(), - dt => panic!("cumprod not supported for dtype: {dt:?}"), - } - } - #[cfg(not(feature = "cum_agg"))] - { - panic!("activate 'cum_agg' feature") - } - } - /// Get the product of an array. /// /// If the [`DataType`] is one of `{Int8, UInt8, Int16, UInt16}` the `Series` is diff --git a/crates/polars-core/src/series/series_trait.rs b/crates/polars-core/src/series/series_trait.rs index a99a59dc7999..f7ab23947447 100644 --- a/crates/polars-core/src/series/series_trait.rs +++ b/crates/polars-core/src/series/series_trait.rs @@ -88,18 +88,6 @@ pub(crate) mod private { invalid_operation_panic!(explode_by_offsets, self) } - /// Get an array with the cumulative max computed at every element - #[cfg(feature = "cum_agg")] - fn _cummax(&self, _reverse: bool) -> Series { - panic!("operation cummax not supported for this dtype") - } - - /// Get an array with the cumulative min computed at every element - #[cfg(feature = "cum_agg")] - fn _cummin(&self, _reverse: bool) -> Series { - panic!("operation cummin not supported for this dtype") - } - unsafe fn equal_element( &self, _idx_self: usize, diff --git a/crates/polars-ops/Cargo.toml b/crates/polars-ops/Cargo.toml index c4a4e35bc3b1..821ad9290c82 100644 --- a/crates/polars-ops/Cargo.toml +++ b/crates/polars-ops/Cargo.toml @@ -111,3 +111,4 @@ is_in = ["polars-core/reinterpret"] convert_index = [] repeat_by = [] peaks = [] +cum_agg = [] diff --git a/crates/polars-ops/src/series/ops/cum_agg.rs b/crates/polars-ops/src/series/ops/cum_agg.rs new file mode 100644 index 000000000000..12faeb193ac5 --- /dev/null +++ b/crates/polars-ops/src/series/ops/cum_agg.rs @@ -0,0 +1,268 @@ +use std::iter::FromIterator; +use std::ops::{Add, AddAssign, Mul}; + +use num_traits::Bounded; +use polars_core::prelude::*; +use polars_core::utils::{CustomIterTools, NoNull}; +use polars_core::with_match_physical_numeric_polars_type; + +fn det_max(state: &mut T, v: Option) -> Option> +where + T: Copy + PartialOrd + AddAssign + Add, +{ + match v { + Some(v) => { + if v > *state { + *state = v + } + Some(Some(*state)) + }, + None => Some(None), + } +} + +fn det_min(state: &mut T, v: Option) -> Option> +where + T: Copy + PartialOrd + AddAssign + Add, +{ + match v { + Some(v) => { + if v < *state { + *state = v + } + Some(Some(*state)) + }, + None => Some(None), + } +} + +fn det_sum(state: &mut Option, v: Option) -> Option> +where + T: Copy + PartialOrd + AddAssign + Add, +{ + match (*state, v) { + (Some(state_inner), Some(v)) => { + *state = Some(state_inner + v); + Some(*state) + }, + (None, Some(v)) => { + *state = Some(v); + Some(*state) + }, + (_, None) => Some(None), + } +} + +fn det_prod(state: &mut Option, v: Option) -> Option> +where + T: Copy + PartialOrd + Mul, +{ + match (*state, v) { + (Some(state_inner), Some(v)) => { + *state = Some(state_inner * v); + Some(*state) + }, + (None, Some(v)) => { + *state = Some(v); + Some(*state) + }, + (_, None) => Some(None), + } +} + +fn cummax_numeric(ca: &ChunkedArray, reverse: bool) -> ChunkedArray +where + T: PolarsNumericType, + ChunkedArray: FromIterator>, +{ + let init = Bounded::min_value(); + + let out: ChunkedArray = match reverse { + false => ca.into_iter().scan(init, det_max).collect_trusted(), + true => ca.into_iter().rev().scan(init, det_max).collect_reversed(), + }; + out.with_name(ca.name()) +} + +fn cummin_numeric(ca: &ChunkedArray, reverse: bool) -> ChunkedArray +where + T: PolarsNumericType, + ChunkedArray: FromIterator>, +{ + let init = Bounded::max_value(); + let out: ChunkedArray = match reverse { + false => ca.into_iter().scan(init, det_min).collect_trusted(), + true => ca.into_iter().rev().scan(init, det_min).collect_reversed(), + }; + out.with_name(ca.name()) +} + +fn cumsum_numeric(ca: &ChunkedArray, reverse: bool) -> ChunkedArray +where + T: PolarsNumericType, + ChunkedArray: FromIterator>, +{ + let init = None; + let out: ChunkedArray = match reverse { + false => ca.into_iter().scan(init, det_sum).collect_trusted(), + true => ca.into_iter().rev().scan(init, det_sum).collect_reversed(), + }; + out.with_name(ca.name()) +} + +fn cumprod_numeric(ca: &ChunkedArray, reverse: bool) -> ChunkedArray +where + T: PolarsNumericType, + ChunkedArray: FromIterator>, +{ + let init = None; + let out: ChunkedArray = match reverse { + false => ca.into_iter().scan(init, det_prod).collect_trusted(), + true => ca.into_iter().rev().scan(init, det_prod).collect_reversed(), + }; + out.with_name(ca.name()) +} + +/// Get an array with the cumulative product computed at every element. +/// +/// If the [`DataType`] is one of `{Int8, UInt8, Int16, UInt16, Int32, UInt32}` the `Series` is +/// first cast to `Int64` to prevent overflow issues. +pub fn cumprod(s: &Series, reverse: bool) -> PolarsResult { + use DataType::*; + let out = match s.dtype() { + Boolean | Int8 | UInt8 | Int16 | UInt16 | Int32 | UInt32 => { + let s = s.cast(&Int64)?; + cumprod_numeric(s.i64()?, reverse).into_series() + }, + Int64 => cumprod_numeric(s.i64()?, reverse).into_series(), + UInt64 => cumprod_numeric(s.u64()?, reverse).into_series(), + Float32 => cumprod_numeric(s.f32()?, reverse).into_series(), + Float64 => cumprod_numeric(s.f64()?, reverse).into_series(), + dt => polars_bail!(opq = cumprod, dt), + }; + Ok(out) +} + +/// Get an array with the cumulative sum computed at every element +/// +/// If the [`DataType`] is one of `{Int8, UInt8, Int16, UInt16}` the `Series` is +/// first cast to `Int64` to prevent overflow issues. +pub fn cumsum(s: &Series, reverse: bool) -> PolarsResult { + use DataType::*; + let out = match s.dtype() { + Boolean => { + let s = s.cast(&UInt32)?; + cumsum_numeric(s.u32()?, reverse).into_series() + }, + Int8 | UInt8 | Int16 | UInt16 => { + let s = s.cast(&Int64)?; + cumsum_numeric(s.i64()?, reverse).into_series() + }, + Int32 => cumsum_numeric(s.i32()?, reverse).into_series(), + UInt32 => cumsum_numeric(s.u32()?, reverse).into_series(), + Int64 => cumsum_numeric(s.i64()?, reverse).into_series(), + UInt64 => cumsum_numeric(s.u64()?, reverse).into_series(), + Float32 => cumsum_numeric(s.f32()?, reverse).into_series(), + Float64 => cumsum_numeric(s.f64()?, reverse).into_series(), + #[cfg(feature = "dtype-duration")] + Duration(tu) => { + let s = s.to_physical_repr(); + let ca = s.i64()?; + cumsum_numeric(ca, reverse).cast(&Duration(*tu))? + }, + dt => polars_bail!(opq = cumsum, dt), + }; + Ok(out) +} + +/// Get an array with the cumulative min computed at every element. +pub fn cummin(s: &Series, reverse: bool) -> PolarsResult { + let original_type = s.dtype(); + let s = s.to_physical_repr(); + match s.dtype() { + dt if dt.is_numeric() => { + with_match_physical_numeric_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + let out = cummin_numeric(ca, reverse).into_series(); + if original_type.is_logical(){ + out.cast(original_type) + }else{ + Ok(out) + } + }) + }, + dt => polars_bail!(opq = cummin, dt), + } +} + +/// Get an array with the cumulative max computed at every element. +pub fn cummax(s: &Series, reverse: bool) -> PolarsResult { + let original_type = s.dtype(); + let s = s.to_physical_repr(); + match s.dtype() { + dt if dt.is_numeric() => { + with_match_physical_numeric_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + let out = cummax_numeric(ca, reverse).into_series(); + if original_type.is_logical(){ + out.cast(original_type) + }else{ + Ok(out) + } + }) + }, + dt => polars_bail!(opq = cummin, dt), + } +} + +pub fn cumcount(s: &Series, reverse: bool) -> PolarsResult { + if reverse { + let ca: NoNull = (0u32..s.len() as u32).rev().collect(); + let mut ca = ca.into_inner(); + ca.rename(s.name()); + Ok(ca.into_series()) + } else { + let ca: NoNull = (0u32..s.len() as u32).collect(); + let mut ca = ca.into_inner(); + ca.rename(s.name()); + Ok(ca.into_series()) + } +} + +#[cfg(test)] +mod test { + use polars_core::prelude::*; + + #[test] + #[cfg(feature = "dtype-u8")] + fn test_cummax() { + let ca = UInt8Chunked::new("foo", &[None, Some(1), Some(3), None, Some(1)]); + let out = ca.cummax(true); + assert_eq!(Vec::from(&out), &[None, Some(3), Some(3), None, Some(1)]); + let out = ca.cummax(false); + assert_eq!(Vec::from(&out), &[None, Some(1), Some(3), None, Some(3)]); + } + + #[test] + #[cfg(feature = "dtype-u8")] + fn test_cummin() { + let ca = UInt8Chunked::new("foo", &[None, Some(1), Some(3), None, Some(2)]); + let out = ca.cummin(true); + assert_eq!(Vec::from(&out), &[None, Some(1), Some(2), None, Some(2)]); + let out = ca.cummin(false); + assert_eq!(Vec::from(&out), &[None, Some(1), Some(1), None, Some(1)]); + } + + #[test] + fn test_cumsum() { + let ca = Int32Chunked::new("foo", &[None, Some(1), Some(3), None, Some(1)]); + let out = ca.cumsum(true); + assert_eq!(Vec::from(&out), &[None, Some(5), Some(4), None, Some(1)]); + let out = ca.cumsum(false); + assert_eq!(Vec::from(&out), &[None, Some(1), Some(4), None, Some(5)]); + + // just check if the trait bounds allow for floats + let ca = Float32Chunked::new("foo", &[None, Some(1.0), Some(3.0), None, Some(1.0)]); + let _out = ca.cumsum(false); + } +} diff --git a/crates/polars-ops/src/series/ops/mod.rs b/crates/polars-ops/src/series/ops/mod.rs index d4c10d7fd078..55fc68dc48eb 100644 --- a/crates/polars-ops/src/series/ops/mod.rs +++ b/crates/polars-ops/src/series/ops/mod.rs @@ -3,6 +3,8 @@ mod approx_algo; mod approx_unique; mod arg_min_max; mod clip; +#[cfg(feature = "cum_agg")] +mod cum_agg; #[cfg(feature = "cutqcut")] mod cut; #[cfg(feature = "round_series")] @@ -39,6 +41,8 @@ pub use approx_algo::*; pub use approx_unique::*; pub use arg_min_max::ArgAgg; pub use clip::*; +#[cfg(feature = "cum_agg")] +pub use cum_agg::*; #[cfg(feature = "cutqcut")] pub use cut::*; #[cfg(feature = "round_series")] diff --git a/crates/polars-plan/Cargo.toml b/crates/polars-plan/Cargo.toml index 89811eca50eb..a68252bf0daf 100644 --- a/crates/polars-plan/Cargo.toml +++ b/crates/polars-plan/Cargo.toml @@ -98,7 +98,7 @@ asof_join = ["polars-core/asof_join", "polars-time", "polars-ops/asof_join"] concat_str = [] range = [] mode = ["polars-ops/mode"] -cum_agg = ["polars-core/cum_agg"] +cum_agg = ["polars-ops/cum_agg"] interpolate = ["polars-ops/interpolate"] rolling_window = [ "polars-core/rolling_window", diff --git a/crates/polars-plan/src/dsl/function_expr/cum.rs b/crates/polars-plan/src/dsl/function_expr/cum.rs index d8ac6434809b..0ba1a0f6281d 100644 --- a/crates/polars-plan/src/dsl/function_expr/cum.rs +++ b/crates/polars-plan/src/dsl/function_expr/cum.rs @@ -1,33 +1,23 @@ use super::*; pub(super) fn cumcount(s: &Series, reverse: bool) -> PolarsResult { - if reverse { - let ca: NoNull = (0u32..s.len() as u32).rev().collect(); - let mut ca = ca.into_inner(); - ca.rename(s.name()); - Ok(ca.into_series()) - } else { - let ca: NoNull = (0u32..s.len() as u32).collect(); - let mut ca = ca.into_inner(); - ca.rename(s.name()); - Ok(ca.into_series()) - } + polars_ops::prelude::cumcount(s, reverse) } pub(super) fn cumsum(s: &Series, reverse: bool) -> PolarsResult { - Ok(s.cumsum(reverse)) + polars_ops::prelude::cumsum(s, reverse) } pub(super) fn cumprod(s: &Series, reverse: bool) -> PolarsResult { - Ok(s.cumprod(reverse)) + polars_ops::prelude::cumprod(s, reverse) } pub(super) fn cummin(s: &Series, reverse: bool) -> PolarsResult { - Ok(s.cummin(reverse)) + polars_ops::prelude::cummin(s, reverse) } pub(super) fn cummax(s: &Series, reverse: bool) -> PolarsResult { - Ok(s.cummax(reverse)) + polars_ops::prelude::cummax(s, reverse) } pub(super) mod dtypes { diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index cd71fc6c96a0..cf6d269e4668 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -14,6 +14,7 @@ mod clip; mod coerce; mod concat; mod correlation; +#[cfg(feature = "cum_agg")] mod cum; #[cfg(feature = "temporal")] mod datetime; @@ -152,18 +153,23 @@ pub enum FunctionExpr { #[cfg(feature = "top_k")] TopK(bool), Shift(i64), + #[cfg(feature = "cum_agg")] Cumcount { reverse: bool, }, + #[cfg(feature = "cum_agg")] Cumsum { reverse: bool, }, + #[cfg(feature = "cum_agg")] Cumprod { reverse: bool, }, + #[cfg(feature = "cum_agg")] Cummin { reverse: bool, }, + #[cfg(feature = "cum_agg")] Cummax { reverse: bool, }, @@ -385,10 +391,15 @@ impl Display for FunctionExpr { } }, Shift(_) => "shift", + #[cfg(feature = "cum_agg")] Cumcount { .. } => "cumcount", + #[cfg(feature = "cum_agg")] Cumsum { .. } => "cumsum", + #[cfg(feature = "cum_agg")] Cumprod { .. } => "cumprod", + #[cfg(feature = "cum_agg")] Cummin { .. } => "cummin", + #[cfg(feature = "cum_agg")] Cummax { .. } => "cummax", #[cfg(feature = "dtype-struct")] ValueCounts { .. } => "value_counts", @@ -679,10 +690,15 @@ impl From for SpecialEq> { map_as_slice!(top_k, descending) }, Shift(periods) => map!(dispatch::shift, periods), + #[cfg(feature = "cum_agg")] Cumcount { reverse } => map!(cum::cumcount, reverse), + #[cfg(feature = "cum_agg")] Cumsum { reverse } => map!(cum::cumsum, reverse), + #[cfg(feature = "cum_agg")] Cumprod { reverse } => map!(cum::cumprod, reverse), + #[cfg(feature = "cum_agg")] Cummin { reverse } => map!(cum::cummin, reverse), + #[cfg(feature = "cum_agg")] Cummax { reverse } => map!(cum::cummax, reverse), #[cfg(feature = "dtype-struct")] ValueCounts { sort, parallel } => map!(dispatch::value_counts, sort, parallel), diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index 9baa4a0e5c89..a787b50d8987 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -122,10 +122,15 @@ impl FunctionExpr { Boolean(func) => func.get_field(mapper), #[cfg(feature = "dtype-categorical")] Categorical(func) => func.get_field(mapper), + #[cfg(feature = "cum_agg")] Cumcount { .. } => mapper.with_dtype(IDX_DTYPE), + #[cfg(feature = "cum_agg")] Cumsum { .. } => mapper.map_dtype(cum::dtypes::cumsum), + #[cfg(feature = "cum_agg")] Cumprod { .. } => mapper.map_dtype(cum::dtypes::cumprod), + #[cfg(feature = "cum_agg")] Cummin { .. } => mapper.with_same_dtype(), + #[cfg(feature = "cum_agg")] Cummax { .. } => mapper.with_same_dtype(), #[cfg(feature = "approx_unique")] ApproxNUnique => mapper.with_dtype(IDX_DTYPE), diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index aff4c2378a12..e9e689513c74 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -54,7 +54,7 @@ use polars_core::prelude::*; #[cfg(feature = "diff")] use polars_core::series::ops::NullBehavior; use polars_core::series::IsSorted; -use polars_core::utils::{try_get_supertype, NoNull}; +use polars_core::utils::try_get_supertype; #[cfg(feature = "rolling_window")] use polars_time::prelude::SeriesOpsTime; pub(crate) use selector::Selector; @@ -743,26 +743,31 @@ impl Expr { } /// Cumulatively count values from 0 to len. + #[cfg(feature = "cum_agg")] pub fn cumcount(self, reverse: bool) -> Self { self.apply_private(FunctionExpr::Cumcount { reverse }) } /// Get an array with the cumulative sum computed at every element. + #[cfg(feature = "cum_agg")] pub fn cumsum(self, reverse: bool) -> Self { self.apply_private(FunctionExpr::Cumsum { reverse }) } /// Get an array with the cumulative product computed at every element. + #[cfg(feature = "cum_agg")] pub fn cumprod(self, reverse: bool) -> Self { self.apply_private(FunctionExpr::Cumprod { reverse }) } /// Get an array with the cumulative min computed at every element. + #[cfg(feature = "cum_agg")] pub fn cummin(self, reverse: bool) -> Self { self.apply_private(FunctionExpr::Cummin { reverse }) } /// Get an array with the cumulative max computed at every element. + #[cfg(feature = "cum_agg")] pub fn cummax(self, reverse: bool) -> Self { self.apply_private(FunctionExpr::Cummax { reverse }) } diff --git a/crates/polars/Cargo.toml b/crates/polars/Cargo.toml index 19de558f8177..b8c78f4ef0ca 100644 --- a/crates/polars/Cargo.toml +++ b/crates/polars/Cargo.toml @@ -139,7 +139,7 @@ string_encoding = ["polars-ops/string_encoding", "polars-core/strings"] binary_encoding = ["polars-ops/binary_encoding"] group_by_list = ["polars-core/group_by_list", "polars-ops/group_by_list"] lazy_regex = ["polars-lazy?/regex"] -cum_agg = ["polars-core/cum_agg", "polars-core/cum_agg"] +cum_agg = ["polars-ops/cum_agg", "polars-lazy?/cum_agg"] rolling_window = ["polars-core/rolling_window", "polars-lazy?/rolling_window", "polars-time/rolling_window"] interpolate = ["polars-ops/interpolate", "polars-lazy?/interpolate"] rank = ["polars-lazy?/rank", "polars-ops/rank"] diff --git a/crates/polars/src/lib.rs b/crates/polars/src/lib.rs index 0eaf13c040ce..c04fb989e47d 100644 --- a/crates/polars/src/lib.rs +++ b/crates/polars/src/lib.rs @@ -278,9 +278,9 @@ //! - `fmt` - Activate [`DataFrame`] formatting //! //! [`UInt64Chunked`]: crate::datatypes::UInt64Chunked -//! [`cumsum`]: crate::series::Series::cumsum -//! [`cummin`]: crate::series::Series::cummin -//! [`cummax`]: crate::series::Series::cummax +//! [`cumsum`]: polars_ops::prelude::cumsum +//! [`cummin`]: polars_ops::prelude::cummin +//! [`cummax`]: polars_ops::prelude::cummax //! [`rolling_mean`]: crate::series::Series#method.rolling_mean //! [`diff`]: crate::series::Series::diff //! [`List`]: crate::datatypes::DataType::List From 6f258312b1863eff72e9109635ffd7da87ffefbb Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Tue, 17 Oct 2023 06:21:28 +0200 Subject: [PATCH 016/103] depr(python): Deprecate `use_pyarrow` param for `Series.to_list` (#11784) --- py-polars/polars/series/series.py | 13 ++++++++++--- py-polars/tests/unit/series/test_series.py | 22 +++++++++++++--------- 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index d6036a83415c..3a01f34b02e3 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -3811,7 +3811,7 @@ def to_physical(self) -> Series: """ - def to_list(self, *, use_pyarrow: bool = False) -> list[Any]: + def to_list(self, *, use_pyarrow: bool | None = None) -> list[Any]: """ Convert this Series to a Python List. This operation clones data. @@ -3829,8 +3829,15 @@ def to_list(self, *, use_pyarrow: bool = False) -> list[Any]: """ - if use_pyarrow: - return self.to_arrow().to_pylist() + if use_pyarrow is not None: + issue_deprecation_warning( + "The parameter `use_pyarrow` for `Series.to_list` is deprecated." + " Call the method without `use_pyarrow` to silence this warning.", + version="0.19.9", + ) + if use_pyarrow: + return self.to_arrow().to_pylist() + return self._s.to_list() def rechunk(self, *, in_place: bool = False) -> Self: diff --git a/py-polars/tests/unit/series/test_series.py b/py-polars/tests/unit/series/test_series.py index 387fcec9a603..52aba0dde302 100644 --- a/py-polars/tests/unit/series/test_series.py +++ b/py-polars/tests/unit/series/test_series.py @@ -602,21 +602,25 @@ def test_to_pandas() -> None: pass -def test_to_python() -> None: - a = pl.Series("a", range(20)) - b = a.to_list() - assert isinstance(b, list) - assert len(b) == 20 - - b = a.to_list(use_pyarrow=True) - assert isinstance(b, list) - assert len(b) == 20 +def test_series_to_list() -> None: + s = pl.Series("a", range(20)) + result = s.to_list() + assert isinstance(result, list) + assert len(result) == 20 a = pl.Series("a", [1, None, 2]) assert a.null_count() == 1 assert a.to_list() == [1, None, 2] +def test_series_to_list_use_pyarrow_deprecated() -> None: + s = pl.Series("a", range(20)) + with pytest.deprecated_call(): + result = s.to_list(use_pyarrow=True) + assert isinstance(result, list) + assert len(result) == 20 + + def test_to_struct() -> None: s = pl.Series("nums", ["12 34", "56 78", "90 00"]).str.extract_all(r"\d+") From 2c6c9bd3f8cacae7e2aabcd2dc422f3e92e776f0 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Tue, 17 Oct 2023 06:22:49 +0200 Subject: [PATCH 017/103] refactor(python): Fix Exception module paths (#11785) --- py-polars/src/error.rs | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/py-polars/src/error.rs b/py-polars/src/error.rs index 4f747dfe3118..07046772316a 100644 --- a/py-polars/src/error.rs +++ b/py-polars/src/error.rs @@ -64,17 +64,17 @@ impl Debug for PyPolarsErr { } } -create_exception!(exceptions, ColumnNotFoundError, PyException); -create_exception!(exceptions, ComputeError, PyException); -create_exception!(exceptions, DuplicateError, PyException); -create_exception!(exceptions, InvalidOperationError, PyException); -create_exception!(exceptions, NoDataError, PyException); -create_exception!(exceptions, OutOfBoundsError, PyException); -create_exception!(exceptions, SchemaError, PyException); -create_exception!(exceptions, SchemaFieldNotFoundError, PyException); -create_exception!(exceptions, ShapeError, PyException); -create_exception!(exceptions, StringCacheMismatchError, PyException); -create_exception!(exceptions, StructFieldNotFoundError, PyException); +create_exception!(polars.exceptions, ColumnNotFoundError, PyException); +create_exception!(polars.exceptions, ComputeError, PyException); +create_exception!(polars.exceptions, DuplicateError, PyException); +create_exception!(polars.exceptions, InvalidOperationError, PyException); +create_exception!(polars.exceptions, NoDataError, PyException); +create_exception!(polars.exceptions, OutOfBoundsError, PyException); +create_exception!(polars.exceptions, SchemaError, PyException); +create_exception!(polars.exceptions, SchemaFieldNotFoundError, PyException); +create_exception!(polars.exceptions, ShapeError, PyException); +create_exception!(polars.exceptions, StringCacheMismatchError, PyException); +create_exception!(polars.exceptions, StructFieldNotFoundError, PyException); #[macro_export] macro_rules! raise_err( From 6b929f2920f68de78f0cc3c68f892c28bf422a10 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Tue, 17 Oct 2023 06:23:23 +0200 Subject: [PATCH 018/103] perf(python): Improve `DataFrame.get_column` performance by ~35% (#11783) --- py-polars/polars/dataframe/frame.py | 16 ++++++++-------- py-polars/src/dataframe.rs | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index aca2118fd704..572d9bfdebaa 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -1666,7 +1666,7 @@ def __getitem__( # select single column # df["foo"] if isinstance(item, str): - return wrap_s(self._df.column(item)) + return self.get_column(item) # df[idx] if isinstance(item, int): @@ -1864,7 +1864,7 @@ def item(self, row: int | None = None, column: int | str | None = None) -> Any: s = ( self._df.select_at_idx(column) if isinstance(column, int) - else self._df.column(column) + else self._df.get_column(column) ) if s is None: raise IndexError(f"column index {column!r} is out of bounds") @@ -6650,13 +6650,17 @@ def get_columns(self) -> list[Series]: def get_column(self, name: str) -> Series: """ - Get a single column as Series by name. + Get a single column by name. Parameters ---------- name : str Name of the column to retrieve. + Returns + ------- + Series + See Also -------- to_series @@ -6674,11 +6678,7 @@ def get_column(self, name: str) -> Series: ] """ - if not isinstance(name, str): - raise TypeError( - f"column name {name!r} should be be a string, but is {type(name).__name__!r}" - ) - return self[name] + return wrap_s(self._df.get_column(name)) def fill_null( self, diff --git a/py-polars/src/dataframe.rs b/py-polars/src/dataframe.rs index 582253078265..98e8bae5ce89 100644 --- a/py-polars/src/dataframe.rs +++ b/py-polars/src/dataframe.rs @@ -1030,7 +1030,7 @@ impl PyDataFrame { self.df.find_idx_by_name(name) } - pub fn column(&self, name: &str) -> PyResult { + pub fn get_column(&self, name: &str) -> PyResult { let series = self .df .column(name) From dcec1e800045c038050381845fe5b347bb9d24c8 Mon Sep 17 00:00:00 2001 From: nameexhaustion Date: Tue, 17 Oct 2023 15:24:19 +1100 Subject: [PATCH 019/103] fix(rust,python): make `PyLazyGroupby` reusable (#11769) --- py-polars/src/lazygroupby.rs | 8 ++++---- py-polars/tests/unit/operations/test_group_by.py | 7 +++++++ 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/py-polars/src/lazygroupby.rs b/py-polars/src/lazygroupby.rs index d74163b8e43b..2364fad0094d 100644 --- a/py-polars/src/lazygroupby.rs +++ b/py-polars/src/lazygroupby.rs @@ -19,18 +19,18 @@ pub struct PyLazyGroupBy { #[pymethods] impl PyLazyGroupBy { fn agg(&mut self, aggs: Vec) -> PyLazyFrame { - let lgb = self.lgb.take().unwrap(); + let lgb = self.lgb.clone().unwrap(); let aggs = aggs.to_exprs(); lgb.agg(aggs).into() } fn head(&mut self, n: usize) -> PyLazyFrame { - let lgb = self.lgb.take().unwrap(); + let lgb = self.lgb.clone().unwrap(); lgb.head(Some(n)).into() } fn tail(&mut self, n: usize) -> PyLazyFrame { - let lgb = self.lgb.take().unwrap(); + let lgb = self.lgb.clone().unwrap(); lgb.tail(Some(n)).into() } @@ -39,7 +39,7 @@ impl PyLazyGroupBy { lambda: PyObject, schema: Option>, ) -> PyResult { - let lgb = self.lgb.take().unwrap(); + let lgb = self.lgb.clone().unwrap(); let schema = match schema { Some(schema) => Arc::new(schema.0), None => LazyFrame::from(lgb.logical_plan.clone()) diff --git a/py-polars/tests/unit/operations/test_group_by.py b/py-polars/tests/unit/operations/test_group_by.py index aa638d3b29ff..1902459bb05a 100644 --- a/py-polars/tests/unit/operations/test_group_by.py +++ b/py-polars/tests/unit/operations/test_group_by.py @@ -822,3 +822,10 @@ def test_group_by_with_expr_as_key() -> None: # tests: 11766 assert gb.head(0).frame_equal(gb.agg(pl.col("x").head(0)).explode("x")) assert gb.tail(0).frame_equal(gb.agg(pl.col("x").tail(0)).explode("x")) + + +def test_lazy_group_by_reuse_11767() -> None: + lgb = pl.select(x=1).lazy().group_by("x") + a = lgb.count() + b = lgb.count() + assert a.collect().frame_equal(b.collect()) From ef503c3d0c60f5b8266c8d02582caeaf0f4b7026 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Tue, 17 Oct 2023 07:00:28 +0200 Subject: [PATCH 020/103] fix(python): Fix values printed by `assert_*_equal` AssertionError when `exact=False` (#11781) --- py-polars/docs/source/reference/testing.rst | 2 + py-polars/polars/expr/expr.py | 2 +- py-polars/polars/series/series.py | 2 +- py-polars/polars/testing/__init__.py | 2 - py-polars/polars/testing/asserts.py | 187 ++++++++---------- .../polars/testing/parametric/primitives.py | 3 +- py-polars/pyproject.toml | 1 + .../tests/unit/{ => testing}/test_testing.py | 20 +- 8 files changed, 100 insertions(+), 119 deletions(-) rename py-polars/tests/unit/{ => testing}/test_testing.py (98%) diff --git a/py-polars/docs/source/reference/testing.rst b/py-polars/docs/source/reference/testing.rst index 4e268cec77dd..78ce4c96a0bd 100644 --- a/py-polars/docs/source/reference/testing.rst +++ b/py-polars/docs/source/reference/testing.rst @@ -25,7 +25,9 @@ Polars provides some standard asserts for use with unit tests: :toctree: api/ testing.assert_frame_equal + testing.assert_frame_not_equal testing.assert_series_equal + testing.assert_series_not_equal Parametric testing diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index a86da90bc996..a05f85c9134b 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -4709,7 +4709,7 @@ def ne(self, other: Any) -> Self: def ne_missing(self, other: Any) -> Self: """ - Method equivalent of equality operator ``expr != other`` where `None` == None`. + Method equivalent of equality operator ``expr != other`` where ``None == None``. This differs from default ``ne`` where null values are propagated. diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 3a01f34b02e3..f9e8a72e6d36 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -669,7 +669,7 @@ def ne_missing(self, other: Any) -> Self: def ne_missing(self, other: Any) -> Self | Expr: """ - Method equivalent of equality operator ``series != other`` where `None` == None`. + Method equivalent of equality operator ``series != other`` where ``None == None``. This differs from the standard ``ne`` where null values are propagated. diff --git a/py-polars/polars/testing/__init__.py b/py-polars/polars/testing/__init__.py index 13cf07939044..2461de6ba6ff 100644 --- a/py-polars/polars/testing/__init__.py +++ b/py-polars/polars/testing/__init__.py @@ -1,6 +1,5 @@ from polars.testing.asserts import ( assert_frame_equal, - assert_frame_equal_local_categoricals, assert_frame_not_equal, assert_series_equal, assert_series_not_equal, @@ -11,5 +10,4 @@ "assert_series_not_equal", "assert_frame_equal", "assert_frame_not_equal", - "assert_frame_equal_local_categoricals", ] diff --git a/py-polars/polars/testing/asserts.py b/py-polars/polars/testing/asserts.py index 7fdd92edb89a..643b7eb231fe 100644 --- a/py-polars/polars/testing/asserts.py +++ b/py-polars/polars/testing/asserts.py @@ -1,6 +1,5 @@ from __future__ import annotations -import textwrap from typing import Any, NoReturn from polars import functions as F @@ -9,7 +8,6 @@ FLOAT_DTYPES, UNSIGNED_INTEGER_DTYPES, Categorical, - DataTypeClass, List, Struct, UInt64, @@ -20,7 +18,6 @@ from polars.exceptions import ComputeError, InvalidAssert from polars.lazyframe import LazyFrame from polars.series import Series -from polars.utils.deprecation import deprecate_function def assert_frame_equal( @@ -81,7 +78,9 @@ def assert_frame_equal( elif isinstance(left, DataFrame) and isinstance(right, DataFrame): objs = "DataFrames" else: - raise_assert_detail("Inputs", "unexpected input types", type(left), type(right)) + _raise_assertion_error( + "Inputs", "unexpected input types", type(left), type(right) + ) if left_not_right := [c for c in left.columns if c not in right.columns]: raise AssertionError( @@ -102,13 +101,13 @@ def assert_frame_equal( if check_dtype: # check this _before_ we collect left_schema, right_schema = left.schema, right.schema if left_schema != right_schema: - raise_assert_detail( + _raise_assertion_error( objs, "lazy schemas are not equal", left_schema, right_schema ) left, right = left.collect(), right.collect() # type: ignore[union-attr] if left.shape[0] != right.shape[0]: # type: ignore[union-attr] - raise_assert_detail(objs, "length mismatch", left.shape, right.shape) # type: ignore[union-attr] + _raise_assertion_error(objs, "length mismatch", left.shape, right.shape) # type: ignore[union-attr] if not check_row_order: try: @@ -133,7 +132,7 @@ def assert_frame_equal( categorical_as_str=categorical_as_str, ) except AssertionError as exc: - msg = f"values for column {c!r} are different." + msg = f"values for column {c!r} are different" raise AssertionError(msg) from exc @@ -254,17 +253,16 @@ def assert_series_equal( >>> assert_series_equal(s1, s2) # doctest: +SKIP """ - if not ( - isinstance(left, Series) # type: ignore[redundant-expr] - and isinstance(right, Series) - ): - raise_assert_detail("Inputs", "unexpected input types", type(left), type(right)) + if not (isinstance(left, Series) and isinstance(right, Series)): # type: ignore[redundant-expr] + _raise_assertion_error( + "Inputs", "unexpected input types", type(left), type(right) + ) if len(left) != len(right): - raise_assert_detail("Series", "length mismatch", len(left), len(right)) + _raise_assertion_error("Series", "length mismatch", len(left), len(right)) if check_names and left.name != right.name: - raise_assert_detail("Series", "name mismatch", left.name, right.name) + _raise_assertion_error("Series", "name mismatch", left.name, right.name) _assert_series_inner( left, @@ -355,12 +353,7 @@ def _assert_series_inner( ) -> None: """Compare Series dtype + values.""" if check_dtype and left.dtype != right.dtype: - raise_assert_detail("Series", "dtype mismatch", left.dtype, right.dtype) - - if left.null_count() != right.null_count(): - raise_assert_detail( - "Series", "null_count is not equal", left.null_count(), right.null_count() - ) + _raise_assertion_error("Series", "dtype mismatch", left.dtype, right.dtype) if categorical_as_str and left.dtype == Categorical: left = left.cast(Utf8) @@ -370,14 +363,14 @@ def _assert_series_inner( unequal = left.ne_missing(right) # handle NaN values (which compare unequal to themselves) - comparing_float_dtypes = left.dtype in FLOAT_DTYPES and right.dtype in FLOAT_DTYPES + comparing_floats = left.dtype in FLOAT_DTYPES and right.dtype in FLOAT_DTYPES if unequal.any() and nans_compare_equal: # when both dtypes are scalar floats - if comparing_float_dtypes: + if comparing_floats: unequal = unequal & ~( (left.is_nan() & right.is_nan()).fill_null(F.lit(False)) ) - if comparing_float_dtypes and not nans_compare_equal: + if comparing_floats and not nans_compare_equal: unequal = unequal | left.is_nan() | right.is_nan() # check nested dtypes in separate function @@ -406,50 +399,74 @@ def _assert_series_inner( # assert exact, or with tolerance if unequal.any(): if check_exact: - raise_assert_detail( + _raise_assertion_error( "Series", "exact value mismatch", - left=list(left), - right=list(right), + left=left.to_list(), + right=right.to_list(), ) else: - # apply check with tolerance (to the known-unequal matches). - left, right = left.filter(unequal), right.filter(unequal) + equal, nan_info = _check_series_equal_inexact( + left, + right, + unequal, + atol=atol, + rtol=rtol, + nans_compare_equal=nans_compare_equal, + comparing_floats=comparing_floats, + ) - if all(tp in UNSIGNED_INTEGER_DTYPES for tp in (left.dtype, right.dtype)): - # avoid potential "subtract-with-overflow" panic on uint math - s_diff = Series( - "diff", [abs(v1 - v2) for v1, v2 in zip(left, right)], dtype=UInt64 - ) - else: - s_diff = (left - right).abs() - - mismatch, nan_info = False, "" - if ((s_diff > (atol + rtol * right.abs())).sum() != 0) or ( - left.is_null() != right.is_null() - ).any(): - mismatch = True - elif comparing_float_dtypes: - # note: take special care with NaN values. - # if NaNs don't compare as equal, any NaN in the left Series is - # sufficient for a mismatch because the if condition above already - # compares the null values. - if not nans_compare_equal and left.is_nan().any(): - nan_info = " (nans_compare_equal=False)" - mismatch = True - elif (left.is_nan() != right.is_nan()).any(): - nan_info = f" (nans_compare_equal={nans_compare_equal})" - mismatch = True - - if mismatch: - raise_assert_detail( + if not equal: + _raise_assertion_error( "Series", f"value mismatch{nan_info}", - left=list(left), - right=list(right), + left=left.to_list(), + right=right.to_list(), ) +def _check_series_equal_inexact( + left: Series, + right: Series, + unequal: Series, + atol: float, + rtol: float, + *, + nans_compare_equal: bool, + comparing_floats: bool, +) -> tuple[bool, str]: + # apply check with tolerance (to the known-unequal matches). + left, right = left.filter(unequal), right.filter(unequal) + + if all(tp in UNSIGNED_INTEGER_DTYPES for tp in (left.dtype, right.dtype)): + # avoid potential "subtract-with-overflow" panic on uint math + s_diff = Series( + "diff", [abs(v1 - v2) for v1, v2 in zip(left, right)], dtype=UInt64 + ) + else: + s_diff = (left - right).abs() + + equal, nan_info = True, "" + if ((s_diff > (atol + rtol * right.abs())).sum() != 0) or ( + left.is_null() != right.is_null() + ).any(): + equal = False + + elif comparing_floats: + # note: take special care with NaN values. + # if NaNs don't compare as equal, any NaN in the left Series is + # sufficient for a mismatch because the if condition above already + # compares the null values. + if not nans_compare_equal and left.is_nan().any(): + equal = False + nan_info = " (nans_compare_equal=False)" + elif (left.is_nan() != right.is_nan()).any(): + equal = False + nan_info = f" (nans_compare_equal={nans_compare_equal})" + + return equal, nan_info + + def _assert_series_nested( left: Series, right: Series, @@ -472,16 +489,16 @@ def _assert_series_nested( if nans_compare_equal: continue else: - raise_assert_detail( + _raise_assertion_error( "Series", f"Nested value mismatch (nans_compare_equal={nans_compare_equal})", s1, s2, ) elif (s1 is None and s2 is not None) or (s2 is None and s1 is not None): - raise_assert_detail("Series", "nested value mismatch", s1, s2) + _raise_assertion_error("Series", "nested value mismatch", s1, s2) elif len(s1) != len(s2): - raise_assert_detail( + _raise_assertion_error( "Series", "nested list length mismatch", len(s1), len(s2) ) @@ -501,14 +518,14 @@ def _assert_series_nested( elif left.dtype == Struct == right.dtype: ls, rs = left.struct.unnest(), right.struct.unnest() if len(ls.columns) != len(rs.columns): - raise_assert_detail( + _raise_assertion_error( "Series", "nested struct fields mismatch", len(ls.columns), len(rs.columns), ) elif len(ls) != len(rs): - raise_assert_detail( + _raise_assertion_error( "Series", "nested struct length mismatch", len(ls), len(rs) ) for s1, s2 in zip(ls, rs): @@ -529,53 +546,13 @@ def _assert_series_nested( return False -def raise_assert_detail( +def _raise_assertion_error( obj: str, detail: str, left: Any, right: Any, - exc: AssertionError | None = None, ) -> NoReturn: """Raise a detailed assertion error.""" __tracebackhide__ = True - - error_msg = textwrap.dedent( - f"""\ - {obj} are different ({detail}) - [left]: {left} - [right]: {right}\ - """ - ) - - raise AssertionError(error_msg) from exc - - -def is_categorical_dtype(data_type: Any) -> bool: - """Check if the input is a polars Categorical dtype.""" - return ( - type(data_type) is DataTypeClass - and issubclass(data_type, Categorical) - or isinstance(data_type, Categorical) - ) - - -@deprecate_function( - "Use `assert_frame_equal` instead and pass `categorical_as_str=True`.", - version="0.18.13", -) -def assert_frame_equal_local_categoricals(df_a: DataFrame, df_b: DataFrame) -> None: - """Assert frame equal for frames containing categoricals.""" - for (a_name, a_value), (b_name, b_value) in zip( - df_a.schema.items(), df_b.schema.items() - ): - if a_name != b_name: - print(f"{a_name} != {b_name}") - raise AssertionError - if a_value != b_value: - print(f"{a_value} != {b_value}") - raise AssertionError - - cat_to_str = F.col(Categorical).cast(str) - assert_frame_equal(df_a.with_columns(cat_to_str), df_b.with_columns(cat_to_str)) - cat_to_phys = F.col(Categorical).to_physical() - assert_frame_equal(df_a.with_columns(cat_to_phys), df_b.with_columns(cat_to_phys)) + msg = f"{obj} are different ({detail})\n[left]: {left}\n[right]: {right}" + raise AssertionError(msg) diff --git a/py-polars/polars/testing/parametric/primitives.py b/py-polars/polars/testing/parametric/primitives.py index 5f3c34730993..27194b56c5e5 100644 --- a/py-polars/polars/testing/parametric/primitives.py +++ b/py-polars/polars/testing/parametric/primitives.py @@ -26,7 +26,6 @@ ) from polars.series import Series from polars.string_cache import StringCache -from polars.testing.asserts import is_categorical_dtype from polars.testing.parametric.strategies import ( _flexhash, all_strategies, @@ -431,7 +430,7 @@ def draw_series(draw: DrawFn) -> Series: dtype=series_dtype, values=series_values, ) - if is_categorical_dtype(dtype): + if dtype == Categorical: s = s.cast(Categorical) if series_size and (chunked or (chunked is None and draw(booleans()))): split_at = series_size // 2 diff --git a/py-polars/pyproject.toml b/py-polars/pyproject.toml index 21ce228ce519..24b6ac8e6e46 100644 --- a/py-polars/pyproject.toml +++ b/py-polars/pyproject.toml @@ -190,6 +190,7 @@ strict = true [tool.pytest.ini_options] addopts = [ + "--tb=short", "--strict-config", "--strict-markers", "--import-mode=importlib", diff --git a/py-polars/tests/unit/test_testing.py b/py-polars/tests/unit/testing/test_testing.py similarity index 98% rename from py-polars/tests/unit/test_testing.py rename to py-polars/tests/unit/testing/test_testing.py index c44ea15961db..c000abcc7ef6 100644 --- a/py-polars/tests/unit/test_testing.py +++ b/py-polars/tests/unit/testing/test_testing.py @@ -10,7 +10,6 @@ from polars.exceptions import InvalidAssert from polars.testing import ( assert_frame_equal, - assert_frame_equal_local_categoricals, assert_frame_not_equal, assert_series_equal, assert_series_not_equal, @@ -102,7 +101,7 @@ def test_compare_series_nulls() -> None: srs2 = pl.Series([1, None, None]) assert_series_not_equal(srs1, srs2) - with pytest.raises(AssertionError, match="null_count is not equal"): + with pytest.raises(AssertionError, match="value mismatch"): assert_series_equal(srs1, srs2) @@ -326,7 +325,7 @@ def test_assert_frame_equal_ignore_row_order() -> None: df1 = pl.DataFrame({"a": [1, 2], "b": [4, 3]}) df2 = pl.DataFrame({"a": [2, 1], "b": [3, 4]}) df3 = pl.DataFrame({"b": [3, 4], "a": [2, 1]}) - with pytest.raises(AssertionError, match="values for column 'a' are different."): + with pytest.raises(AssertionError, match="values for column 'a' are different"): assert_frame_equal(df1, df2) assert_frame_equal(df1, df2, check_row_order=False) @@ -1035,8 +1034,13 @@ def test_assert_series_equal_categorical_vs_str() -> None: assert_series_equal(s1, s2, categorical_as_str=True) -def test_assert_frame_equal_local_categoricals_deprecated() -> None: - df = pl.Series(["a", "b", "a"], dtype=pl.Categorical).to_frame() - - with pytest.deprecated_call(): - assert_frame_equal_local_categoricals(df, df) +def test_assert_series_equal_full_series() -> None: + s1 = pl.Series([1, 2, 3]) + s2 = pl.Series([1, 2, 4]) + msg = ( + r"Series are different \(value mismatch\)\n" + r"\[left\]: \[1, 2, 3\]\n" + r"\[right\]: \[1, 2, 4\]" + ) + with pytest.raises(AssertionError, match=msg): + assert_series_equal(s1, s2) From 4fb3f070ab0c4af5f606a26f8b9ae879d4cd8801 Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Tue, 17 Oct 2023 13:04:05 +0800 Subject: [PATCH 021/103] docs(python): Minor tweak in code example in section Coming from Pandas (#11764) --- docs/user-guide/migration/pandas.md | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/docs/user-guide/migration/pandas.md b/docs/user-guide/migration/pandas.md index d6674ac43f06..a9a039a7b1d4 100644 --- a/docs/user-guide/migration/pandas.md +++ b/docs/user-guide/migration/pandas.md @@ -147,19 +147,20 @@ called `hundredXValue` where the `value` column is multiplied by 100. In `Pandas` this would be: ```python -df["tenXValue"] = df["value"] * 10 -df["hundredXValue"] = df["value"] * 100 +df.assign( + tenXValue=lambda df_: df_.value * 10, + hundredXValue=lambda df_: df_.value * 100 +) ``` These column assignments are executed sequentially. -In `Polars` we add columns to `df` using the `.with_columns` method and name them with -the `.alias` method: +In `Polars` we add columns to `df` using the `.with_columns` method: ```python df.with_columns( - (pl.col("value") * 10).alias("tenXValue"), - (pl.col("value") * 100).alias("hundredXValue"), + tenXValue=pl.col("value") * 10, + hundredXValue=pl.col("value") * 100, ) ``` @@ -174,7 +175,7 @@ the values in column `a` based on a condition. When the value in column `c` is e In `Pandas` this would be: ```python -df.loc[df["c"] == 2, "a"] = df.loc[df["c"] == 2, "b"] +df.assign(a=lambda df_: df_.a.where(df_.c != 2, df_.b)) ``` while in `Polars` this would be: @@ -187,21 +188,17 @@ df.with_columns( ) ``` -The `Polars` way is pure in that the original `DataFrame` is not modified. The `mask` is -also not computed twice as in `Pandas` (you could prevent this in `Pandas`, but that -would require setting a temporary variable). - -Additionally `Polars` can compute every branch of an `if -> then -> otherwise` in +`Polars` can compute every branch of an `if -> then -> otherwise` in parallel. This is valuable, when the branches get more expensive to compute. #### Filtering We want to filter the dataframe `df` with housing data based on some criteria. -In `Pandas` you filter the dataframe by passing Boolean expressions to the `loc` method: +In `Pandas` you filter the dataframe by passing Boolean expressions to the `query` method: ```python -df.loc[(df['sqft_living'] > 2500) & (df['price'] < 300000)] +df.query('m2_living > 2500 and price < 300000') ``` while in `Polars` you call the `filter` method: From d00a43203b3ade009a5f858f4c698b6a50f5b1e6 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Tue, 17 Oct 2023 10:24:15 +0400 Subject: [PATCH 022/103] fix: handle logical types in plugins (#11788) --- crates/polars-ffi/src/lib.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/crates/polars-ffi/src/lib.rs b/crates/polars-ffi/src/lib.rs index 59b2b0b8d9e9..1be2c01c477b 100644 --- a/crates/polars-ffi/src/lib.rs +++ b/crates/polars-ffi/src/lib.rs @@ -97,11 +97,7 @@ pub unsafe fn import_series(e: SeriesExport) -> PolarsResult { }) .collect::>>()?; - Ok(Series::from_chunks_and_dtype_unchecked( - &field.name, - chunks, - &(&field.data_type).into(), - )) + Series::try_from((field.name.as_str(), chunks)) } /// # Safety From 798372435b8eb340434c7bc33183b244fed5ab0d Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Tue, 17 Oct 2023 14:46:31 +0800 Subject: [PATCH 023/103] feat: Expressify pct_change and move to ops (#11786) --- crates/polars-core/Cargo.toml | 1 - crates/polars-core/src/series/ops/mod.rs | 2 - .../polars-core/src/series/ops/pct_change.rs | 48 -------------- crates/polars-ops/Cargo.toml | 1 + crates/polars-ops/src/series/ops/mod.rs | 4 ++ .../polars-ops/src/series/ops/pct_change.rs | 62 +++++++++++++++++++ crates/polars-plan/Cargo.toml | 2 +- .../src/dsl/function_expr/dispatch.rs | 5 ++ .../polars-plan/src/dsl/function_expr/mod.rs | 6 ++ .../src/dsl/function_expr/schema.rs | 5 ++ crates/polars-plan/src/dsl/mod.rs | 12 +--- crates/polars/Cargo.toml | 2 +- py-polars/polars/expr/expr.py | 3 +- py-polars/polars/series/series.py | 2 +- py-polars/src/expr/general.rs | 4 +- py-polars/tests/unit/series/test_series.py | 1 + 16 files changed, 93 insertions(+), 67 deletions(-) delete mode 100644 crates/polars-core/src/series/ops/pct_change.rs create mode 100644 crates/polars-ops/src/series/ops/pct_change.rs diff --git a/crates/polars-core/Cargo.toml b/crates/polars-core/Cargo.toml index 8a14cce9c024..7c28a57e21a4 100644 --- a/crates/polars-core/Cargo.toml +++ b/crates/polars-core/Cargo.toml @@ -92,7 +92,6 @@ group_by_list = [] # rolling window functions rolling_window = [] diff = [] -pct_change = ["diff"] moment = [] diagonal_concat = [] horizontal_concat = [] diff --git a/crates/polars-core/src/series/ops/mod.rs b/crates/polars-core/src/series/ops/mod.rs index 8730c012b163..1cf51c1743dd 100644 --- a/crates/polars-core/src/series/ops/mod.rs +++ b/crates/polars-core/src/series/ops/mod.rs @@ -7,8 +7,6 @@ mod extend; #[cfg(feature = "moment")] pub mod moment; mod null; -#[cfg(feature = "pct_change")] -pub mod pct_change; #[cfg(feature = "round_series")] mod round; mod to_list; diff --git a/crates/polars-core/src/series/ops/pct_change.rs b/crates/polars-core/src/series/ops/pct_change.rs deleted file mode 100644 index 4135a0427bae..000000000000 --- a/crates/polars-core/src/series/ops/pct_change.rs +++ /dev/null @@ -1,48 +0,0 @@ -use crate::prelude::*; -use crate::series::ops::NullBehavior; - -impl Series { - pub fn pct_change(&self, n: i64) -> PolarsResult { - match self.dtype() { - DataType::Float64 | DataType::Float32 => {}, - _ => return self.cast(&DataType::Float64)?.pct_change(n), - } - let nn = self.fill_null(FillNullStrategy::Forward(None))?; - nn.diff(n, NullBehavior::Ignore)?.divide(&nn.shift(n)) - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_nulls() -> PolarsResult<()> { - let s = Series::new("", &[Some(1), None, Some(2), None, Some(3)]); - assert_eq!( - s.pct_change(1)?, - Series::new("", &[None, Some(0.0f64), Some(1.0), Some(0.), Some(0.5)]) - ); - Ok(()) - } - - #[test] - fn test_same() -> PolarsResult<()> { - let s = Series::new("", &[Some(1), Some(1), Some(1)]); - assert_eq!( - s.pct_change(1)?, - Series::new("", &[None, Some(0.0f64), Some(0.0)]) - ); - Ok(()) - } - - #[test] - fn test_two_periods() -> PolarsResult<()> { - let s = Series::new("", &[Some(1), Some(2), Some(4), Some(8), Some(16)]); - assert_eq!( - s.pct_change(2)?, - Series::new("", &[None, None, Some(3.0f64), Some(3.0), Some(3.0)]) - ); - Ok(()) - } -} diff --git a/crates/polars-ops/Cargo.toml b/crates/polars-ops/Cargo.toml index 821ad9290c82..3e1f83fa8978 100644 --- a/crates/polars-ops/Cargo.toml +++ b/crates/polars-ops/Cargo.toml @@ -83,6 +83,7 @@ interpolate = [] list_to_struct = ["polars-core/dtype-struct"] list_count = [] diff = ["polars-core/diff"] +pct_change = ["polars-core/diff"] strings = ["polars-core/strings"] string_justify = ["polars-core/strings"] string_from_radix = ["polars-core/strings"] diff --git a/crates/polars-ops/src/series/ops/mod.rs b/crates/polars-ops/src/series/ops/mod.rs index 55fc68dc48eb..c524a880d967 100644 --- a/crates/polars-ops/src/series/ops/mod.rs +++ b/crates/polars-ops/src/series/ops/mod.rs @@ -24,6 +24,8 @@ mod is_last_distinct; mod is_unique; #[cfg(feature = "log")] mod log; +#[cfg(feature = "pct_change")] +mod pct_change; #[cfg(feature = "rank")] mod rank; #[cfg(feature = "rle")] @@ -62,6 +64,8 @@ pub use is_last_distinct::*; pub use is_unique::*; #[cfg(feature = "log")] pub use log::*; +#[cfg(feature = "pct_change")] +pub use pct_change::*; use polars_core::prelude::*; #[cfg(feature = "rank")] pub use rank::*; diff --git a/crates/polars-ops/src/series/ops/pct_change.rs b/crates/polars-ops/src/series/ops/pct_change.rs new file mode 100644 index 000000000000..bd03bdd67e05 --- /dev/null +++ b/crates/polars-ops/src/series/ops/pct_change.rs @@ -0,0 +1,62 @@ +use polars_core::prelude::*; +use polars_core::series::ops::NullBehavior; + +pub fn pct_change(s: &Series, n: &Series) -> PolarsResult { + polars_ensure!( + n.len() == 1, + ComputeError: "n must be a single value." + ); + + match s.dtype() { + DataType::Float64 | DataType::Float32 => {}, + _ => return pct_change(&s.cast(&DataType::Float64)?, n), + } + + let fill_null_s = s.fill_null(FillNullStrategy::Forward(None))?; + + let n_s = n.cast(&DataType::Int64)?; + if let Some(n) = n_s.i64()?.get(0) { + fill_null_s + .diff(n, NullBehavior::Ignore)? + .divide(&fill_null_s.shift(n)) + } else { + Ok(Series::full_null(s.name(), s.len(), s.dtype())) + } +} + +#[cfg(test)] +mod test { + use polars_core::prelude::Series; + + use super::pct_change; + + #[test] + fn test_nulls() -> PolarsResult<()> { + let s = Series::new("", &[Some(1), None, Some(2), None, Some(3)]); + assert_eq!( + pct_change(s, Series::new("i", &[1]))?, + Series::new("", &[None, Some(0.0f64), Some(1.0), Some(0.), Some(0.5)]) + ); + Ok(()) + } + + #[test] + fn test_same() -> PolarsResult<()> { + let s = Series::new("", &[Some(1), Some(1), Some(1)]); + assert_eq!( + pct_change(s, Series::new("i", &[1]))?, + Series::new("", &[None, Some(0.0f64), Some(0.0)]) + ); + Ok(()) + } + + #[test] + fn test_two_periods() -> PolarsResult<()> { + let s = Series::new("", &[Some(1), Some(2), Some(4), Some(8), Some(16)]); + assert_eq!( + pct_change(s, Series::new("i", &[2]))?, + Series::new("", &[None, None, Some(3.0f64), Some(3.0), Some(3.0)]) + ); + Ok(()) + } +} diff --git a/crates/polars-plan/Cargo.toml b/crates/polars-plan/Cargo.toml index a68252bf0daf..8ea14f5fc7b4 100644 --- a/crates/polars-plan/Cargo.toml +++ b/crates/polars-plan/Cargo.toml @@ -108,7 +108,7 @@ rolling_window = [ ] rank = ["polars-ops/rank"] diff = ["polars-core/diff", "polars-ops/diff"] -pct_change = ["polars-core/pct_change"] +pct_change = ["polars-ops/pct_change"] moment = ["polars-core/moment", "polars-ops/moment"] abs = ["polars-core/abs"] random = ["polars-core/random"] diff --git a/crates/polars-plan/src/dsl/function_expr/dispatch.rs b/crates/polars-plan/src/dsl/function_expr/dispatch.rs index bbe11390fb01..8cd638a89b8e 100644 --- a/crates/polars-plan/src/dsl/function_expr/dispatch.rs +++ b/crates/polars-plan/src/dsl/function_expr/dispatch.rs @@ -18,6 +18,11 @@ pub(super) fn diff(s: &Series, n: i64, null_behavior: NullBehavior) -> PolarsRes s.diff(n, null_behavior) } +#[cfg(feature = "pct_change")] +pub(super) fn pct_change(s: &[Series]) -> PolarsResult { + polars_ops::prelude::pct_change(&s[0], &s[1]) +} + #[cfg(feature = "interpolate")] pub(super) fn interpolate(s: &Series, method: InterpolationMethod) -> PolarsResult { Ok(polars_ops::prelude::interpolate(s, method)) diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index cf6d269e4668..9cad3792f60b 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -190,6 +190,8 @@ pub enum FunctionExpr { ShrinkType, #[cfg(feature = "diff")] Diff(i64, NullBehavior), + #[cfg(feature = "pct_change")] + PctChange, #[cfg(feature = "interpolate")] Interpolate(InterpolationMethod), #[cfg(feature = "log")] @@ -415,6 +417,8 @@ impl Display for FunctionExpr { ShrinkType => "shrink_dtype", #[cfg(feature = "diff")] Diff(_, _) => "diff", + #[cfg(feature = "pct_change")] + PctChange => "pct_change", #[cfg(feature = "interpolate")] Interpolate(_) => "interpolate", #[cfg(feature = "log")] @@ -714,6 +718,8 @@ impl From for SpecialEq> { ShrinkType => map_owned!(shrink_type::shrink), #[cfg(feature = "diff")] Diff(n, null_behavior) => map!(dispatch::diff, n, null_behavior), + #[cfg(feature = "pct_change")] + PctChange => map_as_slice!(dispatch::pct_change), #[cfg(feature = "interpolate")] Interpolate(method) => { map!(dispatch::interpolate, method) diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index a787b50d8987..8ae84c97a907 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -147,6 +147,11 @@ impl FunctionExpr { DataType::UInt8 => DataType::Int16, dt => dt.clone(), }), + #[cfg(feature = "pct_change")] + PctChange => mapper.map_dtype(|dt| match dt { + DataType::Float64 | DataType::Float32 => dt.clone(), + _ => DataType::Float64, + }), #[cfg(feature = "interpolate")] Interpolate(method) => match method { InterpolationMethod::Linear => mapper.map_numeric_to_float_dtype(), diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index e9e689513c74..87f265bae580 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -1570,16 +1570,8 @@ impl Expr { #[cfg(feature = "pct_change")] /// Computes percentage change between values. - pub fn pct_change(self, n: i64) -> Expr { - use DataType::*; - self.apply( - move |s| s.pct_change(n).map(Some), - GetOutput::map_dtype(|dt| match dt { - Float64 | Float32 => dt.clone(), - _ => Float64, - }), - ) - .with_fmt("pct_change") + pub fn pct_change(self, n: Expr) -> Expr { + self.apply_many_private(FunctionExpr::PctChange, &[n], false, false) } #[cfg(feature = "moment")] diff --git a/crates/polars/Cargo.toml b/crates/polars/Cargo.toml index b8c78f4ef0ca..876193e089fb 100644 --- a/crates/polars/Cargo.toml +++ b/crates/polars/Cargo.toml @@ -144,7 +144,7 @@ rolling_window = ["polars-core/rolling_window", "polars-lazy?/rolling_window", " interpolate = ["polars-ops/interpolate", "polars-lazy?/interpolate"] rank = ["polars-lazy?/rank", "polars-ops/rank"] diff = ["polars-core/diff", "polars-lazy?/diff", "polars-ops/diff"] -pct_change = ["polars-core/pct_change", "polars-lazy?/pct_change"] +pct_change = ["polars-ops/pct_change", "polars-lazy?/pct_change"] moment = ["polars-core/moment", "polars-lazy?/moment", "polars-ops/moment"] range = ["polars-lazy?/range"] true_div = ["polars-lazy?/true_div"] diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index a05f85c9134b..6d365764afef 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -7464,7 +7464,7 @@ def diff(self, n: int = 1, null_behavior: NullBehavior = "ignore") -> Self: """ return self._from_pyexpr(self._pyexpr.diff(n, null_behavior)) - def pct_change(self, n: int = 1) -> Self: + def pct_change(self, n: int | IntoExprColumn = 1) -> Self: """ Computes percentage change between values. @@ -7500,6 +7500,7 @@ def pct_change(self, n: int = 1) -> Self: └──────┴────────────┘ """ + n = parse_as_expression(n) return self._from_pyexpr(self._pyexpr.pct_change(n)) def skew(self, *, bias: bool = True) -> Self: diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index f9e8a72e6d36..e23036041fe7 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -6033,7 +6033,7 @@ def diff(self, n: int = 1, null_behavior: NullBehavior = "ignore") -> Series: """ - def pct_change(self, n: int = 1) -> Series: + def pct_change(self, n: int | IntoExprColumn = 1) -> Series: """ Computes percentage change between values. diff --git a/py-polars/src/expr/general.rs b/py-polars/src/expr/general.rs index a14424027441..46137400969c 100644 --- a/py-polars/src/expr/general.rs +++ b/py-polars/src/expr/general.rs @@ -741,8 +741,8 @@ impl PyExpr { } #[cfg(feature = "pct_change")] - fn pct_change(&self, n: i64) -> Self { - self.inner.clone().pct_change(n).into() + fn pct_change(&self, n: Self) -> Self { + self.inner.clone().pct_change(n.inner).into() } fn skew(&self, bias: bool) -> Self { diff --git a/py-polars/tests/unit/series/test_series.py b/py-polars/tests/unit/series/test_series.py index 52aba0dde302..559783025df9 100644 --- a/py-polars/tests/unit/series/test_series.py +++ b/py-polars/tests/unit/series/test_series.py @@ -1267,6 +1267,7 @@ def test_pct_change() -> None: s = pl.Series("a", [1, 2, 4, 8, 16, 32, 64]) expected = pl.Series("a", [None, None, float("inf"), 3.0, 3.0, 3.0, 3.0]) assert_series_equal(s.pct_change(2), expected) + assert_series_equal(s.pct_change(pl.Series([2])), expected) # negative assert pl.Series(range(5)).pct_change(-1).to_list() == [ -1.0, From 4476fbdb6957c02e968f718e050766f6ca393e8b Mon Sep 17 00:00:00 2001 From: Armin Berres <20811121+aberres@users.noreply.github.com> Date: Tue, 17 Oct 2023 09:13:31 +0200 Subject: [PATCH 024/103] Fix typo in docs (#11776) Co-authored-by: Armin Berres --- docs/user-guide/io/cloud-storage.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/user-guide/io/cloud-storage.md b/docs/user-guide/io/cloud-storage.md index a10226a99e65..ba686a5a0f11 100644 --- a/docs/user-guide/io/cloud-storage.md +++ b/docs/user-guide/io/cloud-storage.md @@ -32,7 +32,7 @@ Polars can scan a Parquet file in lazy mode from cloud storage. We may need to p This query creates a `LazyFrame` without downloading the file. In the `LazyFrame` we have access to file metadata such as the schema. Polars uses the `object_store.rs` library internally to manage the interface with the cloud storage providers and so no extra dependencies are required in Python to scan a cloud Parquet file. -If we create a lazy query with [predicate and projection pushdowns](../lazy/optimizations.md), the query optimiszr will apply them before the file is downloaded. This can significantly reduce the amount of data that needs to be downloaded. The query evaluation is triggered by calling `collect`. +If we create a lazy query with [predicate and projection pushdowns](../lazy/optimizations.md), the query optimizer will apply them before the file is downloaded. This can significantly reduce the amount of data that needs to be downloaded. The query evaluation is triggered by calling `collect`. {{code_block('user-guide/io/cloud-storage','scan_parquet_query',[])}} From 8463def11dff7ca01de62c562691e616e1edb983 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Tue, 17 Oct 2023 11:32:26 +0400 Subject: [PATCH 025/103] fix: fix key in object-store cache (#11790) --- .../polars-io/src/cloud/object_store_setup.rs | 8 ++-- crates/polars-ops/src/series/ops/cum_agg.rs | 38 ------------------- .../polars-ops/src/series/ops/pct_change.rs | 37 ------------------ 3 files changed, 4 insertions(+), 79 deletions(-) diff --git a/crates/polars-io/src/cloud/object_store_setup.rs b/crates/polars-io/src/cloud/object_store_setup.rs index 91119dbbf248..be7c4a28d822 100644 --- a/crates/polars-io/src/cloud/object_store_setup.rs +++ b/crates/polars-io/src/cloud/object_store_setup.rs @@ -4,7 +4,7 @@ use tokio::sync::RwLock; use super::*; -type CacheKey = (CloudType, Option); +type CacheKey = (String, Option); /// A very simple cache that only stores a single object-store. /// This greatly reduces the query times as multiple object stores (when reading many small files) @@ -34,9 +34,8 @@ fn err_missing_configuration(feature: &str, scheme: &str) -> BuildResult { pub async fn build_object_store(url: &str, options: Option<&CloudOptions>) -> BuildResult { let cloud_location = CloudLocation::new(url)?; - let cloud_type = CloudType::from_str(url)?; let options = options.cloned(); - let key = (cloud_type, options); + let key = (url.to_string(), options); { let cache = OBJECT_STORE_CACHE.read().await; @@ -47,7 +46,8 @@ pub async fn build_object_store(url: &str, options: Option<&CloudOptions>) -> Bu } } - let store = match key.0 { + let cloud_type = CloudType::from_str(url)?; + let store = match cloud_type { CloudType::File => { let local = LocalFileSystem::new(); Ok::<_, PolarsError>(Arc::new(local) as Arc) diff --git a/crates/polars-ops/src/series/ops/cum_agg.rs b/crates/polars-ops/src/series/ops/cum_agg.rs index 12faeb193ac5..70fb1a0bfcf6 100644 --- a/crates/polars-ops/src/series/ops/cum_agg.rs +++ b/crates/polars-ops/src/series/ops/cum_agg.rs @@ -228,41 +228,3 @@ pub fn cumcount(s: &Series, reverse: bool) -> PolarsResult { Ok(ca.into_series()) } } - -#[cfg(test)] -mod test { - use polars_core::prelude::*; - - #[test] - #[cfg(feature = "dtype-u8")] - fn test_cummax() { - let ca = UInt8Chunked::new("foo", &[None, Some(1), Some(3), None, Some(1)]); - let out = ca.cummax(true); - assert_eq!(Vec::from(&out), &[None, Some(3), Some(3), None, Some(1)]); - let out = ca.cummax(false); - assert_eq!(Vec::from(&out), &[None, Some(1), Some(3), None, Some(3)]); - } - - #[test] - #[cfg(feature = "dtype-u8")] - fn test_cummin() { - let ca = UInt8Chunked::new("foo", &[None, Some(1), Some(3), None, Some(2)]); - let out = ca.cummin(true); - assert_eq!(Vec::from(&out), &[None, Some(1), Some(2), None, Some(2)]); - let out = ca.cummin(false); - assert_eq!(Vec::from(&out), &[None, Some(1), Some(1), None, Some(1)]); - } - - #[test] - fn test_cumsum() { - let ca = Int32Chunked::new("foo", &[None, Some(1), Some(3), None, Some(1)]); - let out = ca.cumsum(true); - assert_eq!(Vec::from(&out), &[None, Some(5), Some(4), None, Some(1)]); - let out = ca.cumsum(false); - assert_eq!(Vec::from(&out), &[None, Some(1), Some(4), None, Some(5)]); - - // just check if the trait bounds allow for floats - let ca = Float32Chunked::new("foo", &[None, Some(1.0), Some(3.0), None, Some(1.0)]); - let _out = ca.cumsum(false); - } -} diff --git a/crates/polars-ops/src/series/ops/pct_change.rs b/crates/polars-ops/src/series/ops/pct_change.rs index bd03bdd67e05..24df891b382e 100644 --- a/crates/polars-ops/src/series/ops/pct_change.rs +++ b/crates/polars-ops/src/series/ops/pct_change.rs @@ -23,40 +23,3 @@ pub fn pct_change(s: &Series, n: &Series) -> PolarsResult { Ok(Series::full_null(s.name(), s.len(), s.dtype())) } } - -#[cfg(test)] -mod test { - use polars_core::prelude::Series; - - use super::pct_change; - - #[test] - fn test_nulls() -> PolarsResult<()> { - let s = Series::new("", &[Some(1), None, Some(2), None, Some(3)]); - assert_eq!( - pct_change(s, Series::new("i", &[1]))?, - Series::new("", &[None, Some(0.0f64), Some(1.0), Some(0.), Some(0.5)]) - ); - Ok(()) - } - - #[test] - fn test_same() -> PolarsResult<()> { - let s = Series::new("", &[Some(1), Some(1), Some(1)]); - assert_eq!( - pct_change(s, Series::new("i", &[1]))?, - Series::new("", &[None, Some(0.0f64), Some(0.0)]) - ); - Ok(()) - } - - #[test] - fn test_two_periods() -> PolarsResult<()> { - let s = Series::new("", &[Some(1), Some(2), Some(4), Some(8), Some(16)]); - assert_eq!( - pct_change(s, Series::new("i", &[2]))?, - Series::new("", &[None, None, Some(3.0f64), Some(3.0), Some(3.0)]) - ); - Ok(()) - } -} From 89ffd882f9b58b8b3d5c23d3ececbb7fd53fb219 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Tue, 17 Oct 2023 12:05:47 +0400 Subject: [PATCH 026/103] python polars 0.19.9 (#11791) --- py-polars/Cargo.lock | 2 +- py-polars/Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/py-polars/Cargo.lock b/py-polars/Cargo.lock index de57f81c1845..3a3827adae7e 100644 --- a/py-polars/Cargo.lock +++ b/py-polars/Cargo.lock @@ -1926,7 +1926,7 @@ dependencies = [ [[package]] name = "py-polars" -version = "0.19.8" +version = "0.19.9" dependencies = [ "ahash", "built", diff --git a/py-polars/Cargo.toml b/py-polars/Cargo.toml index b3d8a6222c0e..a30ae13483ea 100644 --- a/py-polars/Cargo.toml +++ b/py-polars/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "py-polars" -version = "0.19.8" +version = "0.19.9" edition = "2021" [lib] From a507d67d7dfb7cfe0458218b33b74099198c7926 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Tue, 17 Oct 2023 12:18:08 +0200 Subject: [PATCH 027/103] refactor(python): Remove unused `_to_rust_syntax` util (#11795) --- py-polars/polars/testing/_private.py | 38 ---------------------------- 1 file changed, 38 deletions(-) delete mode 100644 py-polars/polars/testing/_private.py diff --git a/py-polars/polars/testing/_private.py b/py-polars/polars/testing/_private.py deleted file mode 100644 index f7a47c49c1fd..000000000000 --- a/py-polars/polars/testing/_private.py +++ /dev/null @@ -1,38 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from polars.datatypes import Utf8 - -if TYPE_CHECKING: - from polars import DataFrame, Series - - -def _to_rust_syntax(df: DataFrame) -> str: - """Utility to generate the syntax that creates a polars 'DataFrame' in Rust.""" - syntax = "df![\n" - - def format_s(s: Series) -> str: - if s.null_count() == 0: - out = str(s.to_list()).replace("'", '"') - if s.dtype != Utf8: - out = out.lower() - return out - else: - tmp = "[" - for val in s: - if val is None: - tmp += "None, " - else: - if isinstance(val, str): - tmp += f'Some("{val}"), ' - else: - val = str(val).lower() - tmp += f"Some({val}), " - tmp = tmp[:-2] + "]" - return tmp - - for s in df: - syntax += f' "{s.name}" => {format_s(s)},\n' - syntax += "]" - return syntax From 003ca4d85e02ef26918646180944415637000fba Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Tue, 17 Oct 2023 12:36:01 +0200 Subject: [PATCH 028/103] refactor(python): Minor updates to assertion utils and docstrings (#11798) --- py-polars/polars/testing/asserts.py | 188 ++++++++++++------- py-polars/tests/unit/testing/test_testing.py | 12 ++ 2 files changed, 135 insertions(+), 65 deletions(-) diff --git a/py-polars/polars/testing/asserts.py b/py-polars/polars/testing/asserts.py index 643b7eb231fe..7c44dd3a19f5 100644 --- a/py-polars/polars/testing/asserts.py +++ b/py-polars/polars/testing/asserts.py @@ -34,43 +34,63 @@ def assert_frame_equal( categorical_as_str: bool = False, ) -> None: """ - Raise detailed AssertionError if `left` does NOT equal `right`. + Assert that the left and right frame are equal. + + Raises a detailed ``AssertionError`` if the frames differ. + This function is intended for use in unit tests. Parameters ---------- left - the DataFrame to compare. + The first DataFrame or LazyFrame to compare. right - the DataFrame to compare with. + The second DataFrame or LazyFrame to compare. check_row_order - if False, frames will compare equal if the required rows are present, - irrespective of the order in which they appear; as this requires - sorting, you cannot set on frames that contain unsortable columns. + Require row order to match. + + .. note:: + Setting this to ``False`` requires sorting the data, which will fail on + frames that contain unsortable columns. check_column_order - if False, frames will compare equal if the required columns are present, - irrespective of the order in which they appear. + Require column order to match. check_dtype - if True, data types need to match exactly. + Require data types to match. check_exact - if False, test if values are within tolerance of each other - (see `rtol` & `atol`). + Require data values to match exactly. If set to ``False``, values are considered + equal when within tolerance of each other (see ``rtol`` and ``atol``). rtol - relative tolerance for inexact checking. Fraction of values in `right`. + Relative tolerance for inexact checking. Fraction of values in ``right``. atol - absolute tolerance for inexact checking. + Absolute tolerance for inexact checking. nans_compare_equal - if your assert/test requires float NaN != NaN, set this to False. + Consider NaN values to be equal. categorical_as_str Cast categorical columns to string before comparing. Enabling this helps - compare DataFrames that do not share the same string cache. + compare columns that do not share the same string cache. + + See Also + -------- + assert_series_equal + assert_frame_not_equal Examples -------- >>> from polars.testing import assert_frame_equal >>> df1 = pl.DataFrame({"a": [1, 2, 3]}) - >>> df2 = pl.DataFrame({"a": [2, 3, 4]}) + >>> df2 = pl.DataFrame({"a": [1, 5, 3]}) >>> assert_frame_equal(df1, df2) # doctest: +SKIP - AssertionError: Values for column 'a' are different. + Traceback (most recent call last): + ... + AssertionError: Series are different (value mismatch) + [left]: [1, 2, 3] + [right]: [1, 5, 3] + + The above exception was the direct cause of the following exception: + + Traceback (most recent call last): + ... + AssertionError: values for column 'a' are different + """ collect_input_frames = isinstance(left, LazyFrame) and isinstance(right, LazyFrame) if collect_input_frames: @@ -79,23 +99,23 @@ def assert_frame_equal( objs = "DataFrames" else: _raise_assertion_error( - "Inputs", "unexpected input types", type(left), type(right) + "Inputs", + "unexpected input types", + type(left).__name__, + type(right).__name__, ) if left_not_right := [c for c in left.columns if c not in right.columns]: - raise AssertionError( - f"columns {left_not_right!r} in left frame, but not in right" - ) + msg = f"columns {left_not_right!r} in left frame, but not in right" + raise AssertionError(msg) if right_not_left := [c for c in right.columns if c not in left.columns]: - raise AssertionError( - f"columns {right_not_left!r} in right frame, but not in left" - ) + msg = f"columns {right_not_left!r} in right frame, but not in left" + raise AssertionError(msg) if check_column_order and left.columns != right.columns: - raise AssertionError( - f"columns are not in the same order:\n{left.columns!r}\n{right.columns!r}" - ) + msg = f"columns are not in the same order:\n{left.columns!r}\n{right.columns!r}" + raise AssertionError(msg) if collect_input_frames: if check_dtype: # check this _before_ we collect @@ -114,9 +134,8 @@ def assert_frame_equal( left = left.sort(by=left.columns) right = right.sort(by=left.columns) except ComputeError as exc: - raise InvalidAssert( - "cannot set `check_row_order=False` on frame with unsortable columns" - ) from exc + msg = "cannot set `check_row_order=False` on frame with unsortable columns" + raise InvalidAssert(msg) from exc # note: does not assume a particular column order for c in left.columns: @@ -150,42 +169,53 @@ def assert_frame_not_equal( categorical_as_str: bool = False, ) -> None: """ - Raise AssertionError if `left` DOES equal `right`. + Assert that the left and right frame are **not** equal. + + This function is intended for use in unit tests. Parameters ---------- left - the DataFrame to compare. + The first DataFrame or LazyFrame to compare. right - the DataFrame to compare with. + The second DataFrame or LazyFrame to compare. check_row_order - if False, frames will compare equal if the required rows are present, - irrespective of the order in which they appear; as this requires - sorting, you cannot set on frames that contain unsortable columns. + Require row order to match. + + .. note:: + Setting this to ``False`` requires sorting the data, which will fail on + frames that contain unsortable columns. check_column_order - if False, frames will compare equal if the required columns are present, - irrespective of the order in which they appear. + Require column order to match. check_dtype - if True, data types need to match exactly. + Require data types to match. check_exact - if False, test if values are within tolerance of each other - (see `rtol` & `atol`). + Require data values to match exactly. If set to ``False``, values are considered + equal when within tolerance of each other (see ``rtol`` and ``atol``). rtol - relative tolerance for inexact checking. Fraction of values in `right`. + Relative tolerance for inexact checking. Fraction of values in ``right``. atol - absolute tolerance for inexact checking. + Absolute tolerance for inexact checking. nans_compare_equal - if your assert/test requires float NaN != NaN, set this to False. + Consider NaN values to be equal. categorical_as_str Cast categorical columns to string before comparing. Enabling this helps - compare DataFrames that do not share the same string cache. + compare columns that do not share the same string cache. + + See Also + -------- + assert_frame_equal + assert_series_not_equal Examples -------- >>> from polars.testing import assert_frame_not_equal >>> df1 = pl.DataFrame({"a": [1, 2, 3]}) - >>> df2 = pl.DataFrame({"a": [2, 3, 4]}) - >>> assert_frame_not_equal(df1, df2) + >>> df2 = pl.DataFrame({"a": [1, 2, 3]}) + >>> assert_frame_not_equal(df1, df2) # doctest: +SKIP + Traceback (most recent call last): + ... + AssertionError: frames are equal """ try: @@ -204,7 +234,8 @@ def assert_frame_not_equal( except AssertionError: return else: - raise AssertionError("expected the input frames to be unequal") + msg = "frames are equal" + raise AssertionError(msg) def assert_series_equal( @@ -220,42 +251,58 @@ def assert_series_equal( categorical_as_str: bool = False, ) -> None: """ - Raise detailed AssertionError if `left` does NOT equal `right`. + Assert that the left and right Series are equal. + + Raises a detailed ``AssertionError`` if the Series differ. + This function is intended for use in unit tests. Parameters ---------- left - the series to compare. + The first Series to compare. right - the series to compare with. + The second Series to compare. check_dtype - if True, data types need to match exactly. + Require data types to match. check_names - if True, names need to match. + Require names to match. check_exact - if False, test if values are within tolerance of each other - (see `rtol` & `atol`). + Require data values to match exactly. If set to ``False``, values are considered + equal when within tolerance of each other (see ``rtol`` and ``atol``). rtol - relative tolerance for inexact checking. Fraction of values in `right`. + Relative tolerance for inexact checking. Fraction of values in ``right``. atol - absolute tolerance for inexact checking. + Absolute tolerance for inexact checking. nans_compare_equal - if your assert/test requires float NaN != NaN, set this to False. + Consider NaN values to be equal. categorical_as_str Cast categorical columns to string before comparing. Enabling this helps - compare DataFrames that do not share the same string cache. + compare columns that do not share the same string cache. + + See Also + -------- + assert_frame_equal + assert_series_not_equal Examples -------- >>> from polars.testing import assert_series_equal >>> s1 = pl.Series([1, 2, 3]) - >>> s2 = pl.Series([2, 3, 4]) + >>> s2 = pl.Series([1, 5, 3]) >>> assert_series_equal(s1, s2) # doctest: +SKIP + Traceback (most recent call last): + ... + AssertionError: Series are different (value mismatch) + [left]: [1, 2, 3] + [right]: [1, 5, 3] """ if not (isinstance(left, Series) and isinstance(right, Series)): # type: ignore[redundant-expr] _raise_assertion_error( - "Inputs", "unexpected input types", type(left), type(right) + "Inputs", + "unexpected input types", + type(left).__name__, + type(right).__name__, ) if len(left) != len(right): @@ -289,7 +336,9 @@ def assert_series_not_equal( categorical_as_str: bool = False, ) -> None: """ - Raise AssertionError if `left` DOES equal `right`. + Assert that the left and right Series are **not** equal. + + This function is intended for use in unit tests. Parameters ---------- @@ -314,12 +363,20 @@ def assert_series_not_equal( Cast categorical columns to string before comparing. Enabling this helps compare DataFrames that do not share the same string cache. + See Also + -------- + assert_series_equal + assert_frame_not_equal + Examples -------- >>> from polars.testing import assert_series_not_equal >>> s1 = pl.Series([1, 2, 3]) - >>> s2 = pl.Series([2, 3, 4]) - >>> assert_series_not_equal(s1, s2) + >>> s2 = pl.Series([1, 2, 3]) + >>> assert_series_not_equal(s1, s2) # doctest: +SKIP + Traceback (most recent call last): + ... + AssertionError: Series are equal """ try: @@ -337,7 +394,8 @@ def assert_series_not_equal( except AssertionError: return else: - raise AssertionError("expected the input Series to be unequal") + msg = "Series are equal" + raise AssertionError(msg) def _assert_series_inner( diff --git a/py-polars/tests/unit/testing/test_testing.py b/py-polars/tests/unit/testing/test_testing.py index c000abcc7ef6..74f352dd0bf0 100644 --- a/py-polars/tests/unit/testing/test_testing.py +++ b/py-polars/tests/unit/testing/test_testing.py @@ -1044,3 +1044,15 @@ def test_assert_series_equal_full_series() -> None: ) with pytest.raises(AssertionError, match=msg): assert_series_equal(s1, s2) + + +def test_assert_frame_not_equal() -> None: + df = pl.DataFrame({"a": [1, 2]}) + with pytest.raises(AssertionError, match="frames are equal"): + assert_frame_not_equal(df, df) + + +def test_assert_series_not_equal() -> None: + s = pl.Series("a", [1, 2]) + with pytest.raises(AssertionError, match="Series are equal"): + assert_series_not_equal(s, s) From 32e7a249ac021b73e555d7c2489e1c9e37fe1250 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Tue, 17 Oct 2023 13:48:32 +0200 Subject: [PATCH 029/103] chore(python): Bump lint dependencies (#11802) --- .github/workflows/lint-global.yml | 2 +- py-polars/polars/series/utils.py | 2 +- py-polars/polars/utils/_async.py | 4 ++-- py-polars/polars/utils/_construction.py | 2 +- py-polars/requirements-lint.txt | 6 +++--- py-polars/tests/test_udfs.py | 2 +- .../tests/unit/operations/test_aggregations.py | 1 - py-polars/tests/unit/series/test_series.py | 4 ++-- py-polars/tests/unit/sql/test_sql.py | 16 +++++++--------- py-polars/tests/unit/test_exprs.py | 4 ++-- 10 files changed, 20 insertions(+), 23 deletions(-) diff --git a/.github/workflows/lint-global.yml b/.github/workflows/lint-global.yml index 2ebcc0dca3b0..07344c893be9 100644 --- a/.github/workflows/lint-global.yml +++ b/.github/workflows/lint-global.yml @@ -15,4 +15,4 @@ jobs: - name: Lint Markdown and TOML uses: dprint/check@v2.2 - name: Spell Check with Typos - uses: crate-ci/typos@v1.16.8 + uses: crate-ci/typos@v1.16.20 diff --git a/py-polars/polars/series/utils.py b/py-polars/polars/series/utils.py index 00ab66ba5e5e..9e5c95477b74 100644 --- a/py-polars/polars/series/utils.py +++ b/py-polars/polars/series/utils.py @@ -89,7 +89,7 @@ def _undecorated(function: Callable[P, T]) -> Callable[P, T]: def call_expr(func: SeriesMethod) -> SeriesMethod: """Dispatch Series method to an expression implementation.""" - @wraps(func) # type: ignore[arg-type] + @wraps(func) def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> Series: s = wrap_s(self._s) expr = F.col(s.name) diff --git a/py-polars/polars/utils/_async.py b/py-polars/polars/utils/_async.py index 42ddfe85c313..3294bca9428f 100644 --- a/py-polars/polars/utils/_async.py +++ b/py-polars/polars/utils/_async.py @@ -24,8 +24,8 @@ def __init__(self) -> None: "polars.collect_all_async(gevent=True)" ) - from gevent.event import AsyncResult # type: ignore[import] - from gevent.hub import get_hub # type: ignore[import] + from gevent.event import AsyncResult # type: ignore[import-untyped] + from gevent.hub import get_hub # type: ignore[import-untyped] self._value: None | Exception | PyDataFrame | list[PyDataFrame] = None self._result = AsyncResult() diff --git a/py-polars/polars/utils/_construction.py b/py-polars/polars/utils/_construction.py index b98bfe41a764..9e38b3e4d38f 100644 --- a/py-polars/polars/utils/_construction.py +++ b/py-polars/polars/utils/_construction.py @@ -860,7 +860,7 @@ def dict_to_pydf( lambda t: pl.Series(t[0], t[1]) if isinstance(t[1], np.ndarray) else t[1], - [(k, v) for k, v in data.items()], + list(data.items()), ), ) ) diff --git a/py-polars/requirements-lint.txt b/py-polars/requirements-lint.txt index 0748b30ce8e8..20bdfbddf117 100644 --- a/py-polars/requirements-lint.txt +++ b/py-polars/requirements-lint.txt @@ -1,5 +1,5 @@ black==23.9.1 blackdoc==0.3.8 -mypy==1.5.1 -ruff==0.0.287 -typos==1.16.8 +mypy==1.6.0 +ruff==0.1.0 +typos==1.16.20 diff --git a/py-polars/tests/test_udfs.py b/py-polars/tests/test_udfs.py index 956edaecd1f6..c3f9c32360fa 100644 --- a/py-polars/tests/test_udfs.py +++ b/py-polars/tests/test_udfs.py @@ -159,7 +159,7 @@ ) def test_bytecode_parser_expression(col: str, func: str, expected: str) -> None: try: - import udfs # type: ignore[import] + import udfs # type: ignore[import-not-found] except ModuleNotFoundError as exc: assert "No module named 'udfs'" in str(exc) # noqa: PT017 # Skip test if udfs can't be imported because it's not in the path. diff --git a/py-polars/tests/unit/operations/test_aggregations.py b/py-polars/tests/unit/operations/test_aggregations.py index 5dff526eabc7..0cd22d787120 100644 --- a/py-polars/tests/unit/operations/test_aggregations.py +++ b/py-polars/tests/unit/operations/test_aggregations.py @@ -132,7 +132,6 @@ def test_quantile_vs_numpy(tp: type, n: int) -> None: np_result = np.quantile(a, q) except IndexError: np_result = None - pass if np_result: # nan check if np_result != np_result: diff --git a/py-polars/tests/unit/series/test_series.py b/py-polars/tests/unit/series/test_series.py index 559783025df9..3fab208b1131 100644 --- a/py-polars/tests/unit/series/test_series.py +++ b/py-polars/tests/unit/series/test_series.py @@ -1438,10 +1438,10 @@ def test_bitwise() -> None: # ensure mistaken use of logical 'and'/'or' raises an exception with pytest.raises(TypeError, match="ambiguous"): - a and b + a and b # type: ignore[redundant-expr] with pytest.raises(TypeError, match="ambiguous"): - a or b + a or b # type: ignore[redundant-expr] def test_to_numpy(monkeypatch: Any) -> None: diff --git a/py-polars/tests/unit/sql/test_sql.py b/py-polars/tests/unit/sql/test_sql.py index 73faef160377..ca281c380ab0 100644 --- a/py-polars/tests/unit/sql/test_sql.py +++ b/py-polars/tests/unit/sql/test_sql.py @@ -942,12 +942,14 @@ def test_sql_trim(foods_ipc_path: Path) -> None: "BY NAME", [(1, "zz"), (2, "yy"), (3, "xx")], ), - ( - # note: waiting for "https://github.com/sqlparser-rs/sqlparser-rs/pull/997" + pytest.param( ["c1", "c2"], ["c2", "c1"], "DISTINCT BY NAME", - None, # [(1, "zz"), (2, "yy"), (3, "xx")], + [(1, "zz"), (2, "yy"), (3, "xx")], + # TODO: Remove xfail marker when supported added in sqlparser-rs + # https://github.com/sqlparser-rs/sqlparser-rs/pull/997 + marks=pytest.mark.xfail, ), ], ) @@ -955,7 +957,7 @@ def test_sql_union( cols1: list[str], cols2: list[str], union_subtype: str, - expected: dict[str, list[int] | list[str]] | None, + expected: list[tuple[int, str]], ) -> None: with pl.SQLContext( frame1=pl.DataFrame({"c1": [1, 2], "c2": ["zz", "yy"]}), @@ -967,11 +969,7 @@ def test_sql_union( UNION {union_subtype} SELECT {', '.join(cols2)} FROM frame2 """ - if expected is not None: - assert sorted(ctx.execute(query).rows()) == expected - else: - with pytest.raises(pl.ComputeError, match="sql parser error"): - ctx.execute(query) + assert sorted(ctx.execute(query).rows()) == expected def test_sql_nullif_coalesce(foods_ipc_path: Path) -> None: diff --git a/py-polars/tests/unit/test_exprs.py b/py-polars/tests/unit/test_exprs.py index 8518755eb4f1..a75fa521663a 100644 --- a/py-polars/tests/unit/test_exprs.py +++ b/py-polars/tests/unit/test_exprs.py @@ -428,10 +428,10 @@ def test_logical_boolean() -> None: # note, cannot use expressions in logical # boolean context (eg: and/or/not operators) with pytest.raises(TypeError, match="ambiguous"): - pl.col("colx") and pl.col("coly") + pl.col("colx") and pl.col("coly") # type: ignore[redundant-expr] with pytest.raises(TypeError, match="ambiguous"): - pl.col("colx") or pl.col("coly") + pl.col("colx") or pl.col("coly") # type: ignore[redundant-expr] df = pl.DataFrame({"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, 5]}) From f63014e5b48486ad2b04937d202a3b7c08ed2c17 Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Tue, 17 Oct 2023 15:01:07 +0200 Subject: [PATCH 030/103] fix: patch broken aHash AES intrinsics on ARM (#11801) --- Cargo.toml | 1 + py-polars/Cargo.lock | 3 +-- py-polars/Cargo.toml | 3 +++ 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index d99bf9e0bf25..06145d401f60 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -106,5 +106,6 @@ features = [ ] [patch.crates-io] +ahash = { git = "https://github.com/orlp/aHash", branch = "fix-arm-intrinsics" } # packed_simd_2 = { git = "https://github.com/rust-lang/packed_simd", rev = "e57c7ba11386147e6d2cbad7c88f376aab4bdc86" } # simd-json = { git = "https://github.com/ritchie46/simd-json", branch = "alignment" } diff --git a/py-polars/Cargo.lock b/py-polars/Cargo.lock index 3a3827adae7e..b492b67de12e 100644 --- a/py-polars/Cargo.lock +++ b/py-polars/Cargo.lock @@ -26,8 +26,7 @@ checksum = "aae1277d39aeec15cb388266ecc24b11c80469deae6067e17a1a7aa9e5c1f234" [[package]] name = "ahash" version = "0.8.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f" +source = "git+https://github.com/orlp/aHash?branch=fix-arm-intrinsics#80685f88d3c120ef39fb3fde1c7786b044af5e8b" dependencies = [ "cfg-if", "getrandom", diff --git a/py-polars/Cargo.toml b/py-polars/Cargo.toml index a30ae13483ea..445646088c1d 100644 --- a/py-polars/Cargo.toml +++ b/py-polars/Cargo.toml @@ -235,6 +235,9 @@ lto = "thin" codegen-units = 1 lto = "fat" +[patch.crates-io] +ahash = { git = "https://github.com/orlp/aHash", branch = "fix-arm-intrinsics" } + # This is ignored here; would be set in .cargo/config.toml. # Should not be used when packaging # target-cpu = "native" From d85c452e50ec0c6b0c929a3392cb51bea575df04 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Tue, 17 Oct 2023 15:01:34 +0200 Subject: [PATCH 031/103] depr(python): Deprecate non-keyword args for `ewm` methods (#11804) --- py-polars/polars/expr/expr.py | 4 +++- py-polars/polars/series/series.py | 15 +++++++++++++++ py-polars/tests/unit/series/test_series.py | 4 +++- 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 6d365764afef..608cb0db4840 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -8300,6 +8300,7 @@ def sample( self._pyexpr.sample_n(n, with_replacement, shuffle, seed) ) + @deprecate_nonkeyword_arguments(version="0.19.10") def ewm_mean( self, com: float | None = None, @@ -8367,7 +8368,6 @@ def ewm_mean( :math:`1-\alpha` and :math:`1` if ``adjust=True``, and :math:`1-\alpha` and :math:`\alpha` if ``adjust=False``. - Examples -------- >>> df = pl.DataFrame({"a": [1, 2, 3]}) @@ -8389,6 +8389,7 @@ def ewm_mean( self._pyexpr.ewm_mean(alpha, adjust, min_periods, ignore_nulls) ) + @deprecate_nonkeyword_arguments(version="0.19.10") def ewm_std( self, com: float | None = None, @@ -8481,6 +8482,7 @@ def ewm_std( self._pyexpr.ewm_std(alpha, adjust, bias, min_periods, ignore_nulls) ) + @deprecate_nonkeyword_arguments(version="0.19.10") def ewm_var( self, com: float | None = None, diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index e23036041fe7..14572ca9480f 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -6413,6 +6413,7 @@ def shuffle(self, seed: int | None = None) -> Series: """ + @deprecate_nonkeyword_arguments(version="0.19.10") def ewm_mean( self, com: float | None = None, @@ -6480,8 +6481,21 @@ def ewm_mean( :math:`1-\alpha` and :math:`1` if ``adjust=True``, and :math:`1-\alpha` and :math:`\alpha` if ``adjust=False``. + Examples + -------- + >>> s = pl.Series([1, 2, 3]) + >>> s.ewm_mean(com=1) + shape: (3,) + Series: '' [f64] + [ + 1.0 + 1.666667 + 2.428571 + ] + """ + @deprecate_nonkeyword_arguments(version="0.19.10") def ewm_std( self, com: float | None = None, @@ -6567,6 +6581,7 @@ def ewm_std( """ + @deprecate_nonkeyword_arguments(version="0.19.10") def ewm_var( self, com: float | None = None, diff --git a/py-polars/tests/unit/series/test_series.py b/py-polars/tests/unit/series/test_series.py index 3fab208b1131..54ba6288e2fa 100644 --- a/py-polars/tests/unit/series/test_series.py +++ b/py-polars/tests/unit/series/test_series.py @@ -2136,7 +2136,9 @@ def test_ewm_mean() -> None: def test_ewm_mean_leading_nulls() -> None: for min_periods in [1, 2, 3]: assert ( - pl.Series([1, 2, 3, 4]).ewm_mean(3, min_periods=min_periods).null_count() + pl.Series([1, 2, 3, 4]) + .ewm_mean(com=3, min_periods=min_periods) + .null_count() == min_periods - 1 ) assert pl.Series([None, 1.0, 1.0, 1.0]).ewm_mean( From 8d29d3cebec713363db4ad5d782c74047e24314d Mon Sep 17 00:00:00 2001 From: nameexhaustion Date: Wed, 18 Oct 2023 00:58:58 +1100 Subject: [PATCH 032/103] fix(rust,python): ensure projections containing only hive columns are projected (#11803) --- .../parquet/read/deserialize/nested_utils.rs | 15 +++--- crates/polars-io/src/parquet/read_impl.rs | 49 +++++++++++++------ py-polars/tests/unit/io/test_hive.py | 12 +++++ 3 files changed, 53 insertions(+), 23 deletions(-) diff --git a/crates/polars-arrow/src/io/parquet/read/deserialize/nested_utils.rs b/crates/polars-arrow/src/io/parquet/read/deserialize/nested_utils.rs index 9466b93cb7dc..ad53b72a7e6f 100644 --- a/crates/polars-arrow/src/io/parquet/read/deserialize/nested_utils.rs +++ b/crates/polars-arrow/src/io/parquet/read/deserialize/nested_utils.rs @@ -499,15 +499,13 @@ where D: NestedDecoder<'a>, { // front[a1, a2, a3, ...]back - if items.len() > 1 { - return MaybeNext::Some(Ok(items.pop_front().unwrap())); + if *remaining == 0 && items.is_empty() { + return MaybeNext::None; } - if *remaining == 0 { - return match items.pop_front() { - Some(decoded) => MaybeNext::Some(Ok(decoded)), - None => MaybeNext::None, - }; + if !items.is_empty() && items.front().unwrap().0.len() > chunk_size.unwrap_or(usize::MAX) { + return MaybeNext::Some(Ok(items.pop_front().unwrap())); } + match iter.next() { Err(e) => MaybeNext::Some(Err(e.into())), Ok(None) => { @@ -541,7 +539,8 @@ where Err(e) => return MaybeNext::Some(Err(e)), }; - if (items.len() == 1) + // if possible, return the value immediately. + if !items.is_empty() && items.front().unwrap().0.len() > chunk_size.unwrap_or(usize::MAX) { MaybeNext::Some(Ok(items.pop_front().unwrap())) diff --git a/crates/polars-io/src/parquet/read_impl.rs b/crates/polars-io/src/parquet/read_impl.rs index 5827f44a57e1..51ec4669f25c 100644 --- a/crates/polars-io/src/parquet/read_impl.rs +++ b/crates/polars-io/src/parquet/read_impl.rs @@ -96,10 +96,15 @@ pub(super) fn array_iter_to_series( } /// Materializes hive partitions. -fn materialize_hive_partitions(df: &mut DataFrame, hive_partition_columns: Option<&[Series]>) { +/// We have a special num_rows arg, as df can be empty when a projection contains +/// only hive partition columns. +/// Safety: num_rows equals the height of the df when the df height is non-zero. +fn materialize_hive_partitions( + df: &mut DataFrame, + hive_partition_columns: Option<&[Series]>, + num_rows: usize, +) { if let Some(hive_columns) = hive_partition_columns { - let num_rows = df.height(); - for s in hive_columns { unsafe { df.with_column_unchecked(s.new_from_index(0, num_rows)) }; } @@ -191,6 +196,7 @@ fn rg_to_dfs_optionally_par_over_columns( assert!(std::env::var("POLARS_PANIC_IF_PARQUET_PARSED").is_err()) } + let projection_height = (*remaining_rows).min(md.num_rows()); let chunk_size = md.num_rows(); let columns = if let ParallelStrategy::Columns = parallel { POOL.install(|| { @@ -200,7 +206,7 @@ fn rg_to_dfs_optionally_par_over_columns( column_idx_to_series( *column_i, md, - *remaining_rows, + projection_height, schema, store, chunk_size, @@ -212,20 +218,26 @@ fn rg_to_dfs_optionally_par_over_columns( projection .iter() .map(|column_i| { - column_idx_to_series(*column_i, md, *remaining_rows, schema, store, chunk_size) + column_idx_to_series( + *column_i, + md, + projection_height, + schema, + store, + chunk_size, + ) }) .collect::>>()? }; - *remaining_rows = - remaining_rows.saturating_sub(file_metadata.row_groups[rg_idx].num_rows()); + *remaining_rows -= projection_height; let mut df = DataFrame::new_no_checks(columns); if let Some(rc) = &row_count { df.with_row_count_mut(&rc.name, Some(*previous_row_count + rc.offset)); } - materialize_hive_partitions(&mut df, hive_partition_columns); + materialize_hive_partitions(&mut df, hive_partition_columns, projection_height); apply_predicate(&mut df, predicate, true)?; *previous_row_count += current_row_count; @@ -265,17 +277,17 @@ fn rg_to_dfs_par_over_rg( let row_count_start = *previous_row_count; let num_rows = rg_md.num_rows(); *previous_row_count += num_rows as IdxSize; - let local_limit = *remaining_rows; - *remaining_rows = remaining_rows.saturating_sub(num_rows); + let projection_height = (*remaining_rows).min(num_rows); + *remaining_rows -= projection_height; - (rg_idx, rg_md, local_limit, row_count_start) + (rg_idx, rg_md, projection_height, row_count_start) }) .collect::>(); let dfs = row_groups .into_par_iter() - .map(|(rg_idx, md, local_limit, row_count_start)| { - if local_limit == 0 + .map(|(rg_idx, md, projection_height, row_count_start)| { + if projection_height == 0 || use_statistics && !read_this_row_group(predicate, &file_metadata.row_groups[rg_idx], schema)? { @@ -291,7 +303,14 @@ fn rg_to_dfs_par_over_rg( let columns = projection .iter() .map(|column_i| { - column_idx_to_series(*column_i, md, local_limit, schema, store, chunk_size) + column_idx_to_series( + *column_i, + md, + projection_height, + schema, + store, + chunk_size, + ) }) .collect::>>()?; @@ -300,8 +319,8 @@ fn rg_to_dfs_par_over_rg( if let Some(rc) = &row_count { df.with_row_count_mut(&rc.name, Some(row_count_start as IdxSize + rc.offset)); } - materialize_hive_partitions(&mut df, hive_partition_columns); + materialize_hive_partitions(&mut df, hive_partition_columns, projection_height); apply_predicate(&mut df, predicate, false)?; Ok(Some(df)) diff --git a/py-polars/tests/unit/io/test_hive.py b/py-polars/tests/unit/io/test_hive.py index 83fce2da1194..4e0bea5316f6 100644 --- a/py-polars/tests/unit/io/test_hive.py +++ b/py-polars/tests/unit/io/test_hive.py @@ -98,3 +98,15 @@ def test_hive_partitioned_projection_pushdown( columns = ["sugars_g", "category"] for streaming in [True, False]: assert q.select(columns).collect(streaming=streaming).columns == columns + + # test that hive partition columns are projected with the correct height when + # the projection contains only hive partition columns (11796) + for parallel in ("row_groups", "columns"): + q = pl.scan_parquet( + root / "**/*.parquet", hive_partitioning=True, parallel=parallel # type: ignore[arg-type] + ) + + expect = q.collect().select("category") + actual = q.select("category").collect() + + assert expect.frame_equal(actual) From 00082c58ecad5c8c03e2c52cff8c4fe17fc9f8d0 Mon Sep 17 00:00:00 2001 From: Brayan Jules Date: Tue, 17 Oct 2023 18:37:40 -0300 Subject: [PATCH 033/103] fix: removing additional asserts from unit test, also improved pattern matching on timestamp casting --- crates/polars-arrow/src/compute/cast/mod.rs | 40 ++++----------------- py-polars/tests/unit/test_queries.py | 10 ------ 2 files changed, 6 insertions(+), 44 deletions(-) diff --git a/crates/polars-arrow/src/compute/cast/mod.rs b/crates/polars-arrow/src/compute/cast/mod.rs index 4324526381ca..d33525285c7d 100644 --- a/crates/polars-arrow/src/compute/cast/mod.rs +++ b/crates/polars-arrow/src/compute/cast/mod.rs @@ -580,23 +580,9 @@ pub fn cast( LargeUtf8 => Ok(Box::new(utf8_to_large_utf8( array.as_any().downcast_ref().unwrap(), ))), - Timestamp(TimeUnit::Nanosecond, None) => { - utf8_to_naive_timestamp_dyn::(array, TimeUnit::Nanosecond) - }, - Timestamp(TimeUnit::Millisecond, None) => { - utf8_to_naive_timestamp_dyn::(array, TimeUnit::Millisecond) - }, - Timestamp(TimeUnit::Microsecond, None) => { - utf8_to_naive_timestamp_dyn::(array, TimeUnit::Microsecond) - }, - Timestamp(TimeUnit::Nanosecond, Some(tz)) => { - utf8_to_timestamp_dyn::(array, tz.clone(), TimeUnit::Nanosecond) - }, - Timestamp(TimeUnit::Millisecond, Some(tz)) => { - utf8_to_timestamp_dyn::(array, tz.clone(), TimeUnit::Millisecond) - }, - Timestamp(TimeUnit::Microsecond, Some(tz)) => { - utf8_to_timestamp_dyn::(array, tz.clone(), TimeUnit::Microsecond) + Timestamp(tu, None) => utf8_to_naive_timestamp_dyn::(array, tu.to_owned()), + Timestamp(tu, Some(tz)) => { + utf8_to_timestamp_dyn::(array, tz.clone(), tu.to_owned()) }, _ => polars_bail!(InvalidOperation: "casting from {from_type:?} to {to_type:?} not supported", @@ -621,23 +607,9 @@ pub fn cast( to_type.clone(), ) .boxed()), - Timestamp(TimeUnit::Nanosecond, None) => { - utf8_to_naive_timestamp_dyn::(array, TimeUnit::Nanosecond) - }, - Timestamp(TimeUnit::Millisecond, None) => { - utf8_to_naive_timestamp_dyn::(array, TimeUnit::Millisecond) - }, - Timestamp(TimeUnit::Microsecond, None) => { - utf8_to_naive_timestamp_dyn::(array, TimeUnit::Microsecond) - }, - Timestamp(TimeUnit::Nanosecond, Some(tz)) => { - utf8_to_timestamp_dyn::(array, tz.clone(), TimeUnit::Nanosecond) - }, - Timestamp(TimeUnit::Millisecond, Some(tz)) => { - utf8_to_timestamp_dyn::(array, tz.clone(), TimeUnit::Millisecond) - }, - Timestamp(TimeUnit::Microsecond, Some(tz)) => { - utf8_to_timestamp_dyn::(array, tz.clone(), TimeUnit::Microsecond) + Timestamp(tu, None) => utf8_to_naive_timestamp_dyn::(array, tu.to_owned()), + Timestamp(tu, Some(tz)) => { + utf8_to_timestamp_dyn::(array, tz.clone(), tu.to_owned()) }, _ => polars_bail!(InvalidOperation: "casting from {from_type:?} to {to_type:?} not supported", diff --git a/py-polars/tests/unit/test_queries.py b/py-polars/tests/unit/test_queries.py index 48e116c39656..14835176cc80 100644 --- a/py-polars/tests/unit/test_queries.py +++ b/py-polars/tests/unit/test_queries.py @@ -380,8 +380,6 @@ def test_utf8_date() -> None: ) expected = pl.DataFrame({"x1-date": [date(2021, 1, 1)]}) out = df.select(pl.col("x1-date")) - assert out.shape == (1, 1) - assert out.dtypes == [pl.Date] assert_frame_equal(expected, out) @@ -419,8 +417,6 @@ def test_utf8_datetime() -> None: out = df.select( pl.col("x1-datetime-ns"), pl.col("x1-datetime-ms"), pl.col("x1-datetime-us") ) - assert out.shape == (2, 3) - assert out.dtypes == [pl.Datetime, pl.Datetime, pl.Datetime] assert_frame_equal(expected, out) @@ -477,10 +473,4 @@ def test_utf8_datetime_timezone() -> None: pl.col("x1-datetime-ns"), pl.col("x1-datetime-ms"), pl.col("x1-datetime-us") ) - assert out.shape == (2, 3) - assert out.dtypes == [ - pl.Datetime("ns", "America/Caracas"), - pl.Datetime("ms", "America/Santiago"), - pl.Datetime("us", "UTC"), - ] assert_frame_equal(expected, out) From 45009eb022fd46915fddd9fdddee425d1a9fec58 Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Wed, 18 Oct 2023 13:14:07 +0800 Subject: [PATCH 034/103] refactor(rust): Make some functions in dsl::mod non-anonymous (#11799) --- crates/polars-ops/src/series/ops/rank.rs | 8 ++- .../src/dsl/function_expr/dispatch.rs | 29 ++++++++++ .../polars-plan/src/dsl/function_expr/mod.rs | 33 +++++++++++ .../src/dsl/function_expr/schema.rs | 13 +++++ crates/polars-plan/src/dsl/mod.rs | 57 +++++-------------- 5 files changed, 95 insertions(+), 45 deletions(-) diff --git a/crates/polars-ops/src/series/ops/rank.rs b/crates/polars-ops/src/series/ops/rank.rs index 41f9b4ca8eb9..568ea02345ee 100644 --- a/crates/polars-ops/src/series/ops/rank.rs +++ b/crates/polars-ops/src/series/ops/rank.rs @@ -6,10 +6,13 @@ use rand::prelude::SliceRandom; use rand::prelude::*; #[cfg(feature = "random")] use rand::{rngs::SmallRng, SeedableRng}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; use crate::prelude::SeriesSealed; -#[derive(Copy, Clone)] +#[derive(Copy, Clone, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum RankMethod { Average, Min, @@ -21,7 +24,8 @@ pub enum RankMethod { } // We might want to add a `nulls_last` or `null_behavior` field. -#[derive(Copy, Clone)] +#[derive(Copy, Clone, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct RankOptions { pub method: RankMethod, pub descending: bool, diff --git a/crates/polars-plan/src/dsl/function_expr/dispatch.rs b/crates/polars-plan/src/dsl/function_expr/dispatch.rs index 8cd638a89b8e..4b3c5f14c1e3 100644 --- a/crates/polars-plan/src/dsl/function_expr/dispatch.rs +++ b/crates/polars-plan/src/dsl/function_expr/dispatch.rs @@ -76,3 +76,32 @@ pub(super) fn max_horizontal(s: &mut [Series]) -> PolarsResult> { pub(super) fn min_horizontal(s: &mut [Series]) -> PolarsResult> { polars_ops::prelude::min_horizontal(s) } + +pub(super) fn drop_nulls(s: &Series) -> PolarsResult { + Ok(s.drop_nulls()) +} + +#[cfg(feature = "mode")] +pub(super) fn mode(s: &Series) -> PolarsResult { + mode::mode(s) +} + +#[cfg(feature = "moment")] +pub(super) fn skew(s: &Series, bias: bool) -> PolarsResult { + s.skew(bias).map(|opt_v| Series::new(s.name(), &[opt_v])) +} + +#[cfg(feature = "moment")] +pub(super) fn kurtosis(s: &Series, fisher: bool, bias: bool) -> PolarsResult { + s.kurtosis(fisher, bias) + .map(|opt_v| Series::new(s.name(), &[opt_v])) +} + +pub(super) fn arg_unique(s: &Series) -> PolarsResult { + s.arg_unique().map(|ok| ok.into_series()) +} + +#[cfg(feature = "rank")] +pub(super) fn rank(s: &Series, options: RankOptions, seed: Option) -> PolarsResult { + Ok(s.rank(options, seed)) +} diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index 9cad3792f60b..40a5eef2f326 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -138,6 +138,19 @@ pub enum FunctionExpr { periods: i64, }, DropNans, + DropNulls, + #[cfg(feature = "mode")] + Mode, + #[cfg(feature = "moment")] + Skew(bool), + #[cfg(feature = "moment")] + Kurtosis(bool, bool), + ArgUnique, + #[cfg(feature = "rank")] + Rank { + options: RankOptions, + seed: Option, + }, #[cfg(feature = "round_series")] Clip { has_min: bool, @@ -372,6 +385,16 @@ impl Display for FunctionExpr { RollingSkew { .. } => "rolling_skew", ShiftAndFill { .. } => "shift_and_fill", DropNans => "drop_nans", + DropNulls => "drop_nulls", + #[cfg(feature = "mode")] + Mode => "mode", + #[cfg(feature = "moment")] + Skew(_) => "skew", + #[cfg(feature = "moment")] + Kurtosis(..) => "kurtosis", + ArgUnique => "arg_unique", + #[cfg(feature = "rank")] + Rank { .. } => "rank", #[cfg(feature = "round_series")] Clip { has_min, has_max } => match (has_min, has_max) { (true, true) => "clip", @@ -626,10 +649,20 @@ impl From for SpecialEq> { map_as_slice!(shift_and_fill::shift_and_fill, periods) }, DropNans => map_owned!(nan::drop_nans), + DropNulls => map!(dispatch::drop_nulls), #[cfg(feature = "round_series")] Clip { has_min, has_max } => { map_as_slice!(clip::clip, has_min, has_max) }, + #[cfg(feature = "mode")] + Mode => map!(dispatch::mode), + #[cfg(feature = "moment")] + Skew(bias) => map!(dispatch::skew, bias), + #[cfg(feature = "moment")] + Kurtosis(fisher, bias) => map!(dispatch::kurtosis, fisher, bias), + ArgUnique => map!(dispatch::arg_unique), + #[cfg(feature = "rank")] + Rank { options, seed } => map!(dispatch::rank, options, seed), ListExpr(lf) => { use ListFunction::*; match lf { diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index 8ae84c97a907..cf2178fcd42c 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -47,8 +47,21 @@ impl FunctionExpr { RollingSkew { .. } => mapper.map_to_float_dtype(), ShiftAndFill { .. } => mapper.with_same_dtype(), DropNans => mapper.with_same_dtype(), + DropNulls => mapper.with_same_dtype(), #[cfg(feature = "round_series")] Clip { .. } => mapper.with_same_dtype(), + #[cfg(feature = "mode")] + Mode => mapper.with_same_dtype(), + #[cfg(feature = "moment")] + Skew(_) => mapper.with_dtype(DataType::Float64), + #[cfg(feature = "moment")] + Kurtosis(..) => mapper.with_dtype(DataType::Float64), + ArgUnique => mapper.with_dtype(IDX_DTYPE), + #[cfg(feature = "rank")] + Rank { options, .. } => mapper.with_dtype(match options.method { + RankMethod::Average => DataType::Float64, + _ => IDX_DTYPE, + }), ListExpr(l) => { use ListFunction::*; match l { diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index 87f265bae580..0eebb7511ec7 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -2,8 +2,6 @@ //! Domain specific language for the Lazy API. #[cfg(feature = "rolling_window")] use polars_core::utils::ensure_sorted_arg; -#[cfg(feature = "mode")] -use polars_ops::chunked_array::mode::mode; #[cfg(feature = "dtype-categorical")] pub mod cat; #[cfg(feature = "dtype-categorical")] @@ -183,7 +181,7 @@ impl Expr { /// Drop null values. pub fn drop_nulls(self) -> Self { - self.apply(|s| Ok(Some(s.drop_nulls())), GetOutput::same_type()) + self.apply_private(FunctionExpr::DropNulls) } /// Drop NaN values. @@ -345,11 +343,7 @@ impl Expr { /// Get the first index of unique values of this expression. pub fn arg_unique(self) -> Self { - self.apply( - |s: Series| s.arg_unique().map(|ca| Some(ca.into_series())), - GetOutput::from_type(IDX_DTYPE), - ) - .with_fmt("arg_unique") + self.apply_private(FunctionExpr::ArgUnique) } /// Get the index value that has the minimum value. @@ -1142,8 +1136,7 @@ impl Expr { #[cfg(feature = "mode")] /// Compute the mode(s) of this column. This is the most occurring value. pub fn mode(self) -> Expr { - self.apply(|s| mode(&s).map(Some), GetOutput::same_type()) - .with_fmt("mode") + self.apply_private(FunctionExpr::Mode) } /// Keep the original root name @@ -1484,14 +1477,7 @@ impl Expr { #[cfg(feature = "rank")] /// Assign ranks to data, dealing with ties appropriately. pub fn rank(self, options: RankOptions, seed: Option) -> Expr { - self.apply( - move |s| Ok(Some(s.rank(options, seed))), - GetOutput::map_field(move |fld| match options.method { - RankMethod::Average => Field::new(fld.name(), DataType::Float64), - _ => Field::new(fld.name(), IDX_DTYPE), - }), - ) - .with_fmt("rank") + self.apply_private(FunctionExpr::Rank { options, seed }) } #[cfg(feature = "cutqcut")] @@ -1585,19 +1571,11 @@ impl Expr { /// /// see: [scipy](https://github.com/scipy/scipy/blob/47bb6febaa10658c72962b9615d5d5aa2513fa3a/scipy/stats/stats.py#L1024) pub fn skew(self, bias: bool) -> Expr { - self.apply( - move |s| { - s.skew(bias) - .map(|opt_v| Series::new(s.name(), &[opt_v])) - .map(Some) - }, - GetOutput::from_type(DataType::Float64), - ) - .with_function_options(|mut options| { - options.fmt_str = "skew"; - options.auto_explode = true; - options - }) + self.apply_private(FunctionExpr::Skew(bias)) + .with_function_options(|mut options| { + options.auto_explode = true; + options + }) } #[cfg(feature = "moment")] @@ -1609,18 +1587,11 @@ impl Expr { /// If bias is False then the kurtosis is calculated using k statistics to /// eliminate bias coming from biased moment estimators. pub fn kurtosis(self, fisher: bool, bias: bool) -> Expr { - self.apply( - move |s| { - s.kurtosis(fisher, bias) - .map(|opt_v| Some(Series::new(s.name(), &[opt_v]))) - }, - GetOutput::from_type(DataType::Float64), - ) - .with_function_options(|mut options| { - options.fmt_str = "kurtosis"; - options.auto_explode = true; - options - }) + self.apply_private(FunctionExpr::Kurtosis(fisher, bias)) + .with_function_options(|mut options| { + options.auto_explode = true; + options + }) } /// Get maximal value that could be hold by this dtype. From d24c50801a6aa26e34e0ef4c3a440a603b51c433 Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Wed, 18 Oct 2023 13:15:13 +0800 Subject: [PATCH 035/103] chore(rust): Move ewma to polars-ops (#11794) --- crates/polars-core/Cargo.toml | 1 - crates/polars-core/src/series/ops/ewm.rs | 104 ------------------ crates/polars-core/src/series/ops/mod.rs | 2 - crates/polars-ops/Cargo.toml | 1 + crates/polars-ops/src/series/ops/ewm.rs | 103 +++++++++++++++++ crates/polars-ops/src/series/ops/mod.rs | 4 + crates/polars-plan/Cargo.toml | 2 +- .../polars-plan/src/dsl/function_expr/ewm.rs | 6 +- crates/polars/Cargo.toml | 2 +- 9 files changed, 113 insertions(+), 112 deletions(-) delete mode 100644 crates/polars-core/src/series/ops/ewm.rs create mode 100644 crates/polars-ops/src/series/ops/ewm.rs diff --git a/crates/polars-core/Cargo.toml b/crates/polars-core/Cargo.toml index 7c28a57e21a4..afecf3c5a88e 100644 --- a/crates/polars-core/Cargo.toml +++ b/crates/polars-core/Cargo.toml @@ -96,7 +96,6 @@ moment = [] diagonal_concat = [] horizontal_concat = [] abs = [] -ewma = [] dataframe_arithmetic = [] product = [] unique_counts = [] diff --git a/crates/polars-core/src/series/ops/ewm.rs b/crates/polars-core/src/series/ops/ewm.rs deleted file mode 100644 index 388a44eb4a2d..000000000000 --- a/crates/polars-core/src/series/ops/ewm.rs +++ /dev/null @@ -1,104 +0,0 @@ -use std::convert::TryFrom; - -pub use arrow::legacy::kernels::ewm::EWMOptions; -use arrow::legacy::kernels::ewm::{ewm_mean, ewm_std, ewm_var}; - -use crate::prelude::*; - -fn check_alpha(alpha: f64) -> PolarsResult<()> { - polars_ensure!((0.0..=1.0).contains(&alpha), ComputeError: "alpha must be in [0; 1]"); - Ok(()) -} - -impl Series { - pub fn ewm_mean(&self, options: EWMOptions) -> PolarsResult { - check_alpha(options.alpha)?; - match self.dtype() { - DataType::Float32 => { - let xs = self.f32().unwrap(); - let result = ewm_mean( - xs, - options.alpha as f32, - options.adjust, - options.min_periods, - options.ignore_nulls, - ); - Series::try_from((self.name(), Box::new(result) as ArrayRef)) - }, - DataType::Float64 => { - let xs = self.f64().unwrap(); - let result = ewm_mean( - xs, - options.alpha, - options.adjust, - options.min_periods, - options.ignore_nulls, - ); - Series::try_from((self.name(), Box::new(result) as ArrayRef)) - }, - _ => self.cast(&DataType::Float64)?.ewm_mean(options), - } - } - - pub fn ewm_std(&self, options: EWMOptions) -> PolarsResult { - check_alpha(options.alpha)?; - match self.dtype() { - DataType::Float32 => { - let xs = self.f32().unwrap(); - let result = ewm_std( - xs, - options.alpha as f32, - options.adjust, - options.bias, - options.min_periods, - options.ignore_nulls, - ); - Series::try_from((self.name(), Box::new(result) as ArrayRef)) - }, - DataType::Float64 => { - let xs = self.f64().unwrap(); - let result = ewm_std( - xs, - options.alpha, - options.adjust, - options.bias, - options.min_periods, - options.ignore_nulls, - ); - Series::try_from((self.name(), Box::new(result) as ArrayRef)) - }, - _ => self.cast(&DataType::Float64)?.ewm_std(options), - } - } - - pub fn ewm_var(&self, options: EWMOptions) -> PolarsResult { - check_alpha(options.alpha)?; - match self.dtype() { - DataType::Float32 => { - let xs = self.f32().unwrap(); - let result = ewm_var( - xs, - options.alpha as f32, - options.adjust, - options.bias, - options.min_periods, - options.ignore_nulls, - ); - Series::try_from((self.name(), Box::new(result) as ArrayRef)) - }, - DataType::Float64 => { - let xs = self.f64().unwrap(); - let result = ewm_var( - xs, - options.alpha, - options.adjust, - options.bias, - options.min_periods, - options.ignore_nulls, - ); - Series::try_from((self.name(), Box::new(result) as ArrayRef)) - }, - _ => self.cast(&DataType::Float64)?.ewm_var(options), - } - } -} diff --git a/crates/polars-core/src/series/ops/mod.rs b/crates/polars-core/src/series/ops/mod.rs index 1cf51c1743dd..bc57ad3ee480 100644 --- a/crates/polars-core/src/series/ops/mod.rs +++ b/crates/polars-core/src/series/ops/mod.rs @@ -1,8 +1,6 @@ #[cfg(feature = "diff")] pub mod diff; mod downcast; -#[cfg(feature = "ewma")] -mod ewm; mod extend; #[cfg(feature = "moment")] pub mod moment; diff --git a/crates/polars-ops/Cargo.toml b/crates/polars-ops/Cargo.toml index 3e1f83fa8978..8e5c2aaa25c3 100644 --- a/crates/polars-ops/Cargo.toml +++ b/crates/polars-ops/Cargo.toml @@ -113,3 +113,4 @@ convert_index = [] repeat_by = [] peaks = [] cum_agg = [] +ewma = [] diff --git a/crates/polars-ops/src/series/ops/ewm.rs b/crates/polars-ops/src/series/ops/ewm.rs new file mode 100644 index 000000000000..6f4458777306 --- /dev/null +++ b/crates/polars-ops/src/series/ops/ewm.rs @@ -0,0 +1,103 @@ +use std::convert::TryFrom; + +pub use arrow::legacy::kernels::ewm::EWMOptions; +use arrow::legacy::kernels::ewm::{ + ewm_mean as kernel_ewm_mean, ewm_std as kernel_ewm_std, ewm_var as kernel_ewm_var, +}; +use polars_core::prelude::*; + +fn check_alpha(alpha: f64) -> PolarsResult<()> { + polars_ensure!((0.0..=1.0).contains(&alpha), ComputeError: "alpha must be in [0; 1]"); + Ok(()) +} + +pub fn ewm_mean(s: &Series, options: EWMOptions) -> PolarsResult { + check_alpha(options.alpha)?; + match s.dtype() { + DataType::Float32 => { + let xs = s.f32().unwrap(); + let result = kernel_ewm_mean( + xs, + options.alpha as f32, + options.adjust, + options.min_periods, + options.ignore_nulls, + ); + Series::try_from((s.name(), Box::new(result) as ArrayRef)) + }, + DataType::Float64 => { + let xs = s.f64().unwrap(); + let result = kernel_ewm_mean( + xs, + options.alpha, + options.adjust, + options.min_periods, + options.ignore_nulls, + ); + Series::try_from((s.name(), Box::new(result) as ArrayRef)) + }, + _ => ewm_mean(&s.cast(&DataType::Float64)?, options), + } +} + +pub fn ewm_std(s: &Series, options: EWMOptions) -> PolarsResult { + check_alpha(options.alpha)?; + match s.dtype() { + DataType::Float32 => { + let xs = s.f32().unwrap(); + let result = kernel_ewm_std( + xs, + options.alpha as f32, + options.adjust, + options.bias, + options.min_periods, + options.ignore_nulls, + ); + Series::try_from((s.name(), Box::new(result) as ArrayRef)) + }, + DataType::Float64 => { + let xs = s.f64().unwrap(); + let result = kernel_ewm_std( + xs, + options.alpha, + options.adjust, + options.bias, + options.min_periods, + options.ignore_nulls, + ); + Series::try_from((s.name(), Box::new(result) as ArrayRef)) + }, + _ => ewm_std(&s.cast(&DataType::Float64)?, options), + } +} + +pub fn ewm_var(s: &Series, options: EWMOptions) -> PolarsResult { + check_alpha(options.alpha)?; + match s.dtype() { + DataType::Float32 => { + let xs = s.f32().unwrap(); + let result = kernel_ewm_var( + xs, + options.alpha as f32, + options.adjust, + options.bias, + options.min_periods, + options.ignore_nulls, + ); + Series::try_from((s.name(), Box::new(result) as ArrayRef)) + }, + DataType::Float64 => { + let xs = s.f64().unwrap(); + let result = kernel_ewm_var( + xs, + options.alpha, + options.adjust, + options.bias, + options.min_periods, + options.ignore_nulls, + ); + Series::try_from((s.name(), Box::new(result) as ArrayRef)) + }, + _ => ewm_var(&s.cast(&DataType::Float64)?, options), + } +} diff --git a/crates/polars-ops/src/series/ops/mod.rs b/crates/polars-ops/src/series/ops/mod.rs index c524a880d967..6437c7a0ffac 100644 --- a/crates/polars-ops/src/series/ops/mod.rs +++ b/crates/polars-ops/src/series/ops/mod.rs @@ -7,6 +7,8 @@ mod clip; mod cum_agg; #[cfg(feature = "cutqcut")] mod cut; +#[cfg(feature = "ewma")] +mod ewm; #[cfg(feature = "round_series")] mod floor_divide; #[cfg(feature = "fused")] @@ -47,6 +49,8 @@ pub use clip::*; pub use cum_agg::*; #[cfg(feature = "cutqcut")] pub use cut::*; +#[cfg(feature = "ewma")] +pub use ewm::*; #[cfg(feature = "round_series")] pub use floor_divide::*; #[cfg(feature = "fused")] diff --git a/crates/polars-plan/Cargo.toml b/crates/polars-plan/Cargo.toml index 8ea14f5fc7b4..ee412f48dbec 100644 --- a/crates/polars-plan/Cargo.toml +++ b/crates/polars-plan/Cargo.toml @@ -113,7 +113,7 @@ moment = ["polars-core/moment", "polars-ops/moment"] abs = ["polars-core/abs"] random = ["polars-core/random"] dynamic_group_by = ["polars-core/dynamic_group_by"] -ewma = ["polars-core/ewma"] +ewma = ["polars-ops/ewma"] dot_diagram = [] unique_counts = ["polars-core/unique_counts"] log = ["polars-ops/log"] diff --git a/crates/polars-plan/src/dsl/function_expr/ewm.rs b/crates/polars-plan/src/dsl/function_expr/ewm.rs index a26285eef33a..b824ca3013e9 100644 --- a/crates/polars-plan/src/dsl/function_expr/ewm.rs +++ b/crates/polars-plan/src/dsl/function_expr/ewm.rs @@ -1,13 +1,13 @@ use super::*; pub(super) fn ewm_mean(s: &Series, options: EWMOptions) -> PolarsResult { - s.ewm_mean(options) + polars_ops::prelude::ewm_mean(s, options) } pub(super) fn ewm_std(s: &Series, options: EWMOptions) -> PolarsResult { - s.ewm_std(options) + polars_ops::prelude::ewm_std(s, options) } pub(super) fn ewm_var(s: &Series, options: EWMOptions) -> PolarsResult { - s.ewm_var(options) + polars_ops::prelude::ewm_var(s, options) } diff --git a/crates/polars/Cargo.toml b/crates/polars/Cargo.toml index 876193e089fb..58f6a8334bcd 100644 --- a/crates/polars/Cargo.toml +++ b/crates/polars/Cargo.toml @@ -152,7 +152,7 @@ diagonal_concat = ["polars-core/diagonal_concat", "polars-lazy?/diagonal_concat" horizontal_concat = ["polars-core/horizontal_concat"] abs = ["polars-core/abs", "polars-lazy?/abs"] dynamic_group_by = ["polars-core/dynamic_group_by", "polars-lazy?/dynamic_group_by"] -ewma = ["polars-core/ewma", "polars-lazy?/ewma"] +ewma = ["polars-ops/ewma", "polars-lazy?/ewma"] dot_diagram = ["polars-lazy?/dot_diagram"] dataframe_arithmetic = ["polars-core/dataframe_arithmetic"] product = ["polars-core/product"] From f5f3fa9f66292d2040c3af5a7d92047fe70d0cbe Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Wed, 18 Oct 2023 09:55:51 +0400 Subject: [PATCH 036/103] fix(rust): remove flag inconsistency 'map_many' (#11817) --- crates/polars-plan/src/dsl/mod.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index 0eebb7511ec7..40dbb939de6f 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -647,7 +647,6 @@ impl Expr { options: FunctionOptions { collect_groups: ApplyOptions::ApplyGroups, fmt_str: "", - auto_explode: true, ..Default::default() }, } From d6ef2e4c0bb4fa9f80236ee81710d11427e61a00 Mon Sep 17 00:00:00 2001 From: nameexhaustion Date: Wed, 18 Oct 2023 17:12:11 +1100 Subject: [PATCH 037/103] refactor(rust): remove redundant if branch in nested parquet (#11814) --- .../src/io/parquet/read/deserialize/nested_utils.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/crates/polars-arrow/src/io/parquet/read/deserialize/nested_utils.rs b/crates/polars-arrow/src/io/parquet/read/deserialize/nested_utils.rs index ad53b72a7e6f..482d5117a7da 100644 --- a/crates/polars-arrow/src/io/parquet/read/deserialize/nested_utils.rs +++ b/crates/polars-arrow/src/io/parquet/read/deserialize/nested_utils.rs @@ -502,9 +502,6 @@ where if *remaining == 0 && items.is_empty() { return MaybeNext::None; } - if !items.is_empty() && items.front().unwrap().0.len() > chunk_size.unwrap_or(usize::MAX) { - return MaybeNext::Some(Ok(items.pop_front().unwrap())); - } match iter.next() { Err(e) => MaybeNext::Some(Err(e.into())), @@ -539,7 +536,8 @@ where Err(e) => return MaybeNext::Some(Err(e)), }; - // if possible, return the value immediately. + // this comparison is strictly greater to ensure the contents of the + // row are fully read. if !items.is_empty() && items.front().unwrap().0.len() > chunk_size.unwrap_or(usize::MAX) { From c3e1c1ee136f7340b7268b74319a1322d9cfc5af Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Wed, 18 Oct 2023 10:42:20 +0400 Subject: [PATCH 038/103] feat: don't require empty config for cloud scan_parquet (#11819) --- .../polars-io/src/cloud/object_store_setup.rs | 35 ++++++------------- 1 file changed, 11 insertions(+), 24 deletions(-) diff --git a/crates/polars-io/src/cloud/object_store_setup.rs b/crates/polars-io/src/cloud/object_store_setup.rs index be7c4a28d822..5826c48f8bd1 100644 --- a/crates/polars-io/src/cloud/object_store_setup.rs +++ b/crates/polars-io/src/cloud/object_store_setup.rs @@ -22,13 +22,6 @@ fn err_missing_feature(feature: &str, scheme: &str) -> BuildResult { "feature '{}' must be enabled in order to use '{}' cloud urls", feature, scheme, ); } -#[cfg(any(feature = "azure", feature = "aws", feature = "gcp"))] -fn err_missing_configuration(feature: &str, scheme: &str) -> BuildResult { - polars_bail!( - ComputeError: - "configuration '{}' must be provided in order to use '{}' cloud urls", feature, scheme, - ); -} /// Build an [`ObjectStore`] based on the URL and passed in url. Return the cloud location and an implementation of the object store. pub async fn build_object_store(url: &str, options: Option<&CloudOptions>) -> BuildResult { @@ -47,6 +40,11 @@ pub async fn build_object_store(url: &str, options: Option<&CloudOptions>) -> Bu } let cloud_type = CloudType::from_str(url)?; + let options = key + .1 + .as_ref() + .map(Cow::Borrowed) + .unwrap_or_else(|| Cow::Owned(Default::default())); let store = match cloud_type { CloudType::File => { let local = LocalFileSystem::new(); @@ -55,11 +53,6 @@ pub async fn build_object_store(url: &str, options: Option<&CloudOptions>) -> Bu CloudType::Aws => { #[cfg(feature = "aws")] { - let options = key - .1 - .as_ref() - .map(Cow::Borrowed) - .unwrap_or_else(|| Cow::Owned(Default::default())); let store = options.build_aws(url).await?; Ok::<_, PolarsError>(Arc::new(store) as Arc) } @@ -68,12 +61,9 @@ pub async fn build_object_store(url: &str, options: Option<&CloudOptions>) -> Bu }, CloudType::Gcp => { #[cfg(feature = "gcp")] - match key.1.as_ref() { - Some(options) => { - let store = options.build_gcp(url)?; - Ok::<_, PolarsError>(Arc::new(store) as Arc) - }, - _ => return err_missing_configuration("gcp", &cloud_location.scheme), + { + let store = options.build_gcp(url)?; + Ok::<_, PolarsError>(Arc::new(store) as Arc) } #[cfg(not(feature = "gcp"))] return err_missing_feature("gcp", &cloud_location.scheme); @@ -81,12 +71,9 @@ pub async fn build_object_store(url: &str, options: Option<&CloudOptions>) -> Bu CloudType::Azure => { { #[cfg(feature = "azure")] - match key.1.as_ref() { - Some(options) => { - let store = options.build_azure(url)?; - Ok::<_, PolarsError>(Arc::new(store) as Arc) - }, - _ => return err_missing_configuration("azure", &cloud_location.scheme), + { + let store = options.build_azure(url)?; + Ok::<_, PolarsError>(Arc::new(store) as Arc) } } #[cfg(not(feature = "azure"))] From 9ea46ef910f341112b9a8de0db04c60d87dc4ac5 Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Wed, 18 Oct 2023 14:43:26 +0800 Subject: [PATCH 039/103] chore(rust): Move diff to polars-ops (#11818) --- crates/polars-core/Cargo.toml | 2 -- crates/polars-core/src/series/ops/diff.rs | 24 ------------------- crates/polars-core/src/series/ops/mod.rs | 2 -- crates/polars-ops/Cargo.toml | 4 ++-- .../src/chunked_array/list/namespace.rs | 4 +++- crates/polars-ops/src/series/ops/diff.rs | 22 +++++++++++++++++ crates/polars-ops/src/series/ops/mod.rs | 4 ++++ .../polars-ops/src/series/ops/pct_change.rs | 6 ++--- crates/polars-plan/Cargo.toml | 2 +- .../src/dsl/function_expr/dispatch.rs | 2 +- crates/polars/Cargo.toml | 2 +- crates/polars/src/lib.rs | 2 +- 12 files changed, 38 insertions(+), 38 deletions(-) delete mode 100644 crates/polars-core/src/series/ops/diff.rs create mode 100644 crates/polars-ops/src/series/ops/diff.rs diff --git a/crates/polars-core/Cargo.toml b/crates/polars-core/Cargo.toml index afecf3c5a88e..12a8a536ebdd 100644 --- a/crates/polars-core/Cargo.toml +++ b/crates/polars-core/Cargo.toml @@ -91,7 +91,6 @@ take_opt_iter = [] group_by_list = [] # rolling window functions rolling_window = [] -diff = [] moment = [] diagonal_concat = [] horizontal_concat = [] @@ -147,7 +146,6 @@ docs-selection = [ "dot_product", "row_hash", "rolling_window", - "diff", "moment", "dtype-categorical", "dtype-decimal", diff --git a/crates/polars-core/src/series/ops/diff.rs b/crates/polars-core/src/series/ops/diff.rs deleted file mode 100644 index b3452ccade81..000000000000 --- a/crates/polars-core/src/series/ops/diff.rs +++ /dev/null @@ -1,24 +0,0 @@ -use crate::prelude::*; -use crate::series::ops::NullBehavior; - -impl Series { - pub fn diff(&self, n: i64, null_behavior: NullBehavior) -> PolarsResult { - use DataType::*; - let s = match self.dtype() { - UInt8 => self.cast(&Int16).unwrap(), - UInt16 => self.cast(&Int32).unwrap(), - UInt32 | UInt64 => self.cast(&Int64).unwrap(), - _ => self.clone(), - }; - - match null_behavior { - NullBehavior::Ignore => Ok(&s - &s.shift(n)), - NullBehavior::Drop => { - polars_ensure!(n > 0, InvalidOperation: "only positive integer allowed if nulls are dropped in 'diff' operation"); - let n = n as usize; - let len = s.len() - n; - Ok(&self.slice(n as i64, len) - &s.slice(0, len)) - }, - } - } -} diff --git a/crates/polars-core/src/series/ops/mod.rs b/crates/polars-core/src/series/ops/mod.rs index bc57ad3ee480..48766748f58b 100644 --- a/crates/polars-core/src/series/ops/mod.rs +++ b/crates/polars-core/src/series/ops/mod.rs @@ -1,5 +1,3 @@ -#[cfg(feature = "diff")] -pub mod diff; mod downcast; mod extend; #[cfg(feature = "moment")] diff --git a/crates/polars-ops/Cargo.toml b/crates/polars-ops/Cargo.toml index 8e5c2aaa25c3..1ea76b197cd5 100644 --- a/crates/polars-ops/Cargo.toml +++ b/crates/polars-ops/Cargo.toml @@ -82,8 +82,8 @@ to_dummies = [] interpolate = [] list_to_struct = ["polars-core/dtype-struct"] list_count = [] -diff = ["polars-core/diff"] -pct_change = ["polars-core/diff"] +diff = [] +pct_change = ["diff"] strings = ["polars-core/strings"] string_justify = ["polars-core/strings"] string_from_radix = ["polars-core/strings"] diff --git a/crates/polars-ops/src/chunked_array/list/namespace.rs b/crates/polars-ops/src/chunked_array/list/namespace.rs index f9f92a75ad85..37165179d187 100644 --- a/crates/polars-ops/src/chunked_array/list/namespace.rs +++ b/crates/polars-ops/src/chunked_array/list/namespace.rs @@ -17,6 +17,8 @@ use super::*; use crate::chunked_array::list::any_all::*; use crate::chunked_array::list::min_max::{list_max_function, list_min_function}; use crate::chunked_array::list::sum_mean::sum_with_nulls; +#[cfg(feature = "diff")] +use crate::prelude::diff; use crate::prelude::list::sum_mean::{mean_list_numerical, sum_list_numerical}; use crate::series::ArgAgg; @@ -256,7 +258,7 @@ pub trait ListNameSpaceImpl: AsList { #[cfg(feature = "diff")] fn lst_diff(&self, n: i64, null_behavior: NullBehavior) -> PolarsResult { let ca = self.as_list(); - ca.try_apply_amortized(|s| s.as_ref().diff(n, null_behavior)) + ca.try_apply_amortized(|s| diff(s.as_ref(), n, null_behavior)) } fn lst_shift(&self, periods: &Series) -> PolarsResult { diff --git a/crates/polars-ops/src/series/ops/diff.rs b/crates/polars-ops/src/series/ops/diff.rs new file mode 100644 index 000000000000..8fa28768609e --- /dev/null +++ b/crates/polars-ops/src/series/ops/diff.rs @@ -0,0 +1,22 @@ +use polars_core::prelude::*; +use polars_core::series::ops::NullBehavior; + +pub fn diff(s: &Series, n: i64, null_behavior: NullBehavior) -> PolarsResult { + use DataType::*; + let s = match s.dtype() { + UInt8 => s.cast(&Int16)?, + UInt16 => s.cast(&Int32)?, + UInt32 | UInt64 => s.cast(&Int64)?, + _ => s.clone(), + }; + + match null_behavior { + NullBehavior::Ignore => Ok(&s - &s.shift(n)), + NullBehavior::Drop => { + polars_ensure!(n > 0, InvalidOperation: "only positive integer allowed if nulls are dropped in 'diff' operation"); + let n = n as usize; + let len = s.len() - n; + Ok(&s.slice(n as i64, len) - &s.slice(0, len)) + }, + } +} diff --git a/crates/polars-ops/src/series/ops/mod.rs b/crates/polars-ops/src/series/ops/mod.rs index 6437c7a0ffac..b771520bd192 100644 --- a/crates/polars-ops/src/series/ops/mod.rs +++ b/crates/polars-ops/src/series/ops/mod.rs @@ -7,6 +7,8 @@ mod clip; mod cum_agg; #[cfg(feature = "cutqcut")] mod cut; +#[cfg(feature = "diff")] +mod diff; #[cfg(feature = "ewma")] mod ewm; #[cfg(feature = "round_series")] @@ -49,6 +51,8 @@ pub use clip::*; pub use cum_agg::*; #[cfg(feature = "cutqcut")] pub use cut::*; +#[cfg(feature = "diff")] +pub use diff::*; #[cfg(feature = "ewma")] pub use ewm::*; #[cfg(feature = "round_series")] diff --git a/crates/polars-ops/src/series/ops/pct_change.rs b/crates/polars-ops/src/series/ops/pct_change.rs index 24df891b382e..56c7af142e9b 100644 --- a/crates/polars-ops/src/series/ops/pct_change.rs +++ b/crates/polars-ops/src/series/ops/pct_change.rs @@ -1,6 +1,8 @@ use polars_core::prelude::*; use polars_core::series::ops::NullBehavior; +use crate::prelude::diff; + pub fn pct_change(s: &Series, n: &Series) -> PolarsResult { polars_ensure!( n.len() == 1, @@ -16,9 +18,7 @@ pub fn pct_change(s: &Series, n: &Series) -> PolarsResult { let n_s = n.cast(&DataType::Int64)?; if let Some(n) = n_s.i64()?.get(0) { - fill_null_s - .diff(n, NullBehavior::Ignore)? - .divide(&fill_null_s.shift(n)) + diff(&fill_null_s, n, NullBehavior::Ignore)?.divide(&fill_null_s.shift(n)) } else { Ok(Series::full_null(s.name(), s.len(), s.dtype())) } diff --git a/crates/polars-plan/Cargo.toml b/crates/polars-plan/Cargo.toml index ee412f48dbec..99fbd668c782 100644 --- a/crates/polars-plan/Cargo.toml +++ b/crates/polars-plan/Cargo.toml @@ -107,7 +107,7 @@ rolling_window = [ "polars-time/rolling_window", ] rank = ["polars-ops/rank"] -diff = ["polars-core/diff", "polars-ops/diff"] +diff = ["polars-ops/diff"] pct_change = ["polars-ops/pct_change"] moment = ["polars-core/moment", "polars-ops/moment"] abs = ["polars-core/abs"] diff --git a/crates/polars-plan/src/dsl/function_expr/dispatch.rs b/crates/polars-plan/src/dsl/function_expr/dispatch.rs index 4b3c5f14c1e3..eb00120e970f 100644 --- a/crates/polars-plan/src/dsl/function_expr/dispatch.rs +++ b/crates/polars-plan/src/dsl/function_expr/dispatch.rs @@ -15,7 +15,7 @@ pub(super) fn approx_n_unique(s: &Series) -> PolarsResult { #[cfg(feature = "diff")] pub(super) fn diff(s: &Series, n: i64, null_behavior: NullBehavior) -> PolarsResult { - s.diff(n, null_behavior) + polars_ops::prelude::diff(s, n, null_behavior) } #[cfg(feature = "pct_change")] diff --git a/crates/polars/Cargo.toml b/crates/polars/Cargo.toml index 58f6a8334bcd..3ca2884540ca 100644 --- a/crates/polars/Cargo.toml +++ b/crates/polars/Cargo.toml @@ -143,7 +143,7 @@ cum_agg = ["polars-ops/cum_agg", "polars-lazy?/cum_agg"] rolling_window = ["polars-core/rolling_window", "polars-lazy?/rolling_window", "polars-time/rolling_window"] interpolate = ["polars-ops/interpolate", "polars-lazy?/interpolate"] rank = ["polars-lazy?/rank", "polars-ops/rank"] -diff = ["polars-core/diff", "polars-lazy?/diff", "polars-ops/diff"] +diff = ["polars-ops/diff", "polars-lazy?/diff"] pct_change = ["polars-ops/pct_change", "polars-lazy?/pct_change"] moment = ["polars-core/moment", "polars-lazy?/moment", "polars-ops/moment"] range = ["polars-lazy?/range"] diff --git a/crates/polars/src/lib.rs b/crates/polars/src/lib.rs index c04fb989e47d..df7da1257ff1 100644 --- a/crates/polars/src/lib.rs +++ b/crates/polars/src/lib.rs @@ -282,7 +282,7 @@ //! [`cummin`]: polars_ops::prelude::cummin //! [`cummax`]: polars_ops::prelude::cummax //! [`rolling_mean`]: crate::series::Series#method.rolling_mean -//! [`diff`]: crate::series::Series::diff +//! [`diff`]: polars_ops::prelude::diff //! [`List`]: crate::datatypes::DataType::List //! [`Struct`]: crate::datatypes::DataType::Struct //! From 46e70094af11f92d7a6f64069118712171e8bd55 Mon Sep 17 00:00:00 2001 From: Walnut <39544927+Walnut356@users.noreply.github.com> Date: Wed, 18 Oct 2023 01:47:29 -0500 Subject: [PATCH 040/103] fix(rust, python): Edge cases for list count formatting (#11780) --- crates/polars-core/src/fmt.rs | 168 +++++++++++++++++++++++++++++----- py-polars/polars/config.py | 4 +- 2 files changed, 146 insertions(+), 26 deletions(-) diff --git a/crates/polars-core/src/fmt.rs b/crates/polars-core/src/fmt.rs index 24758ca320aa..6dec1221081e 100644 --- a/crates/polars-core/src/fmt.rs +++ b/crates/polars-core/src/fmt.rs @@ -964,36 +964,46 @@ fn fmt_struct(f: &mut Formatter<'_>, vals: &[AnyValue]) -> fmt::Result { impl Series { pub fn fmt_list(&self) -> String { - match self.len() { - 0 => "[]".to_string(), - 1 => format!("[{}]", self.get(0).unwrap()), - 2 => format!("[{}, {}]", self.get(0).unwrap(), self.get(1).unwrap()), - 3 => format!( - "[{}, {}, {}]", - self.get(0).unwrap(), - self.get(1).unwrap(), - self.get(2).unwrap() - ), - _ => { - let max_items = std::env::var(FMT_TABLE_CELL_LIST_LEN) - .as_deref() - .unwrap_or("") - .parse() - .map_or(3, |n: i64| if n < 0 { self.len() } else { n as usize }); + if self.is_empty() { + return "[]".to_owned(); + } + + let max_items = std::env::var(FMT_TABLE_CELL_LIST_LEN) + .as_deref() + .unwrap_or("") + .parse() + .map_or(3, |n: i64| if n < 0 { self.len() } else { n as usize }); + match max_items { + 0 => "[…]".to_owned(), + _ if max_items >= self.len() => { let mut result = "[".to_owned(); - for (i, item) in self.iter().enumerate() { + for i in 0..self.len() { + let item = self.get(i).unwrap(); write!(result, "{item}").unwrap(); + // this will always leave a trailing ", " after the last item + // but for long lists, this is faster than checking against the length each time + result.push_str(", "); + } + // remove trailing ", " and replace with closing brace + result.pop(); + result.pop(); + result.push(']'); - if i != self.len() - 1 { - result.push_str(", "); - } + result + }, + _ => { + let mut result = "[".to_owned(); - if i == max_items - 2 { + for (i, item) in self.iter().enumerate() { + if i == max_items.saturating_sub(1) { result.push_str("… "); write!(result, "{}", self.get(self.len() - 1).unwrap()).unwrap(); break; + } else { + write!(result, "{item}").unwrap(); + result.push_str(", "); } } result.push(']'); @@ -1136,7 +1146,7 @@ mod test { ListPrimitiveChunkedBuilder::::new("a", 10, 10, DataType::Int32); builder.append_opt_slice(Some(&[1, 2, 3, 4, 5, 6])); builder.append_opt_slice(None); - let list = builder.finish().into_series(); + let list_long = builder.finish().into_series(); assert_eq!( r#"shape: (2,) @@ -1145,7 +1155,7 @@ Series: 'a' [list[i32]] [1, 2, … 6] null ]"#, - format!("{:?}", list) + format!("{:?}", list_long) ); std::env::set_var("POLARS_FMT_TABLE_CELL_LIST_LEN", "10"); @@ -1157,8 +1167,116 @@ Series: 'a' [list[i32]] [1, 2, 3, 4, 5, 6] null ]"#, - format!("{:?}", list) - ) + format!("{:?}", list_long) + ); + + std::env::set_var("POLARS_FMT_TABLE_CELL_LIST_LEN", "-1"); + + assert_eq!( + r#"shape: (2,) +Series: 'a' [list[i32]] +[ + [1, 2, 3, 4, 5, 6] + null +]"#, + format!("{:?}", list_long) + ); + + std::env::set_var("POLARS_FMT_TABLE_CELL_LIST_LEN", "0"); + + assert_eq!( + r#"shape: (2,) +Series: 'a' [list[i32]] +[ + […] + null +]"#, + format!("{:?}", list_long) + ); + + std::env::set_var("POLARS_FMT_TABLE_CELL_LIST_LEN", "1"); + + assert_eq!( + r#"shape: (2,) +Series: 'a' [list[i32]] +[ + [… 6] + null +]"#, + format!("{:?}", list_long) + ); + + std::env::set_var("POLARS_FMT_TABLE_CELL_LIST_LEN", "4"); + + assert_eq!( + r#"shape: (2,) +Series: 'a' [list[i32]] +[ + [1, 2, 3, … 6] + null +]"#, + format!("{:?}", list_long) + ); + + let mut builder = + ListPrimitiveChunkedBuilder::::new("a", 10, 10, DataType::Int32); + builder.append_opt_slice(Some(&[1])); + builder.append_opt_slice(None); + let list_short = builder.finish().into_series(); + + std::env::set_var("POLARS_FMT_TABLE_CELL_LIST_LEN", ""); + + assert_eq!( + r#"shape: (2,) +Series: 'a' [list[i32]] +[ + [1] + null +]"#, + format!("{:?}", list_short) + ); + + std::env::set_var("POLARS_FMT_TABLE_CELL_LIST_LEN", "0"); + + assert_eq!( + r#"shape: (2,) +Series: 'a' [list[i32]] +[ + […] + null +]"#, + format!("{:?}", list_short) + ); + + std::env::set_var("POLARS_FMT_TABLE_CELL_LIST_LEN", "-1"); + + assert_eq!( + r#"shape: (2,) +Series: 'a' [list[i32]] +[ + [1] + null +]"#, + format!("{:?}", list_short) + ); + + let mut builder = + ListPrimitiveChunkedBuilder::::new("a", 10, 10, DataType::Int32); + builder.append_opt_slice(Some(&[])); + builder.append_opt_slice(None); + let list_empty = builder.finish().into_series(); + + std::env::set_var("POLARS_FMT_TABLE_CELL_LIST_LEN", ""); + + assert_eq!( + r#"shape: (2,) +Series: 'a' [list[i32]] +[ + [] + null +]"#, + format!("{:?}", list_empty) + ); } #[test] diff --git a/py-polars/polars/config.py b/py-polars/polars/config.py index 2d557d914354..3a5a1f1b07cb 100644 --- a/py-polars/polars/config.py +++ b/py-polars/polars/config.py @@ -588,7 +588,9 @@ def set_fmt_table_cell_list_len(cls, n: int | None) -> type[Config]: """ Set the number of elements to display for List values. - Values less than 0 will result in all values being printed. + Empty lists will always print "[]". Negative values will result in all values + being printed. A value of 0 will always "[…]" for lists with contents. A value + of 1 will print only the final item in the list. Parameters ---------- From 34d42c6dec6bb41889101bc758b0a4fe6419d4d0 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Wed, 18 Oct 2023 10:56:06 +0400 Subject: [PATCH 041/103] chore(rust): arrow: remove unused arithmetic code and remove doctests (#11820) --- crates/polars-arrow/src/array/ord.rs | 180 -------- .../src/compute/arithmetics/basic/add.rs | 314 +------------ .../src/compute/arithmetics/basic/div.rs | 63 --- .../src/compute/arithmetics/basic/mod.rs | 2 - .../src/compute/arithmetics/basic/mul.rs | 315 +------------ .../src/compute/arithmetics/basic/pow.rs | 49 -- .../src/compute/arithmetics/basic/rem.rs | 111 +---- .../src/compute/arithmetics/basic/sub.rs | 314 +------------ .../src/compute/arithmetics/decimal/add.rs | 220 --------- .../src/compute/arithmetics/decimal/div.rs | 301 ------------ .../src/compute/arithmetics/decimal/mod.rs | 120 ----- .../src/compute/arithmetics/decimal/mul.rs | 313 ------------- .../src/compute/arithmetics/decimal/sub.rs | 237 ---------- .../src/compute/arithmetics/mod.rs | 131 ------ .../src/compute/arithmetics/time.rs | 432 ------------------ .../polars-arrow/src/compute/if_then_else.rs | 18 - crates/polars-arrow/src/io/ipc/mod.rs | 48 -- .../src/io/ipc/write/file_async.rs | 40 -- .../src/io/ipc/write/stream_async.rs | 31 -- .../polars-arrow/src/io/parquet/write/sink.rs | 41 -- 20 files changed, 5 insertions(+), 3275 deletions(-) delete mode 100644 crates/polars-arrow/src/array/ord.rs delete mode 100644 crates/polars-arrow/src/compute/arithmetics/basic/pow.rs delete mode 100644 crates/polars-arrow/src/compute/arithmetics/decimal/add.rs delete mode 100644 crates/polars-arrow/src/compute/arithmetics/decimal/div.rs delete mode 100644 crates/polars-arrow/src/compute/arithmetics/decimal/mod.rs delete mode 100644 crates/polars-arrow/src/compute/arithmetics/decimal/mul.rs delete mode 100644 crates/polars-arrow/src/compute/arithmetics/decimal/sub.rs delete mode 100644 crates/polars-arrow/src/compute/arithmetics/time.rs diff --git a/crates/polars-arrow/src/array/ord.rs b/crates/polars-arrow/src/array/ord.rs deleted file mode 100644 index b585d67600e4..000000000000 --- a/crates/polars-arrow/src/array/ord.rs +++ /dev/null @@ -1,180 +0,0 @@ -//! Contains functions and function factories to order values within arrays. -use std::cmp::Ordering; -use polars_error::polars_bail; - -use crate::array::*; -use crate::datatypes::*; -use crate::offset::Offset; -use crate::types::NativeType; -use crate::util::total_ord::TotalOrd; - -/// Compare the values at two arbitrary indices in two arrays. -pub type DynComparator = Box Ordering + Send + Sync>; - -fn compare_primitives( - left: &dyn Array, - right: &dyn Array, -) -> DynComparator { - let left = left - .as_any() - .downcast_ref::>() - .unwrap() - .clone(); - let right = right - .as_any() - .downcast_ref::>() - .unwrap() - .clone(); - Box::new(move |i, j| left.value(i).tot_cmp(&right.value(j))) -} - -fn compare_boolean(left: &dyn Array, right: &dyn Array) -> DynComparator { - let left = left - .as_any() - .downcast_ref::() - .unwrap() - .clone(); - let right = right - .as_any() - .downcast_ref::() - .unwrap() - .clone(); - Box::new(move |i, j| left.value(i).cmp(&right.value(j))) -} - -fn compare_string(left: &dyn Array, right: &dyn Array) -> DynComparator { - let left = left - .as_any() - .downcast_ref::>() - .unwrap() - .clone(); - let right = right - .as_any() - .downcast_ref::>() - .unwrap() - .clone(); - Box::new(move |i, j| left.value(i).cmp(right.value(j))) -} - -fn compare_binary(left: &dyn Array, right: &dyn Array) -> DynComparator { - let left = left - .as_any() - .downcast_ref::>() - .unwrap() - .clone(); - let right = right - .as_any() - .downcast_ref::>() - .unwrap() - .clone(); - Box::new(move |i, j| left.value(i).cmp(right.value(j))) -} - -fn compare_dict(left: &DictionaryArray, right: &DictionaryArray) -> Result -where - K: DictionaryKey, -{ - let left_keys = left.keys().values().clone(); - let right_keys = right.keys().values().clone(); - - let comparator = build_compare(left.values().as_ref(), right.values().as_ref())?; - - Ok(Box::new(move |i: usize, j: usize| { - // safety: all dictionaries keys are guaranteed to be castable to usize - let key_left = unsafe { left_keys[i].as_usize() }; - let key_right = unsafe { right_keys[j].as_usize() }; - (comparator)(key_left, key_right) - })) -} - -macro_rules! dyn_dict { - ($key:ty, $lhs:expr, $rhs:expr) => {{ - let lhs = $lhs.as_any().downcast_ref().unwrap(); - let rhs = $rhs.as_any().downcast_ref().unwrap(); - compare_dict::<$key>(lhs, rhs)? - }}; -} - -/// returns a comparison function that compares values at two different slots -/// between two [`Array`]. -/// # Example -/// ``` -/// use polars_arrow::array::{ord::build_compare, PrimitiveArray}; -/// -/// # fn main() -> polars_arrow::error::Result<()> { -/// let array1 = PrimitiveArray::from_slice([1, 2]); -/// let array2 = PrimitiveArray::from_slice([3, 4]); -/// -/// let cmp = build_compare(&array1, &array2)?; -/// -/// // 1 (index 0 of array1) is smaller than 4 (index 1 of array2) -/// assert_eq!(std::cmp::Ordering::Less, (cmp)(0, 1)); -/// # Ok(()) -/// # } -/// ``` -/// # Error -/// The arrays' [`DataType`] must be equal and the types must have a natural order. -// This is a factory of comparisons. -pub fn build_compare(left: &dyn Array, right: &dyn Array) -> Result { - use DataType::*; - use IntervalUnit::*; - use TimeUnit::*; - Ok(match (left.data_type(), right.data_type()) { - (a, b) if a != b => { - polars_bail!(ComputeError: - "Can't compare arrays of different types".to_string(), - ); - }, - (Boolean, Boolean) => compare_boolean(left, right), - (UInt8, UInt8) => compare_primitives::(left, right), - (UInt16, UInt16) => compare_primitives::(left, right), - (UInt32, UInt32) => compare_primitives::(left, right), - (UInt64, UInt64) => compare_primitives::(left, right), - (Int8, Int8) => compare_primitives::(left, right), - (Int16, Int16) => compare_primitives::(left, right), - (Int32, Int32) - | (Date32, Date32) - | (Time32(Second), Time32(Second)) - | (Time32(Millisecond), Time32(Millisecond)) - | (Interval(YearMonth), Interval(YearMonth)) => compare_primitives::(left, right), - (Int64, Int64) - | (Date64, Date64) - | (Time64(Microsecond), Time64(Microsecond)) - | (Time64(Nanosecond), Time64(Nanosecond)) - | (Timestamp(Second, None), Timestamp(Second, None)) - | (Timestamp(Millisecond, None), Timestamp(Millisecond, None)) - | (Timestamp(Microsecond, None), Timestamp(Microsecond, None)) - | (Timestamp(Nanosecond, None), Timestamp(Nanosecond, None)) - | (Duration(Second), Duration(Second)) - | (Duration(Millisecond), Duration(Millisecond)) - | (Duration(Microsecond), Duration(Microsecond)) - | (Duration(Nanosecond), Duration(Nanosecond)) => compare_primitives::(left, right), - (Float32, Float32) => compare_primitives::(left, right), - (Float64, Float64) => compare_primitives::(left, right), - (Decimal(_, _), Decimal(_, _)) => compare_primitives::(left, right), - (Utf8, Utf8) => compare_string::(left, right), - (LargeUtf8, LargeUtf8) => compare_string::(left, right), - (Binary, Binary) => compare_binary::(left, right), - (LargeBinary, LargeBinary) => compare_binary::(left, right), - (Dictionary(key_type_lhs, ..), Dictionary(key_type_rhs, ..)) => { - match (key_type_lhs, key_type_rhs) { - (IntegerType::UInt8, IntegerType::UInt8) => dyn_dict!(u8, left, right), - (IntegerType::UInt16, IntegerType::UInt16) => dyn_dict!(u16, left, right), - (IntegerType::UInt32, IntegerType::UInt32) => dyn_dict!(u32, left, right), - (IntegerType::UInt64, IntegerType::UInt64) => dyn_dict!(u64, left, right), - (IntegerType::Int8, IntegerType::Int8) => dyn_dict!(i8, left, right), - (IntegerType::Int16, IntegerType::Int16) => dyn_dict!(i16, left, right), - (IntegerType::Int32, IntegerType::Int32) => dyn_dict!(i32, left, right), - (IntegerType::Int64, IntegerType::Int64) => dyn_dict!(i64, left, right), - (lhs, _) => { - return Err(Error::InvalidArgumentError(format!( - "Dictionaries do not support keys of type {lhs:?}" - ))) - }, - } - }, - _ => { - unimplemented!() - }, - }) -} diff --git a/crates/polars-arrow/src/compute/arithmetics/basic/add.rs b/crates/polars-arrow/src/compute/arithmetics/basic/add.rs index 4ac4fb8bd02f..ec941edc2381 100644 --- a/crates/polars-arrow/src/compute/arithmetics/basic/add.rs +++ b/crates/polars-arrow/src/compute/arithmetics/basic/add.rs @@ -1,33 +1,12 @@ //! Definition of basic add operations with primitive arrays use std::ops::Add; -use num_traits::ops::overflowing::OverflowingAdd; -use num_traits::{CheckedAdd, SaturatingAdd, WrappingAdd}; - use super::NativeArithmetics; use crate::array::PrimitiveArray; -use crate::bitmap::Bitmap; -use crate::compute::arithmetics::{ - ArrayAdd, ArrayCheckedAdd, ArrayOverflowingAdd, ArraySaturatingAdd, ArrayWrappingAdd, -}; -use crate::compute::arity::{ - binary, binary_checked, binary_with_bitmap, unary, unary_checked, unary_with_bitmap, -}; +use crate::compute::arity::{binary, unary}; /// Adds two primitive arrays with the same type. /// Panics if the sum of one pair of values overflows. -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::add; -/// use polars_arrow::array::PrimitiveArray; -/// -/// let a = PrimitiveArray::from([None, Some(6), None, Some(6)]); -/// let b = PrimitiveArray::from([Some(5), None, None, Some(6)]); -/// let result = add(&a, &b); -/// let expected = PrimitiveArray::from([None, None, None, Some(12)]); -/// assert_eq!(result, expected) -/// ``` pub fn add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray where T: NativeArithmetics + Add, @@ -35,166 +14,8 @@ where binary(lhs, rhs, lhs.data_type().clone(), |a, b| a + b) } -/// Wrapping addition of two [`PrimitiveArray`]s. -/// It wraps around at the boundary of the type if the result overflows. -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::wrapping_add; -/// use polars_arrow::array::PrimitiveArray; -/// -/// let a = PrimitiveArray::from([Some(-100i8), Some(100i8), Some(100i8)]); -/// let b = PrimitiveArray::from([Some(0i8), Some(100i8), Some(0i8)]); -/// let result = wrapping_add(&a, &b); -/// let expected = PrimitiveArray::from([Some(-100i8), Some(-56i8), Some(100i8)]); -/// assert_eq!(result, expected); -/// ``` -pub fn wrapping_add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray -where - T: NativeArithmetics + WrappingAdd, -{ - let op = move |a: T, b: T| a.wrapping_add(&b); - - binary(lhs, rhs, lhs.data_type().clone(), op) -} - -/// Checked addition of two primitive arrays. If the result from the sum -/// overflows, the validity for that index is changed to None -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::checked_add; -/// use polars_arrow::array::PrimitiveArray; -/// -/// let a = PrimitiveArray::from([Some(100i8), Some(100i8), Some(100i8)]); -/// let b = PrimitiveArray::from([Some(0i8), Some(100i8), Some(0i8)]); -/// let result = checked_add(&a, &b); -/// let expected = PrimitiveArray::from([Some(100i8), None, Some(100i8)]); -/// assert_eq!(result, expected); -/// ``` -pub fn checked_add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray -where - T: NativeArithmetics + CheckedAdd, -{ - let op = move |a: T, b: T| a.checked_add(&b); - - binary_checked(lhs, rhs, lhs.data_type().clone(), op) -} - -/// Saturating addition of two primitive arrays. If the result from the sum is -/// larger than the possible number for this type, the result for the operation -/// will be the saturated value. -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::saturating_add; -/// use polars_arrow::array::PrimitiveArray; -/// -/// let a = PrimitiveArray::from([Some(100i8)]); -/// let b = PrimitiveArray::from([Some(100i8)]); -/// let result = saturating_add(&a, &b); -/// let expected = PrimitiveArray::from([Some(127)]); -/// assert_eq!(result, expected); -/// ``` -pub fn saturating_add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray -where - T: NativeArithmetics + SaturatingAdd, -{ - let op = move |a: T, b: T| a.saturating_add(&b); - - binary(lhs, rhs, lhs.data_type().clone(), op) -} - -/// Overflowing addition of two primitive arrays. If the result from the sum is -/// larger than the possible number for this type, the result for the operation -/// will be an array with overflowed values and a validity array indicating -/// the overflowing elements from the array. -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::overflowing_add; -/// use polars_arrow::array::PrimitiveArray; -/// -/// let a = PrimitiveArray::from([Some(1i8), Some(100i8)]); -/// let b = PrimitiveArray::from([Some(1i8), Some(100i8)]); -/// let (result, overflow) = overflowing_add(&a, &b); -/// let expected = PrimitiveArray::from([Some(2i8), Some(-56i8)]); -/// assert_eq!(result, expected); -/// ``` -pub fn overflowing_add( - lhs: &PrimitiveArray, - rhs: &PrimitiveArray, -) -> (PrimitiveArray, Bitmap) -where - T: NativeArithmetics + OverflowingAdd, -{ - let op = move |a: T, b: T| a.overflowing_add(&b); - - binary_with_bitmap(lhs, rhs, lhs.data_type().clone(), op) -} - -// Implementation of ArrayAdd trait for PrimitiveArrays -impl ArrayAdd> for PrimitiveArray -where - T: NativeArithmetics + Add, -{ - fn add(&self, rhs: &PrimitiveArray) -> Self { - add(self, rhs) - } -} - -impl ArrayWrappingAdd> for PrimitiveArray -where - T: NativeArithmetics + WrappingAdd, -{ - fn wrapping_add(&self, rhs: &PrimitiveArray) -> Self { - wrapping_add(self, rhs) - } -} - -// Implementation of ArrayCheckedAdd trait for PrimitiveArrays -impl ArrayCheckedAdd> for PrimitiveArray -where - T: NativeArithmetics + CheckedAdd, -{ - fn checked_add(&self, rhs: &PrimitiveArray) -> Self { - checked_add(self, rhs) - } -} - -// Implementation of ArraySaturatingAdd trait for PrimitiveArrays -impl ArraySaturatingAdd> for PrimitiveArray -where - T: NativeArithmetics + SaturatingAdd, -{ - fn saturating_add(&self, rhs: &PrimitiveArray) -> Self { - saturating_add(self, rhs) - } -} - -// Implementation of ArraySaturatingAdd trait for PrimitiveArrays -impl ArrayOverflowingAdd> for PrimitiveArray -where - T: NativeArithmetics + OverflowingAdd, -{ - fn overflowing_add(&self, rhs: &PrimitiveArray) -> (Self, Bitmap) { - overflowing_add(self, rhs) - } -} - /// Adds a scalar T to a primitive array of type T. /// Panics if the sum of the values overflows. -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::add_scalar; -/// use polars_arrow::array::PrimitiveArray; -/// -/// let a = PrimitiveArray::from([None, Some(6), None, Some(6)]); -/// let result = add_scalar(&a, &1i32); -/// let expected = PrimitiveArray::from([None, Some(7), None, Some(7)]); -/// assert_eq!(result, expected) -/// ``` pub fn add_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray where T: NativeArithmetics + Add, @@ -202,136 +23,3 @@ where let rhs = *rhs; unary(lhs, |a| a + rhs, lhs.data_type().clone()) } - -/// Wrapping addition of a scalar T to a [`PrimitiveArray`] of type T. -/// It do nothing if the result overflows. -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::wrapping_add_scalar; -/// use polars_arrow::array::Int8Array; -/// -/// let a = Int8Array::from(&[None, Some(100)]); -/// let result = wrapping_add_scalar(&a, &100i8); -/// let expected = Int8Array::from(&[None, Some(-56)]); -/// assert_eq!(result, expected); -/// ``` -pub fn wrapping_add_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray -where - T: NativeArithmetics + WrappingAdd, -{ - unary(lhs, |a| a.wrapping_add(rhs), lhs.data_type().clone()) -} - -/// Checked addition of a scalar T to a primitive array of type T. If the -/// result from the sum overflows then the validity index for that value is -/// changed to None -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::checked_add_scalar; -/// use polars_arrow::array::Int8Array; -/// -/// let a = Int8Array::from(&[None, Some(100), None, Some(100)]); -/// let result = checked_add_scalar(&a, &100i8); -/// let expected = Int8Array::from(&[None, None, None, None]); -/// assert_eq!(result, expected); -/// ``` -pub fn checked_add_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray -where - T: NativeArithmetics + CheckedAdd, -{ - let rhs = *rhs; - let op = move |a: T| a.checked_add(&rhs); - - unary_checked(lhs, op, lhs.data_type().clone()) -} - -/// Saturated addition of a scalar T to a primitive array of type T. If the -/// result from the sum is larger than the possible number for this type, then -/// the result will be saturated -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::saturating_add_scalar; -/// use polars_arrow::array::PrimitiveArray; -/// -/// let a = PrimitiveArray::from([Some(100i8)]); -/// let result = saturating_add_scalar(&a, &100i8); -/// let expected = PrimitiveArray::from([Some(127)]); -/// assert_eq!(result, expected); -/// ``` -pub fn saturating_add_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray -where - T: NativeArithmetics + SaturatingAdd, -{ - let rhs = *rhs; - let op = move |a: T| a.saturating_add(&rhs); - - unary(lhs, op, lhs.data_type().clone()) -} - -/// Overflowing addition of a scalar T to a primitive array of type T. If the -/// result from the sum is larger than the possible number for this type, then -/// the result will be an array with overflowed values and a validity array -/// indicating the overflowing elements from the array -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::overflowing_add_scalar; -/// use polars_arrow::array::PrimitiveArray; -/// -/// let a = PrimitiveArray::from([Some(1i8), Some(100i8)]); -/// let (result, overflow) = overflowing_add_scalar(&a, &100i8); -/// let expected = PrimitiveArray::from([Some(101i8), Some(-56i8)]); -/// assert_eq!(result, expected); -/// ``` -pub fn overflowing_add_scalar(lhs: &PrimitiveArray, rhs: &T) -> (PrimitiveArray, Bitmap) -where - T: NativeArithmetics + OverflowingAdd, -{ - let rhs = *rhs; - let op = move |a: T| a.overflowing_add(&rhs); - - unary_with_bitmap(lhs, op, lhs.data_type().clone()) -} - -// Implementation of ArrayAdd trait for PrimitiveArrays with a scalar -impl ArrayAdd for PrimitiveArray -where - T: NativeArithmetics + Add, -{ - fn add(&self, rhs: &T) -> Self { - add_scalar(self, rhs) - } -} - -// Implementation of ArrayCheckedAdd trait for PrimitiveArrays with a scalar -impl ArrayCheckedAdd for PrimitiveArray -where - T: NativeArithmetics + CheckedAdd, -{ - fn checked_add(&self, rhs: &T) -> Self { - checked_add_scalar(self, rhs) - } -} - -// Implementation of ArraySaturatingAdd trait for PrimitiveArrays with a scalar -impl ArraySaturatingAdd for PrimitiveArray -where - T: NativeArithmetics + SaturatingAdd, -{ - fn saturating_add(&self, rhs: &T) -> Self { - saturating_add_scalar(self, rhs) - } -} - -// Implementation of ArraySaturatingAdd trait for PrimitiveArrays with a scalar -impl ArrayOverflowingAdd for PrimitiveArray -where - T: NativeArithmetics + OverflowingAdd, -{ - fn overflowing_add(&self, rhs: &T) -> (Self, Bitmap) { - overflowing_add_scalar(self, rhs) - } -} diff --git a/crates/polars-arrow/src/compute/arithmetics/basic/div.rs b/crates/polars-arrow/src/compute/arithmetics/basic/div.rs index 4b27001543e0..9b5220b1b1ef 100644 --- a/crates/polars-arrow/src/compute/arithmetics/basic/div.rs +++ b/crates/polars-arrow/src/compute/arithmetics/basic/div.rs @@ -8,7 +8,6 @@ use strength_reduce::{ use super::NativeArithmetics; use crate::array::{Array, PrimitiveArray}; -use crate::compute::arithmetics::{ArrayCheckedDiv, ArrayDiv}; use crate::compute::arity::{binary, binary_checked, unary, unary_checked}; use crate::compute::utils::check_same_len; use crate::datatypes::PrimitiveType; @@ -67,39 +66,8 @@ where binary_checked(lhs, rhs, lhs.data_type().clone(), op) } -// Implementation of ArrayDiv trait for PrimitiveArrays -impl ArrayDiv> for PrimitiveArray -where - T: NativeArithmetics + Div, -{ - fn div(&self, rhs: &PrimitiveArray) -> Self { - div(self, rhs) - } -} - -// Implementation of ArrayCheckedDiv trait for PrimitiveArrays -impl ArrayCheckedDiv> for PrimitiveArray -where - T: NativeArithmetics + CheckedDiv, -{ - fn checked_div(&self, rhs: &PrimitiveArray) -> Self { - checked_div(self, rhs) - } -} - /// Divide a primitive array of type T by a scalar T. /// Panics if the divisor is zero. -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::div_scalar; -/// use polars_arrow::array::Int32Array; -/// -/// let a = Int32Array::from(&[None, Some(6), None, Some(6)]); -/// let result = div_scalar(&a, &2i32); -/// let expected = Int32Array::from(&[None, Some(3), None, Some(3)]); -/// assert_eq!(result, expected) -/// ``` pub fn div_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray where T: NativeArithmetics + Div + NumCast, @@ -161,17 +129,6 @@ where /// Checked division of a primitive array of type T by a scalar T. If the /// divisor is zero then the validity array is changed to None. -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::checked_div_scalar; -/// use polars_arrow::array::Int8Array; -/// -/// let a = Int8Array::from(&[Some(-100i8)]); -/// let result = checked_div_scalar(&a, &100i8); -/// let expected = Int8Array::from(&[Some(-1i8)]); -/// assert_eq!(result, expected); -/// ``` pub fn checked_div_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray where T: NativeArithmetics + CheckedDiv, @@ -181,23 +138,3 @@ where unary_checked(lhs, op, lhs.data_type().clone()) } - -// Implementation of ArrayDiv trait for PrimitiveArrays with a scalar -impl ArrayDiv for PrimitiveArray -where - T: NativeArithmetics + Div + NumCast, -{ - fn div(&self, rhs: &T) -> Self { - div_scalar(self, rhs) - } -} - -// Implementation of ArrayCheckedDiv trait for PrimitiveArrays with a scalar -impl ArrayCheckedDiv for PrimitiveArray -where - T: NativeArithmetics + CheckedDiv, -{ - fn checked_div(&self, rhs: &T) -> Self { - checked_div_scalar(self, rhs) - } -} diff --git a/crates/polars-arrow/src/compute/arithmetics/basic/mod.rs b/crates/polars-arrow/src/compute/arithmetics/basic/mod.rs index b01e31c5a214..faa55af6bbd9 100644 --- a/crates/polars-arrow/src/compute/arithmetics/basic/mod.rs +++ b/crates/polars-arrow/src/compute/arithmetics/basic/mod.rs @@ -11,8 +11,6 @@ mod div; pub use div::*; mod mul; pub use mul::*; -mod pow; -pub use pow::*; mod rem; pub use rem::*; mod sub; diff --git a/crates/polars-arrow/src/compute/arithmetics/basic/mul.rs b/crates/polars-arrow/src/compute/arithmetics/basic/mul.rs index becdce1eba4a..a1ed463f0195 100644 --- a/crates/polars-arrow/src/compute/arithmetics/basic/mul.rs +++ b/crates/polars-arrow/src/compute/arithmetics/basic/mul.rs @@ -1,33 +1,12 @@ //! Definition of basic mul operations with primitive arrays use std::ops::Mul; -use num_traits::ops::overflowing::OverflowingMul; -use num_traits::{CheckedMul, SaturatingMul, WrappingMul}; - use super::NativeArithmetics; use crate::array::PrimitiveArray; -use crate::bitmap::Bitmap; -use crate::compute::arithmetics::{ - ArrayCheckedMul, ArrayMul, ArrayOverflowingMul, ArraySaturatingMul, ArrayWrappingMul, -}; -use crate::compute::arity::{ - binary, binary_checked, binary_with_bitmap, unary, unary_checked, unary_with_bitmap, -}; +use crate::compute::arity::{binary, unary}; /// Multiplies two primitive arrays with the same type. /// Panics if the multiplication of one pair of values overflows. -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::mul; -/// use polars_arrow::array::Int32Array; -/// -/// let a = Int32Array::from(&[None, Some(6), None, Some(6)]); -/// let b = Int32Array::from(&[Some(5), None, None, Some(6)]); -/// let result = mul(&a, &b); -/// let expected = Int32Array::from(&[None, None, None, Some(36)]); -/// assert_eq!(result, expected) -/// ``` pub fn mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray where T: NativeArithmetics + Mul, @@ -35,167 +14,8 @@ where binary(lhs, rhs, lhs.data_type().clone(), |a, b| a * b) } -/// Wrapping multiplication of two [`PrimitiveArray`]s. -/// It wraps around at the boundary of the type if the result overflows. -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::wrapping_mul; -/// use polars_arrow::array::PrimitiveArray; -/// -/// let a = PrimitiveArray::from([Some(100i8), Some(0x10i8), Some(100i8)]); -/// let b = PrimitiveArray::from([Some(0i8), Some(0x10i8), Some(0i8)]); -/// let result = wrapping_mul(&a, &b); -/// let expected = PrimitiveArray::from([Some(0), Some(0), Some(0)]); -/// assert_eq!(result, expected); -/// ``` -pub fn wrapping_mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray -where - T: NativeArithmetics + WrappingMul, -{ - let op = move |a: T, b: T| a.wrapping_mul(&b); - - binary(lhs, rhs, lhs.data_type().clone(), op) -} - -/// Checked multiplication of two primitive arrays. If the result from the -/// multiplications overflows, the validity for that index is changed -/// returned. -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::checked_mul; -/// use polars_arrow::array::Int8Array; -/// -/// let a = Int8Array::from(&[Some(100i8), Some(100i8), Some(100i8)]); -/// let b = Int8Array::from(&[Some(1i8), Some(100i8), Some(1i8)]); -/// let result = checked_mul(&a, &b); -/// let expected = Int8Array::from(&[Some(100i8), None, Some(100i8)]); -/// assert_eq!(result, expected); -/// ``` -pub fn checked_mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray -where - T: NativeArithmetics + CheckedMul, -{ - let op = move |a: T, b: T| a.checked_mul(&b); - - binary_checked(lhs, rhs, lhs.data_type().clone(), op) -} - -/// Saturating multiplication of two primitive arrays. If the result from the -/// multiplication overflows, the result for the -/// operation will be the saturated value. -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::saturating_mul; -/// use polars_arrow::array::Int8Array; -/// -/// let a = Int8Array::from(&[Some(-100i8)]); -/// let b = Int8Array::from(&[Some(100i8)]); -/// let result = saturating_mul(&a, &b); -/// let expected = Int8Array::from(&[Some(-128)]); -/// assert_eq!(result, expected); -/// ``` -pub fn saturating_mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray -where - T: NativeArithmetics + SaturatingMul, -{ - let op = move |a: T, b: T| a.saturating_mul(&b); - - binary(lhs, rhs, lhs.data_type().clone(), op) -} - -/// Overflowing multiplication of two primitive arrays. If the result from the -/// mul overflows, the result for the operation will be an array with -/// overflowed values and a validity array indicating the overflowing elements -/// from the array. -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::overflowing_mul; -/// use polars_arrow::array::Int8Array; -/// -/// let a = Int8Array::from(&[Some(1i8), Some(-100i8)]); -/// let b = Int8Array::from(&[Some(1i8), Some(100i8)]); -/// let (result, overflow) = overflowing_mul(&a, &b); -/// let expected = Int8Array::from(&[Some(1i8), Some(-16i8)]); -/// assert_eq!(result, expected); -/// ``` -pub fn overflowing_mul( - lhs: &PrimitiveArray, - rhs: &PrimitiveArray, -) -> (PrimitiveArray, Bitmap) -where - T: NativeArithmetics + OverflowingMul, -{ - let op = move |a: T, b: T| a.overflowing_mul(&b); - - binary_with_bitmap(lhs, rhs, lhs.data_type().clone(), op) -} - -// Implementation of ArrayMul trait for PrimitiveArrays -impl ArrayMul> for PrimitiveArray -where - T: NativeArithmetics + Mul, -{ - fn mul(&self, rhs: &PrimitiveArray) -> Self { - mul(self, rhs) - } -} - -impl ArrayWrappingMul> for PrimitiveArray -where - T: NativeArithmetics + WrappingMul, -{ - fn wrapping_mul(&self, rhs: &PrimitiveArray) -> Self { - wrapping_mul(self, rhs) - } -} - -// Implementation of ArrayCheckedMul trait for PrimitiveArrays -impl ArrayCheckedMul> for PrimitiveArray -where - T: NativeArithmetics + CheckedMul, -{ - fn checked_mul(&self, rhs: &PrimitiveArray) -> Self { - checked_mul(self, rhs) - } -} - -// Implementation of ArraySaturatingMul trait for PrimitiveArrays -impl ArraySaturatingMul> for PrimitiveArray -where - T: NativeArithmetics + SaturatingMul, -{ - fn saturating_mul(&self, rhs: &PrimitiveArray) -> Self { - saturating_mul(self, rhs) - } -} - -// Implementation of ArraySaturatingMul trait for PrimitiveArrays -impl ArrayOverflowingMul> for PrimitiveArray -where - T: NativeArithmetics + OverflowingMul, -{ - fn overflowing_mul(&self, rhs: &PrimitiveArray) -> (Self, Bitmap) { - overflowing_mul(self, rhs) - } -} - /// Multiply a scalar T to a primitive array of type T. /// Panics if the multiplication of the values overflows. -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::mul_scalar; -/// use polars_arrow::array::Int32Array; -/// -/// let a = Int32Array::from(&[None, Some(6), None, Some(6)]); -/// let result = mul_scalar(&a, &2i32); -/// let expected = Int32Array::from(&[None, Some(12), None, Some(12)]); -/// assert_eq!(result, expected) -/// ``` pub fn mul_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray where T: NativeArithmetics + Mul, @@ -203,136 +23,3 @@ where let rhs = *rhs; unary(lhs, |a| a * rhs, lhs.data_type().clone()) } - -/// Wrapping multiplication of a scalar T to a [`PrimitiveArray`] of type T. -/// It do nothing if the result overflows. -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::wrapping_mul_scalar; -/// use polars_arrow::array::Int8Array; -/// -/// let a = Int8Array::from(&[None, Some(0x10)]); -/// let result = wrapping_mul_scalar(&a, &0x10); -/// let expected = Int8Array::from(&[None, Some(0)]); -/// assert_eq!(result, expected); -/// ``` -pub fn wrapping_mul_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray -where - T: NativeArithmetics + WrappingMul, -{ - unary(lhs, |a| a.wrapping_mul(rhs), lhs.data_type().clone()) -} - -/// Checked multiplication of a scalar T to a primitive array of type T. If the -/// result from the multiplication overflows, then the validity for that index is -/// changed to None -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::checked_mul_scalar; -/// use polars_arrow::array::Int8Array; -/// -/// let a = Int8Array::from(&[None, Some(100), None, Some(100)]); -/// let result = checked_mul_scalar(&a, &100i8); -/// let expected = Int8Array::from(&[None, None, None, None]); -/// assert_eq!(result, expected); -/// ``` -pub fn checked_mul_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray -where - T: NativeArithmetics + CheckedMul, -{ - let rhs = *rhs; - let op = move |a: T| a.checked_mul(&rhs); - - unary_checked(lhs, op, lhs.data_type().clone()) -} - -/// Saturated multiplication of a scalar T to a primitive array of type T. If the -/// result from the mul overflows for this type, then -/// the result will be saturated -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::saturating_mul_scalar; -/// use polars_arrow::array::Int8Array; -/// -/// let a = Int8Array::from(&[Some(-100i8)]); -/// let result = saturating_mul_scalar(&a, &100i8); -/// let expected = Int8Array::from(&[Some(-128i8)]); -/// assert_eq!(result, expected); -/// ``` -pub fn saturating_mul_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray -where - T: NativeArithmetics + SaturatingMul, -{ - let rhs = *rhs; - let op = move |a: T| a.saturating_mul(&rhs); - - unary(lhs, op, lhs.data_type().clone()) -} - -/// Overflowing multiplication of a scalar T to a primitive array of type T. If -/// the result from the mul overflows for this type, -/// then the result will be an array with overflowed values and a validity -/// array indicating the overflowing elements from the array -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::overflowing_mul_scalar; -/// use polars_arrow::array::Int8Array; -/// -/// let a = Int8Array::from(&[Some(1i8), Some(100i8)]); -/// let (result, overflow) = overflowing_mul_scalar(&a, &100i8); -/// let expected = Int8Array::from(&[Some(100i8), Some(16i8)]); -/// assert_eq!(result, expected); -/// ``` -pub fn overflowing_mul_scalar(lhs: &PrimitiveArray, rhs: &T) -> (PrimitiveArray, Bitmap) -where - T: NativeArithmetics + OverflowingMul, -{ - let rhs = *rhs; - let op = move |a: T| a.overflowing_mul(&rhs); - - unary_with_bitmap(lhs, op, lhs.data_type().clone()) -} - -// Implementation of ArrayMul trait for PrimitiveArrays with a scalar -impl ArrayMul for PrimitiveArray -where - T: NativeArithmetics + Mul, -{ - fn mul(&self, rhs: &T) -> Self { - mul_scalar(self, rhs) - } -} - -// Implementation of ArrayCheckedMul trait for PrimitiveArrays with a scalar -impl ArrayCheckedMul for PrimitiveArray -where - T: NativeArithmetics + CheckedMul, -{ - fn checked_mul(&self, rhs: &T) -> Self { - checked_mul_scalar(self, rhs) - } -} - -// Implementation of ArraySaturatingMul trait for PrimitiveArrays with a scalar -impl ArraySaturatingMul for PrimitiveArray -where - T: NativeArithmetics + SaturatingMul, -{ - fn saturating_mul(&self, rhs: &T) -> Self { - saturating_mul_scalar(self, rhs) - } -} - -// Implementation of ArraySaturatingMul trait for PrimitiveArrays with a scalar -impl ArrayOverflowingMul for PrimitiveArray -where - T: NativeArithmetics + OverflowingMul, -{ - fn overflowing_mul(&self, rhs: &T) -> (Self, Bitmap) { - overflowing_mul_scalar(self, rhs) - } -} diff --git a/crates/polars-arrow/src/compute/arithmetics/basic/pow.rs b/crates/polars-arrow/src/compute/arithmetics/basic/pow.rs deleted file mode 100644 index 173c4a351aa5..000000000000 --- a/crates/polars-arrow/src/compute/arithmetics/basic/pow.rs +++ /dev/null @@ -1,49 +0,0 @@ -//! Definition of basic pow operations with primitive arrays -use num_traits::{checked_pow, CheckedMul, One, Pow}; - -use super::NativeArithmetics; -use crate::array::PrimitiveArray; -use crate::compute::arity::{unary, unary_checked}; - -/// Raises an array of primitives to the power of exponent. Panics if one of -/// the values values overflows. -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::powf_scalar; -/// use polars_arrow::array::Float32Array; -/// -/// let a = Float32Array::from(&[Some(2f32), None]); -/// let actual = powf_scalar(&a, 2.0); -/// let expected = Float32Array::from(&[Some(4f32), None]); -/// assert_eq!(expected, actual); -/// ``` -pub fn powf_scalar(array: &PrimitiveArray, exponent: T) -> PrimitiveArray -where - T: NativeArithmetics + Pow, -{ - unary(array, |x| x.pow(exponent), array.data_type().clone()) -} - -/// Checked operation of raising an array of primitives to the power of -/// exponent. If the result from the multiplications overflows, the validity -/// for that index is changed returned. -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::checked_powf_scalar; -/// use polars_arrow::array::Int8Array; -/// -/// let a = Int8Array::from(&[Some(1i8), None, Some(7i8)]); -/// let actual = checked_powf_scalar(&a, 8usize); -/// let expected = Int8Array::from(&[Some(1i8), None, None]); -/// assert_eq!(expected, actual); -/// ``` -pub fn checked_powf_scalar(array: &PrimitiveArray, exponent: usize) -> PrimitiveArray -where - T: NativeArithmetics + CheckedMul + One, -{ - let op = move |a: T| checked_pow(a, exponent); - - unary_checked(array, op, array.data_type().clone()) -} diff --git a/crates/polars-arrow/src/compute/arithmetics/basic/rem.rs b/crates/polars-arrow/src/compute/arithmetics/basic/rem.rs index d0ac512b5604..46eeb16cb8c6 100644 --- a/crates/polars-arrow/src/compute/arithmetics/basic/rem.rs +++ b/crates/polars-arrow/src/compute/arithmetics/basic/rem.rs @@ -1,30 +1,17 @@ use std::ops::Rem; -use num_traits::{CheckedRem, NumCast}; +use num_traits::NumCast; use strength_reduce::{ StrengthReducedU16, StrengthReducedU32, StrengthReducedU64, StrengthReducedU8, }; use super::NativeArithmetics; use crate::array::{Array, PrimitiveArray}; -use crate::compute::arithmetics::{ArrayCheckedRem, ArrayRem}; -use crate::compute::arity::{binary, binary_checked, unary, unary_checked}; +use crate::compute::arity::{binary, unary}; use crate::datatypes::PrimitiveType; /// Remainder of two primitive arrays with the same type. /// Panics if the divisor is zero of one pair of values overflows. -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::rem; -/// use polars_arrow::array::Int32Array; -/// -/// let a = Int32Array::from(&[Some(10), Some(7)]); -/// let b = Int32Array::from(&[Some(5), Some(6)]); -/// let result = rem(&a, &b); -/// let expected = Int32Array::from(&[Some(0), Some(1)]); -/// assert_eq!(result, expected) -/// ``` pub fn rem(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray where T: NativeArithmetics + Rem, @@ -32,61 +19,8 @@ where binary(lhs, rhs, lhs.data_type().clone(), |a, b| a % b) } -/// Checked remainder of two primitive arrays. If the result from the remainder -/// overflows, the result for the operation will change the validity array -/// making this operation None -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::checked_rem; -/// use polars_arrow::array::Int8Array; -/// -/// let a = Int8Array::from(&[Some(-100i8), Some(10i8)]); -/// let b = Int8Array::from(&[Some(100i8), Some(0i8)]); -/// let result = checked_rem(&a, &b); -/// let expected = Int8Array::from(&[Some(-0i8), None]); -/// assert_eq!(result, expected); -/// ``` -pub fn checked_rem(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray -where - T: NativeArithmetics + CheckedRem, -{ - let op = move |a: T, b: T| a.checked_rem(&b); - - binary_checked(lhs, rhs, lhs.data_type().clone(), op) -} - -impl ArrayRem> for PrimitiveArray -where - T: NativeArithmetics + Rem, -{ - fn rem(&self, rhs: &PrimitiveArray) -> Self { - rem(self, rhs) - } -} - -impl ArrayCheckedRem> for PrimitiveArray -where - T: NativeArithmetics + CheckedRem, -{ - fn checked_rem(&self, rhs: &PrimitiveArray) -> Self { - checked_rem(self, rhs) - } -} - /// Remainder a primitive array of type T by a scalar T. /// Panics if the divisor is zero. -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::rem_scalar; -/// use polars_arrow::array::Int32Array; -/// -/// let a = Int32Array::from(&[None, Some(6), None, Some(7)]); -/// let result = rem_scalar(&a, &2i32); -/// let expected = Int32Array::from(&[None, Some(0), None, Some(1)]); -/// assert_eq!(result, expected) -/// ``` pub fn rem_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray where T: NativeArithmetics + Rem + NumCast, @@ -153,44 +87,3 @@ where _ => unary(lhs, |a| a % rhs, lhs.data_type().clone()), } } - -/// Checked remainder of a primitive array of type T by a scalar T. If the -/// divisor is zero then the validity array is changed to None. -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::checked_rem_scalar; -/// use polars_arrow::array::Int8Array; -/// -/// let a = Int8Array::from(&[Some(-100i8)]); -/// let result = checked_rem_scalar(&a, &100i8); -/// let expected = Int8Array::from(&[Some(0i8)]); -/// assert_eq!(result, expected); -/// ``` -pub fn checked_rem_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray -where - T: NativeArithmetics + CheckedRem, -{ - let rhs = *rhs; - let op = move |a: T| a.checked_rem(&rhs); - - unary_checked(lhs, op, lhs.data_type().clone()) -} - -impl ArrayRem for PrimitiveArray -where - T: NativeArithmetics + Rem + NumCast, -{ - fn rem(&self, rhs: &T) -> Self { - rem_scalar(self, rhs) - } -} - -impl ArrayCheckedRem for PrimitiveArray -where - T: NativeArithmetics + CheckedRem, -{ - fn checked_rem(&self, rhs: &T) -> Self { - checked_rem_scalar(self, rhs) - } -} diff --git a/crates/polars-arrow/src/compute/arithmetics/basic/sub.rs b/crates/polars-arrow/src/compute/arithmetics/basic/sub.rs index 43f267c6bf13..33acb99b3ef6 100644 --- a/crates/polars-arrow/src/compute/arithmetics/basic/sub.rs +++ b/crates/polars-arrow/src/compute/arithmetics/basic/sub.rs @@ -1,33 +1,12 @@ //! Definition of basic sub operations with primitive arrays use std::ops::Sub; -use num_traits::ops::overflowing::OverflowingSub; -use num_traits::{CheckedSub, SaturatingSub, WrappingSub}; - use super::NativeArithmetics; use crate::array::PrimitiveArray; -use crate::bitmap::Bitmap; -use crate::compute::arithmetics::{ - ArrayCheckedSub, ArrayOverflowingSub, ArraySaturatingSub, ArraySub, ArrayWrappingSub, -}; -use crate::compute::arity::{ - binary, binary_checked, binary_with_bitmap, unary, unary_checked, unary_with_bitmap, -}; +use crate::compute::arity::{binary, unary}; /// Subtracts two primitive arrays with the same type. /// Panics if the subtraction of one pair of values overflows. -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::sub; -/// use polars_arrow::array::Int32Array; -/// -/// let a = Int32Array::from(&[None, Some(6), None, Some(6)]); -/// let b = Int32Array::from(&[Some(5), None, None, Some(6)]); -/// let result = sub(&a, &b); -/// let expected = Int32Array::from(&[None, None, None, Some(0)]); -/// assert_eq!(result, expected) -/// ``` pub fn sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray where T: NativeArithmetics + Sub, @@ -35,166 +14,8 @@ where binary(lhs, rhs, lhs.data_type().clone(), |a, b| a - b) } -/// Wrapping subtraction of two [`PrimitiveArray`]s. -/// It wraps around at the boundary of the type if the result overflows. -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::wrapping_sub; -/// use polars_arrow::array::PrimitiveArray; -/// -/// let a = PrimitiveArray::from([Some(-100i8), Some(-100i8), Some(100i8)]); -/// let b = PrimitiveArray::from([Some(0i8), Some(100i8), Some(0i8)]); -/// let result = wrapping_sub(&a, &b); -/// let expected = PrimitiveArray::from([Some(-100i8), Some(56i8), Some(100i8)]); -/// assert_eq!(result, expected); -/// ``` -pub fn wrapping_sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray -where - T: NativeArithmetics + WrappingSub, -{ - let op = move |a: T, b: T| a.wrapping_sub(&b); - - binary(lhs, rhs, lhs.data_type().clone(), op) -} - -/// Checked subtraction of two primitive arrays. If the result from the -/// subtraction overflow, the validity for that index is changed -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::checked_sub; -/// use polars_arrow::array::Int8Array; -/// -/// let a = Int8Array::from(&[Some(100i8), Some(-100i8), Some(100i8)]); -/// let b = Int8Array::from(&[Some(1i8), Some(100i8), Some(0i8)]); -/// let result = checked_sub(&a, &b); -/// let expected = Int8Array::from(&[Some(99i8), None, Some(100i8)]); -/// assert_eq!(result, expected); -/// ``` -pub fn checked_sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray -where - T: NativeArithmetics + CheckedSub, -{ - let op = move |a: T, b: T| a.checked_sub(&b); - - binary_checked(lhs, rhs, lhs.data_type().clone(), op) -} - -/// Saturating subtraction of two primitive arrays. If the result from the sub -/// is smaller than the possible number for this type, the result for the -/// operation will be the saturated value. -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::saturating_sub; -/// use polars_arrow::array::Int8Array; -/// -/// let a = Int8Array::from(&[Some(-100i8)]); -/// let b = Int8Array::from(&[Some(100i8)]); -/// let result = saturating_sub(&a, &b); -/// let expected = Int8Array::from(&[Some(-128)]); -/// assert_eq!(result, expected); -/// ``` -pub fn saturating_sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray -where - T: NativeArithmetics + SaturatingSub, -{ - let op = move |a: T, b: T| a.saturating_sub(&b); - - binary(lhs, rhs, lhs.data_type().clone(), op) -} - -/// Overflowing subtraction of two primitive arrays. If the result from the sub -/// is smaller than the possible number for this type, the result for the -/// operation will be an array with overflowed values and a validity array -/// indicating the overflowing elements from the array. -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::overflowing_sub; -/// use polars_arrow::array::Int8Array; -/// -/// let a = Int8Array::from(&[Some(1i8), Some(-100i8)]); -/// let b = Int8Array::from(&[Some(1i8), Some(100i8)]); -/// let (result, overflow) = overflowing_sub(&a, &b); -/// let expected = Int8Array::from(&[Some(0i8), Some(56i8)]); -/// assert_eq!(result, expected); -/// ``` -pub fn overflowing_sub( - lhs: &PrimitiveArray, - rhs: &PrimitiveArray, -) -> (PrimitiveArray, Bitmap) -where - T: NativeArithmetics + OverflowingSub, -{ - let op = move |a: T, b: T| a.overflowing_sub(&b); - - binary_with_bitmap(lhs, rhs, lhs.data_type().clone(), op) -} - -// Implementation of ArraySub trait for PrimitiveArrays -impl ArraySub> for PrimitiveArray -where - T: NativeArithmetics + Sub, -{ - fn sub(&self, rhs: &PrimitiveArray) -> Self { - sub(self, rhs) - } -} - -impl ArrayWrappingSub> for PrimitiveArray -where - T: NativeArithmetics + WrappingSub, -{ - fn wrapping_sub(&self, rhs: &PrimitiveArray) -> Self { - wrapping_sub(self, rhs) - } -} - -// Implementation of ArrayCheckedSub trait for PrimitiveArrays -impl ArrayCheckedSub> for PrimitiveArray -where - T: NativeArithmetics + CheckedSub, -{ - fn checked_sub(&self, rhs: &PrimitiveArray) -> Self { - checked_sub(self, rhs) - } -} - -// Implementation of ArraySaturatingSub trait for PrimitiveArrays -impl ArraySaturatingSub> for PrimitiveArray -where - T: NativeArithmetics + SaturatingSub, -{ - fn saturating_sub(&self, rhs: &PrimitiveArray) -> Self { - saturating_sub(self, rhs) - } -} - -// Implementation of ArraySaturatingSub trait for PrimitiveArrays -impl ArrayOverflowingSub> for PrimitiveArray -where - T: NativeArithmetics + OverflowingSub, -{ - fn overflowing_sub(&self, rhs: &PrimitiveArray) -> (Self, Bitmap) { - overflowing_sub(self, rhs) - } -} - /// Subtract a scalar T to a primitive array of type T. /// Panics if the subtraction of the values overflows. -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::sub_scalar; -/// use polars_arrow::array::Int32Array; -/// -/// let a = Int32Array::from(&[None, Some(6), None, Some(6)]); -/// let result = sub_scalar(&a, &1i32); -/// let expected = Int32Array::from(&[None, Some(5), None, Some(5)]); -/// assert_eq!(result, expected) -/// ``` pub fn sub_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray where T: NativeArithmetics + Sub, @@ -202,136 +23,3 @@ where let rhs = *rhs; unary(lhs, |a| a - rhs, lhs.data_type().clone()) } - -/// Wrapping subtraction of a scalar T to a [`PrimitiveArray`] of type T. -/// It do nothing if the result overflows. -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::wrapping_sub_scalar; -/// use polars_arrow::array::Int8Array; -/// -/// let a = Int8Array::from(&[None, Some(-100)]); -/// let result = wrapping_sub_scalar(&a, &100i8); -/// let expected = Int8Array::from(&[None, Some(56)]); -/// assert_eq!(result, expected); -/// ``` -pub fn wrapping_sub_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray -where - T: NativeArithmetics + WrappingSub, -{ - unary(lhs, |a| a.wrapping_sub(rhs), lhs.data_type().clone()) -} - -/// Checked subtraction of a scalar T to a primitive array of type T. If the -/// result from the subtraction overflows, then the validity for that index -/// is changed to None -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::checked_sub_scalar; -/// use polars_arrow::array::Int8Array; -/// -/// let a = Int8Array::from(&[None, Some(-100), None, Some(-100)]); -/// let result = checked_sub_scalar(&a, &100i8); -/// let expected = Int8Array::from(&[None, None, None, None]); -/// assert_eq!(result, expected); -/// ``` -pub fn checked_sub_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray -where - T: NativeArithmetics + CheckedSub, -{ - let rhs = *rhs; - let op = move |a: T| a.checked_sub(&rhs); - - unary_checked(lhs, op, lhs.data_type().clone()) -} - -/// Saturated subtraction of a scalar T to a primitive array of type T. If the -/// result from the sub is smaller than the possible number for this type, then -/// the result will be saturated -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::saturating_sub_scalar; -/// use polars_arrow::array::Int8Array; -/// -/// let a = Int8Array::from(&[Some(-100i8)]); -/// let result = saturating_sub_scalar(&a, &100i8); -/// let expected = Int8Array::from(&[Some(-128i8)]); -/// assert_eq!(result, expected); -/// ``` -pub fn saturating_sub_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray -where - T: NativeArithmetics + SaturatingSub, -{ - let rhs = *rhs; - let op = move |a: T| a.saturating_sub(&rhs); - - unary(lhs, op, lhs.data_type().clone()) -} - -/// Overflowing subtraction of a scalar T to a primitive array of type T. If -/// the result from the sub is smaller than the possible number for this type, -/// then the result will be an array with overflowed values and a validity -/// array indicating the overflowing elements from the array -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::basic::overflowing_sub_scalar; -/// use polars_arrow::array::Int8Array; -/// -/// let a = Int8Array::from(&[Some(1i8), Some(-100i8)]); -/// let (result, overflow) = overflowing_sub_scalar(&a, &100i8); -/// let expected = Int8Array::from(&[Some(-99i8), Some(56i8)]); -/// assert_eq!(result, expected); -/// ``` -pub fn overflowing_sub_scalar(lhs: &PrimitiveArray, rhs: &T) -> (PrimitiveArray, Bitmap) -where - T: NativeArithmetics + OverflowingSub, -{ - let rhs = *rhs; - let op = move |a: T| a.overflowing_sub(&rhs); - - unary_with_bitmap(lhs, op, lhs.data_type().clone()) -} - -// Implementation of ArraySub trait for PrimitiveArrays with a scalar -impl ArraySub for PrimitiveArray -where - T: NativeArithmetics + Sub, -{ - fn sub(&self, rhs: &T) -> Self { - sub_scalar(self, rhs) - } -} - -// Implementation of ArrayCheckedSub trait for PrimitiveArrays with a scalar -impl ArrayCheckedSub for PrimitiveArray -where - T: NativeArithmetics + CheckedSub, -{ - fn checked_sub(&self, rhs: &T) -> Self { - checked_sub_scalar(self, rhs) - } -} - -// Implementation of ArraySaturatingSub trait for PrimitiveArrays with a scalar -impl ArraySaturatingSub for PrimitiveArray -where - T: NativeArithmetics + SaturatingSub, -{ - fn saturating_sub(&self, rhs: &T) -> Self { - saturating_sub_scalar(self, rhs) - } -} - -// Implementation of ArraySaturatingSub trait for PrimitiveArrays with a scalar -impl ArrayOverflowingSub for PrimitiveArray -where - T: NativeArithmetics + OverflowingSub, -{ - fn overflowing_sub(&self, rhs: &T) -> (Self, Bitmap) { - overflowing_sub_scalar(self, rhs) - } -} diff --git a/crates/polars-arrow/src/compute/arithmetics/decimal/add.rs b/crates/polars-arrow/src/compute/arithmetics/decimal/add.rs deleted file mode 100644 index 63f912e59e60..000000000000 --- a/crates/polars-arrow/src/compute/arithmetics/decimal/add.rs +++ /dev/null @@ -1,220 +0,0 @@ -//! Defines the addition arithmetic kernels for [`PrimitiveArray`] representing decimals. -use polars_error::{polars_bail, PolarsResult}; - -use super::{adjusted_precision_scale, get_parameters, max_value, number_digits}; -use crate::array::PrimitiveArray; -use crate::compute::arithmetics::{ArrayAdd, ArrayCheckedAdd, ArraySaturatingAdd}; -use crate::compute::arity::{binary, binary_checked}; -use crate::compute::utils::{check_same_len, combine_validities}; -use crate::datatypes::DataType; - -/// Adds two decimal [`PrimitiveArray`] with the same precision and scale. -/// # Error -/// Errors if the precision and scale are different. -/// # Panic -/// This function panics iff the added numbers result in a number larger than -/// the possible number for the precision. -pub fn add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { - let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); - - let max = max_value(precision); - let op = move |a, b| { - let res: i128 = a + b; - - assert!( - res.abs() <= max, - "Overflow in addition presented for precision {precision}" - ); - - res - }; - - binary(lhs, rhs, lhs.data_type().clone(), op) -} - -/// Saturated addition of two decimal primitive arrays with the same precision -/// and scale. If the precision and scale is different, then an -/// InvalidArgumentError is returned. If the result from the sum is larger than -/// the possible number with the selected precision then the resulted number in -/// the arrow array is the maximum number for the selected precision. -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::decimal::saturating_add; -/// use polars_arrow::array::PrimitiveArray; -/// use polars_arrow::datatypes::DataType; -/// -/// let a = PrimitiveArray::from([Some(99000i128), Some(11100i128), None, Some(22200i128)]).to(DataType::Decimal(5, 2)); -/// let b = PrimitiveArray::from([Some(01000i128), Some(22200i128), None, Some(11100i128)]).to(DataType::Decimal(5, 2)); -/// -/// let result = saturating_add(&a, &b); -/// let expected = PrimitiveArray::from([Some(99999i128), Some(33300i128), None, Some(33300i128)]).to(DataType::Decimal(5, 2)); -/// -/// assert_eq!(result, expected); -/// ``` -pub fn saturating_add( - lhs: &PrimitiveArray, - rhs: &PrimitiveArray, -) -> PrimitiveArray { - let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); - - let max = max_value(precision); - let op = move |a, b| { - let res: i128 = a + b; - - if res.abs() > max { - if res > 0 { - max - } else { - -max - } - } else { - res - } - }; - - binary(lhs, rhs, lhs.data_type().clone(), op) -} - -/// Checked addition of two decimal primitive arrays with the same precision -/// and scale. If the precision and scale is different, then an -/// InvalidArgumentError is returned. If the result from the sum is larger than -/// the possible number with the selected precision (overflowing), then the -/// validity for that index is changed to None -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::decimal::checked_add; -/// use polars_arrow::array::PrimitiveArray; -/// use polars_arrow::datatypes::DataType; -/// -/// let a = PrimitiveArray::from([Some(99000i128), Some(11100i128), None, Some(22200i128)]).to(DataType::Decimal(5, 2)); -/// let b = PrimitiveArray::from([Some(01000i128), Some(22200i128), None, Some(11100i128)]).to(DataType::Decimal(5, 2)); -/// -/// let result = checked_add(&a, &b); -/// let expected = PrimitiveArray::from([None, Some(33300i128), None, Some(33300i128)]).to(DataType::Decimal(5, 2)); -/// -/// assert_eq!(result, expected); -/// ``` -pub fn checked_add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { - let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); - - let max = max_value(precision); - let op = move |a, b| { - let result: i128 = a + b; - - if result.abs() > max { - None - } else { - Some(result) - } - }; - - binary_checked(lhs, rhs, lhs.data_type().clone(), op) -} - -// Implementation of ArrayAdd trait for PrimitiveArrays -impl ArrayAdd> for PrimitiveArray { - fn add(&self, rhs: &PrimitiveArray) -> Self { - add(self, rhs) - } -} - -// Implementation of ArrayCheckedAdd trait for PrimitiveArrays -impl ArrayCheckedAdd> for PrimitiveArray { - fn checked_add(&self, rhs: &PrimitiveArray) -> Self { - checked_add(self, rhs) - } -} - -// Implementation of ArraySaturatingAdd trait for PrimitiveArrays -impl ArraySaturatingAdd> for PrimitiveArray { - fn saturating_add(&self, rhs: &PrimitiveArray) -> Self { - saturating_add(self, rhs) - } -} - -/// Adaptive addition of two decimal primitive arrays with different precision -/// and scale. If the precision and scale is different, then the smallest scale -/// and precision is adjusted to the largest precision and scale. If during the -/// addition one of the results is larger than the max possible value, the -/// result precision is changed to the precision of the max value -/// -/// ```nocode -/// 11111.11 -> 7, 2 -/// 11111.111 -> 8, 3 -/// ------------------ -/// 22222.221 -> 8, 3 -/// ``` -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::decimal::adaptive_add; -/// use polars_arrow::array::PrimitiveArray; -/// use polars_arrow::datatypes::DataType; -/// -/// let a = PrimitiveArray::from([Some(11111_11i128)]).to(DataType::Decimal(7, 2)); -/// let b = PrimitiveArray::from([Some(11111_111i128)]).to(DataType::Decimal(8, 3)); -/// let result = adaptive_add(&a, &b).unwrap(); -/// let expected = PrimitiveArray::from([Some(22222_221i128)]).to(DataType::Decimal(8, 3)); -/// -/// assert_eq!(result, expected); -/// ``` -pub fn adaptive_add( - lhs: &PrimitiveArray, - rhs: &PrimitiveArray, -) -> PolarsResult> { - check_same_len(lhs, rhs)?; - - let (lhs_p, lhs_s, rhs_p, rhs_s) = - if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) = - (lhs.data_type(), rhs.data_type()) - { - (*lhs_p, *lhs_s, *rhs_p, *rhs_s) - } else { - polars_bail!(ComputeError: "Incorrect data type for the array") - }; - - // The resulting precision is mutable because it could change while - // looping through the iterator - let (mut res_p, res_s, diff) = adjusted_precision_scale(lhs_p, lhs_s, rhs_p, rhs_s); - - let shift = 10i128.pow(diff as u32); - let mut max = max_value(res_p); - - let values = lhs - .values() - .iter() - .zip(rhs.values().iter()) - .map(|(l, r)| { - // Based on the array's scales one of the arguments in the sum has to be shifted - // to the left to match the final scale - let res = if lhs_s > rhs_s { - l + r * shift - } else { - l * shift + r - }; - - // The precision of the resulting array will change if one of the - // sums during the iteration produces a value bigger than the - // possible value for the initial precision - - // 99.9999 -> 6, 4 - // 00.0001 -> 6, 4 - // ----------------- - // 100.0000 -> 7, 4 - if res.abs() > max { - res_p = number_digits(res); - max = max_value(res_p); - } - res - }) - .collect::>(); - - let validity = combine_validities(lhs.validity(), rhs.validity()); - - Ok(PrimitiveArray::::new( - DataType::Decimal(res_p, res_s), - values.into(), - validity, - )) -} diff --git a/crates/polars-arrow/src/compute/arithmetics/decimal/div.rs b/crates/polars-arrow/src/compute/arithmetics/decimal/div.rs deleted file mode 100644 index 6516717e3239..000000000000 --- a/crates/polars-arrow/src/compute/arithmetics/decimal/div.rs +++ /dev/null @@ -1,301 +0,0 @@ -//! Defines the division arithmetic kernels for Decimal -//! `PrimitiveArrays`. - -use polars_error::{polars_bail, PolarsResult}; - -use super::{adjusted_precision_scale, get_parameters, max_value, number_digits}; -use crate::array::PrimitiveArray; -use crate::compute::arithmetics::{ArrayCheckedDiv, ArrayDiv}; -use crate::compute::arity::{binary, binary_checked, unary}; -use crate::compute::utils::{check_same_len, combine_validities}; -use crate::datatypes::DataType; -use crate::scalar::{PrimitiveScalar, Scalar}; - -/// Divide two decimal primitive arrays with the same precision and scale. If -/// the precision and scale is different, then an InvalidArgumentError is -/// returned. This function panics if the dividend is divided by 0 or None. -/// This function also panics if the division produces a number larger -/// than the possible number for the array precision. -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::decimal::div; -/// use polars_arrow::array::PrimitiveArray; -/// use polars_arrow::datatypes::DataType; -/// -/// let a = PrimitiveArray::from([Some(1_00i128), Some(4_00i128), Some(6_00i128)]).to(DataType::Decimal(5, 2)); -/// let b = PrimitiveArray::from([Some(1_00i128), Some(2_00i128), Some(2_00i128)]).to(DataType::Decimal(5, 2)); -/// -/// let result = div(&a, &b); -/// let expected = PrimitiveArray::from([Some(1_00i128), Some(2_00i128), Some(3_00i128)]).to(DataType::Decimal(5, 2)); -/// -/// assert_eq!(result, expected); -/// ``` -pub fn div(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { - let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); - - let scale = 10i128.pow(scale as u32); - let max = max_value(precision); - let op = move |a: i128, b: i128| { - // The division is done using the numbers without scale. - // The dividend is scaled up to maintain precision after the - // division - - // 222.222 --> 222222000 - // 123.456 --> 123456 - // -------- --------- - // 1.800 <-- 1800 - let numeral: i128 = a * scale; - - // The division can overflow if the dividend is divided - // by zero. - let res: i128 = numeral.checked_div(b).expect("Found division by zero"); - - assert!( - res.abs() <= max, - "Overflow in multiplication presented for precision {precision}" - ); - - res - }; - - binary(lhs, rhs, lhs.data_type().clone(), op) -} - -/// Multiply a decimal [`PrimitiveArray`] with a [`PrimitiveScalar`] with the same precision and scale. If -/// the precision and scale is different, then an InvalidArgumentError is -/// returned. This function panics if the multiplied numbers result in a number -/// larger than the possible number for the selected precision. -pub fn div_scalar(lhs: &PrimitiveArray, rhs: &PrimitiveScalar) -> PrimitiveArray { - let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); - - let rhs = if let Some(rhs) = *rhs.value() { - rhs - } else { - return PrimitiveArray::::new_null(lhs.data_type().clone(), lhs.len()); - }; - - let scale = 10i128.pow(scale as u32); - let max = max_value(precision); - - let op = move |a: i128| { - // The division is done using the numbers without scale. - // The dividend is scaled up to maintain precision after the - // division - - // 222.222 --> 222222000 - // 123.456 --> 123456 - // -------- --------- - // 1.800 <-- 1800 - let numeral: i128 = a * scale; - - // The division can overflow if the dividend is divided - // by zero. - let res: i128 = numeral.checked_div(rhs).expect("Found division by zero"); - - assert!( - res.abs() <= max, - "Overflow in multiplication presented for precision {precision}" - ); - - res - }; - - unary(lhs, op, lhs.data_type().clone()) -} - -/// Saturated division of two decimal primitive arrays with the same -/// precision and scale. If the precision and scale is different, then an -/// InvalidArgumentError is returned. If the result from the division is -/// larger than the possible number with the selected precision then the -/// resulted number in the arrow array is the maximum number for the selected -/// precision. The function panics if divided by zero. -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::decimal::saturating_div; -/// use polars_arrow::array::PrimitiveArray; -/// use polars_arrow::datatypes::DataType; -/// -/// let a = PrimitiveArray::from([Some(999_99i128), Some(4_00i128), Some(6_00i128)]).to(DataType::Decimal(5, 2)); -/// let b = PrimitiveArray::from([Some(000_01i128), Some(2_00i128), Some(2_00i128)]).to(DataType::Decimal(5, 2)); -/// -/// let result = saturating_div(&a, &b); -/// let expected = PrimitiveArray::from([Some(999_99i128), Some(2_00i128), Some(3_00i128)]).to(DataType::Decimal(5, 2)); -/// -/// assert_eq!(result, expected); -/// ``` -pub fn saturating_div( - lhs: &PrimitiveArray, - rhs: &PrimitiveArray, -) -> PrimitiveArray { - let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); - - let scale = 10i128.pow(scale as u32); - let max = max_value(precision); - - let op = move |a: i128, b: i128| { - let numeral: i128 = a * scale; - - match numeral.checked_div(b) { - Some(res) => match res { - res if res.abs() > max => { - if res > 0 { - max - } else { - -max - } - }, - _ => res, - }, - None => 0, - } - }; - - binary(lhs, rhs, lhs.data_type().clone(), op) -} - -/// Checked division of two decimal primitive arrays with the same precision -/// and scale. If the precision and scale is different, then an -/// InvalidArgumentError is returned. If the divisor is zero, then the -/// validity for that index is changed to None -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::decimal::checked_div; -/// use polars_arrow::array::PrimitiveArray; -/// use polars_arrow::datatypes::DataType; -/// -/// let a = PrimitiveArray::from([Some(1_00i128), Some(4_00i128), Some(6_00i128)]).to(DataType::Decimal(5, 2)); -/// let b = PrimitiveArray::from([Some(000_00i128), None, Some(2_00i128)]).to(DataType::Decimal(5, 2)); -/// -/// let result = checked_div(&a, &b); -/// let expected = PrimitiveArray::from([None, None, Some(3_00i128)]).to(DataType::Decimal(5, 2)); -/// -/// assert_eq!(result, expected); -/// ``` -pub fn checked_div(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { - let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); - - let scale = 10i128.pow(scale as u32); - let max = max_value(precision); - - let op = move |a: i128, b: i128| { - let numeral: i128 = a * scale; - - match numeral.checked_div(b) { - Some(res) => match res { - res if res.abs() > max => None, - _ => Some(res), - }, - None => None, - } - }; - - binary_checked(lhs, rhs, lhs.data_type().clone(), op) -} - -// Implementation of ArrayDiv trait for PrimitiveArrays -impl ArrayDiv> for PrimitiveArray { - fn div(&self, rhs: &PrimitiveArray) -> Self { - div(self, rhs) - } -} - -// Implementation of ArrayCheckedDiv trait for PrimitiveArrays -impl ArrayCheckedDiv> for PrimitiveArray { - fn checked_div(&self, rhs: &PrimitiveArray) -> Self { - checked_div(self, rhs) - } -} - -/// Adaptive division of two decimal primitive arrays with different precision -/// and scale. If the precision and scale is different, then the smallest scale -/// and precision is adjusted to the largest precision and scale. If during the -/// division one of the results is larger than the max possible value, the -/// result precision is changed to the precision of the max value. The function -/// panics when divided by zero. -/// -/// ```nocode -/// 1000.00 -> 7, 2 -/// 10.0000 -> 6, 4 -/// ----------------- -/// 100.0000 -> 9, 4 -/// ``` -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::decimal::adaptive_div; -/// use polars_arrow::array::PrimitiveArray; -/// use polars_arrow::datatypes::DataType; -/// -/// let a = PrimitiveArray::from([Some(1000_00i128)]).to(DataType::Decimal(7, 2)); -/// let b = PrimitiveArray::from([Some(10_0000i128)]).to(DataType::Decimal(6, 4)); -/// let result = adaptive_div(&a, &b).unwrap(); -/// let expected = PrimitiveArray::from([Some(100_0000i128)]).to(DataType::Decimal(9, 4)); -/// -/// assert_eq!(result, expected); -/// ``` -pub fn adaptive_div( - lhs: &PrimitiveArray, - rhs: &PrimitiveArray, -) -> PolarsResult> { - check_same_len(lhs, rhs)?; - - let (lhs_p, lhs_s, rhs_p, rhs_s) = - if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) = - (lhs.data_type(), rhs.data_type()) - { - (*lhs_p, *lhs_s, *rhs_p, *rhs_s) - } else { - polars_bail!(ComputeError: "Incorrect data type for the array") - }; - - // The resulting precision is mutable because it could change while - // looping through the iterator - let (mut res_p, res_s, diff) = adjusted_precision_scale(lhs_p, lhs_s, rhs_p, rhs_s); - - let shift = 10i128.pow(diff as u32); - let shift_1 = 10i128.pow(res_s as u32); - let mut max = max_value(res_p); - - let values = lhs - .values() - .iter() - .zip(rhs.values().iter()) - .map(|(l, r)| { - let numeral: i128 = l * shift_1; - - // Based on the array's scales one of the arguments in the sum has to be shifted - // to the left to match the final scale - let res = if lhs_s > rhs_s { - numeral.checked_div(r * shift) - } else { - (numeral * shift).checked_div(*r) - } - .expect("Found division by zero"); - - // The precision of the resulting array will change if one of the - // multiplications during the iteration produces a value bigger - // than the possible value for the initial precision - - // 10.0000 -> 6, 4 - // 00.1000 -> 6, 4 - // ----------------- - // 100.0000 -> 7, 4 - if res.abs() > max { - res_p = number_digits(res); - max = max_value(res_p); - } - - res - }) - .collect::>(); - - let validity = combine_validities(lhs.validity(), rhs.validity()); - - Ok(PrimitiveArray::::new( - DataType::Decimal(res_p, res_s), - values.into(), - validity, - )) -} diff --git a/crates/polars-arrow/src/compute/arithmetics/decimal/mod.rs b/crates/polars-arrow/src/compute/arithmetics/decimal/mod.rs deleted file mode 100644 index d0cabb7d359a..000000000000 --- a/crates/polars-arrow/src/compute/arithmetics/decimal/mod.rs +++ /dev/null @@ -1,120 +0,0 @@ -//! Defines the arithmetic kernels for Decimal `PrimitiveArrays`. The -//! [`Decimal`](crate::datatypes::DataType::Decimal) type specifies the -//! precision and scale parameters. These affect the arithmetic operations and -//! need to be considered while doing operations with Decimal numbers. - -mod add; -pub use add::*; -mod div; -pub use div::*; -mod mul; -pub use mul::*; -use polars_error::{polars_bail, PolarsResult}; - -mod sub; -pub use sub::*; - -use crate::datatypes::DataType; - -/// Maximum value that can exist with a selected precision -#[inline] -fn max_value(precision: usize) -> i128 { - 10i128.pow(precision as u32) - 1 -} - -// Calculates the number of digits in a i128 number -fn number_digits(num: i128) -> usize { - let mut num = num.abs(); - let mut digit: i128 = 0; - let base = 10i128; - - while num != 0 { - num /= base; - digit += 1; - } - - digit as usize -} - -fn get_parameters(lhs: &DataType, rhs: &DataType) -> PolarsResult<(usize, usize)> { - if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) = - (lhs.to_logical_type(), rhs.to_logical_type()) - { - if lhs_p == rhs_p && lhs_s == rhs_s { - Ok((*lhs_p, *lhs_s)) - } else { - polars_bail!(InvalidOperation: - "Arrays must have the same precision and scale" - ) - } - } else { - unreachable!() - } -} - -/// Returns the adjusted precision and scale for the lhs and rhs precision and -/// scale -fn adjusted_precision_scale( - lhs_p: usize, - lhs_s: usize, - rhs_p: usize, - rhs_s: usize, -) -> (usize, usize, usize) { - // The initial new precision and scale is based on the number of digits - // that lhs and rhs number has before and after the point. The max - // number of digits before and after the point will make the last - // precision and scale of the result - - // Digits before/after point - // before after - // 11.1111 -> 5, 4 -> 2 4 - // 11111.01 -> 7, 2 -> 5 2 - // ----------------- - // 11122.1211 -> 9, 4 -> 5 4 - let lhs_digits_before = lhs_p - lhs_s; - let rhs_digits_before = rhs_p - rhs_s; - - let res_digits_before = std::cmp::max(lhs_digits_before, rhs_digits_before); - - let (res_s, diff) = if lhs_s > rhs_s { - (lhs_s, lhs_s - rhs_s) - } else { - (rhs_s, rhs_s - lhs_s) - }; - - let res_p = res_digits_before + res_s; - - (res_p, res_s, diff) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_max_value() { - assert_eq!(999, max_value(3)); - assert_eq!(99999, max_value(5)); - assert_eq!(999999, max_value(6)); - } - - #[test] - fn test_number_digits() { - assert_eq!(2, number_digits(12i128)); - assert_eq!(3, number_digits(123i128)); - assert_eq!(4, number_digits(1234i128)); - assert_eq!(6, number_digits(123456i128)); - assert_eq!(7, number_digits(1234567i128)); - assert_eq!(7, number_digits(-1234567i128)); - assert_eq!(3, number_digits(-123i128)); - } - - #[test] - fn test_adjusted_precision_scale() { - // 11.1111 -> 5, 4 -> 2 4 - // 11111.01 -> 7, 2 -> 5 2 - // ----------------- - // 11122.1211 -> 9, 4 -> 5 4 - assert_eq!((9, 4, 2), adjusted_precision_scale(5, 4, 7, 2)) - } -} diff --git a/crates/polars-arrow/src/compute/arithmetics/decimal/mul.rs b/crates/polars-arrow/src/compute/arithmetics/decimal/mul.rs deleted file mode 100644 index 698b47717ffb..000000000000 --- a/crates/polars-arrow/src/compute/arithmetics/decimal/mul.rs +++ /dev/null @@ -1,313 +0,0 @@ -//! Defines the multiplication arithmetic kernels for Decimal -//! `PrimitiveArrays`. - -use polars_error::{polars_bail, PolarsResult}; - -use super::{adjusted_precision_scale, get_parameters, max_value, number_digits}; -use crate::array::PrimitiveArray; -use crate::compute::arithmetics::{ArrayCheckedMul, ArrayMul, ArraySaturatingMul}; -use crate::compute::arity::{binary, binary_checked, unary}; -use crate::compute::utils::{check_same_len, combine_validities}; -use crate::datatypes::DataType; -use crate::scalar::{PrimitiveScalar, Scalar}; - -/// Multiply two decimal primitive arrays with the same precision and scale. If -/// the precision and scale is different, then an InvalidArgumentError is -/// returned. This function panics if the multiplied numbers result in a number -/// larger than the possible number for the selected precision. -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::decimal::mul; -/// use polars_arrow::array::PrimitiveArray; -/// use polars_arrow::datatypes::DataType; -/// -/// let a = PrimitiveArray::from([Some(1_00i128), Some(1_00i128), None, Some(2_00i128)]).to(DataType::Decimal(5, 2)); -/// let b = PrimitiveArray::from([Some(1_00i128), Some(2_00i128), None, Some(2_00i128)]).to(DataType::Decimal(5, 2)); -/// -/// let result = mul(&a, &b); -/// let expected = PrimitiveArray::from([Some(1_00i128), Some(2_00i128), None, Some(4_00i128)]).to(DataType::Decimal(5, 2)); -/// -/// assert_eq!(result, expected); -/// ``` -pub fn mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { - let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); - - let scale = 10i128.pow(scale as u32); - let max = max_value(precision); - - let op = move |a: i128, b: i128| { - // The multiplication between i128 can overflow if they are - // very large numbers. For that reason a checked - // multiplication is used. - let res: i128 = a.checked_mul(b).expect("Mayor overflow for multiplication"); - - // The multiplication is done using the numbers without scale. - // The resulting scale of the value has to be corrected by - // dividing by (10^scale) - - // 111.111 --> 111111 - // 222.222 --> 222222 - // -------- ------- - // 24691.308 <-- 24691308642 - let res = res / scale; - - assert!( - res.abs() <= max, - "Overflow in multiplication presented for precision {precision}" - ); - - res - }; - - binary(lhs, rhs, lhs.data_type().clone(), op) -} - -/// Multiply a decimal [`PrimitiveArray`] with a [`PrimitiveScalar`] with the same precision and scale. If -/// the precision and scale is different, then an InvalidArgumentError is -/// returned. This function panics if the multiplied numbers result in a number -/// larger than the possible number for the selected precision. -pub fn mul_scalar(lhs: &PrimitiveArray, rhs: &PrimitiveScalar) -> PrimitiveArray { - let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); - - let rhs = if let Some(rhs) = *rhs.value() { - rhs - } else { - return PrimitiveArray::::new_null(lhs.data_type().clone(), lhs.len()); - }; - - let scale = 10i128.pow(scale as u32); - let max = max_value(precision); - - let op = move |a: i128| { - // The multiplication between i128 can overflow if they are - // very large numbers. For that reason a checked - // multiplication is used. - let res: i128 = a - .checked_mul(rhs) - .expect("Mayor overflow for multiplication"); - - // The multiplication is done using the numbers without scale. - // The resulting scale of the value has to be corrected by - // dividing by (10^scale) - - // 111.111 --> 111111 - // 222.222 --> 222222 - // -------- ------- - // 24691.308 <-- 24691308642 - let res = res / scale; - - assert!( - res.abs() <= max, - "Overflow in multiplication presented for precision {precision}" - ); - - res - }; - - unary(lhs, op, lhs.data_type().clone()) -} - -/// Saturated multiplication of two decimal primitive arrays with the same -/// precision and scale. If the precision and scale is different, then an -/// InvalidArgumentError is returned. If the result from the multiplication is -/// larger than the possible number with the selected precision then the -/// resulted number in the arrow array is the maximum number for the selected -/// precision. -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::decimal::saturating_mul; -/// use polars_arrow::array::PrimitiveArray; -/// use polars_arrow::datatypes::DataType; -/// -/// let a = PrimitiveArray::from([Some(999_99i128), Some(1_00i128), None, Some(2_00i128)]).to(DataType::Decimal(5, 2)); -/// let b = PrimitiveArray::from([Some(10_00i128), Some(2_00i128), None, Some(2_00i128)]).to(DataType::Decimal(5, 2)); -/// -/// let result = saturating_mul(&a, &b); -/// let expected = PrimitiveArray::from([Some(999_99i128), Some(2_00i128), None, Some(4_00i128)]).to(DataType::Decimal(5, 2)); -/// -/// assert_eq!(result, expected); -/// ``` -pub fn saturating_mul( - lhs: &PrimitiveArray, - rhs: &PrimitiveArray, -) -> PrimitiveArray { - let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); - - let scale = 10i128.pow(scale as u32); - let max = max_value(precision); - - let op = move |a: i128, b: i128| match a.checked_mul(b) { - Some(res) => { - let res = res / scale; - - match res { - res if res.abs() > max => { - if res > 0 { - max - } else { - -max - } - }, - _ => res, - } - }, - None => max, - }; - - binary(lhs, rhs, lhs.data_type().clone(), op) -} - -/// Checked multiplication of two decimal primitive arrays with the same -/// precision and scale. If the precision and scale is different, then an -/// InvalidArgumentError is returned. If the result from the mul is larger than -/// the possible number with the selected precision (overflowing), then the -/// validity for that index is changed to None -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::decimal::checked_mul; -/// use polars_arrow::array::PrimitiveArray; -/// use polars_arrow::datatypes::DataType; -/// -/// let a = PrimitiveArray::from([Some(999_99i128), Some(1_00i128), None, Some(2_00i128)]).to(DataType::Decimal(5, 2)); -/// let b = PrimitiveArray::from([Some(10_00i128), Some(2_00i128), None, Some(2_00i128)]).to(DataType::Decimal(5, 2)); -/// -/// let result = checked_mul(&a, &b); -/// let expected = PrimitiveArray::from([None, Some(2_00i128), None, Some(4_00i128)]).to(DataType::Decimal(5, 2)); -/// -/// assert_eq!(result, expected); -/// ``` -pub fn checked_mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { - let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); - - let scale = 10i128.pow(scale as u32); - let max = max_value(precision); - - let op = move |a: i128, b: i128| match a.checked_mul(b) { - Some(res) => { - let res = res / scale; - - match res { - res if res.abs() > max => None, - _ => Some(res), - } - }, - None => None, - }; - - binary_checked(lhs, rhs, lhs.data_type().clone(), op) -} - -// Implementation of ArrayMul trait for PrimitiveArrays -impl ArrayMul> for PrimitiveArray { - fn mul(&self, rhs: &PrimitiveArray) -> Self { - mul(self, rhs) - } -} - -// Implementation of ArrayCheckedMul trait for PrimitiveArrays -impl ArrayCheckedMul> for PrimitiveArray { - fn checked_mul(&self, rhs: &PrimitiveArray) -> Self { - checked_mul(self, rhs) - } -} - -// Implementation of ArraySaturatingMul trait for PrimitiveArrays -impl ArraySaturatingMul> for PrimitiveArray { - fn saturating_mul(&self, rhs: &PrimitiveArray) -> Self { - saturating_mul(self, rhs) - } -} - -/// Adaptive multiplication of two decimal primitive arrays with different -/// precision and scale. If the precision and scale is different, then the -/// smallest scale and precision is adjusted to the largest precision and -/// scale. If during the multiplication one of the results is larger than the -/// max possible value, the result precision is changed to the precision of the -/// max value -/// -/// ```nocode -/// 11111.0 -> 6, 1 -/// 10.002 -> 5, 3 -/// ----------------- -/// 111132.222 -> 9, 3 -/// ``` -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::decimal::adaptive_mul; -/// use polars_arrow::array::PrimitiveArray; -/// use polars_arrow::datatypes::DataType; -/// -/// let a = PrimitiveArray::from([Some(11111_0i128), Some(1_0i128)]).to(DataType::Decimal(6, 1)); -/// let b = PrimitiveArray::from([Some(10_002i128), Some(2_000i128)]).to(DataType::Decimal(5, 3)); -/// let result = adaptive_mul(&a, &b).unwrap(); -/// let expected = PrimitiveArray::from([Some(111132_222i128), Some(2_000i128)]).to(DataType::Decimal(9, 3)); -/// -/// assert_eq!(result, expected); -/// ``` -pub fn adaptive_mul( - lhs: &PrimitiveArray, - rhs: &PrimitiveArray, -) -> PolarsResult> { - check_same_len(lhs, rhs)?; - - let (lhs_p, lhs_s, rhs_p, rhs_s) = - if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) = - (lhs.data_type(), rhs.data_type()) - { - (*lhs_p, *lhs_s, *rhs_p, *rhs_s) - } else { - polars_bail!(ComputeError: "Incorrect data type for the array") - }; - - // The resulting precision is mutable because it could change while - // looping through the iterator - let (mut res_p, res_s, diff) = adjusted_precision_scale(lhs_p, lhs_s, rhs_p, rhs_s); - - let shift = 10i128.pow(diff as u32); - let shift_1 = 10i128.pow(res_s as u32); - let mut max = max_value(res_p); - - let values = lhs - .values() - .iter() - .zip(rhs.values().iter()) - .map(|(l, r)| { - // Based on the array's scales one of the arguments in the sum has to be shifted - // to the left to match the final scale - let res = if lhs_s > rhs_s { - l.checked_mul(r * shift) - } else { - (l * shift).checked_mul(*r) - } - .expect("Mayor overflow for multiplication"); - - let res = res / shift_1; - - // The precision of the resulting array will change if one of the - // multiplications during the iteration produces a value bigger - // than the possible value for the initial precision - - // 10.0000 -> 6, 4 - // 10.0000 -> 6, 4 - // ----------------- - // 100.0000 -> 7, 4 - if res.abs() > max { - res_p = number_digits(res); - max = max_value(res_p); - } - - res - }) - .collect::>(); - - let validity = combine_validities(lhs.validity(), rhs.validity()); - - Ok(PrimitiveArray::::new( - DataType::Decimal(res_p, res_s), - values.into(), - validity, - )) -} diff --git a/crates/polars-arrow/src/compute/arithmetics/decimal/sub.rs b/crates/polars-arrow/src/compute/arithmetics/decimal/sub.rs deleted file mode 100644 index 73840acc34b4..000000000000 --- a/crates/polars-arrow/src/compute/arithmetics/decimal/sub.rs +++ /dev/null @@ -1,237 +0,0 @@ -//! Defines the subtract arithmetic kernels for Decimal `PrimitiveArrays`. - -use polars_error::{polars_bail, PolarsResult}; - -use super::{adjusted_precision_scale, get_parameters, max_value, number_digits}; -use crate::array::PrimitiveArray; -use crate::compute::arithmetics::{ArrayCheckedSub, ArraySaturatingSub, ArraySub}; -use crate::compute::arity::{binary, binary_checked}; -use crate::compute::utils::{check_same_len, combine_validities}; -use crate::datatypes::DataType; - -/// Subtract two decimal primitive arrays with the same precision and scale. If -/// the precision and scale is different, then an InvalidArgumentError is -/// returned. This function panics if the subtracted numbers result in a number -/// smaller than the possible number for the selected precision. -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::decimal::sub; -/// use polars_arrow::array::PrimitiveArray; -/// use polars_arrow::datatypes::DataType; -/// -/// let a = PrimitiveArray::from([Some(1i128), Some(1i128), None, Some(2i128)]).to(DataType::Decimal(5, 2)); -/// let b = PrimitiveArray::from([Some(1i128), Some(2i128), None, Some(2i128)]).to(DataType::Decimal(5, 2)); -/// -/// let result = sub(&a, &b); -/// let expected = PrimitiveArray::from([Some(0i128), Some(-1i128), None, Some(0i128)]).to(DataType::Decimal(5, 2)); -/// -/// assert_eq!(result, expected); -/// ``` -pub fn sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { - let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); - - let max = max_value(precision); - - let op = move |a, b| { - let res: i128 = a - b; - - assert!( - res.abs() <= max, - "Overflow in subtract presented for precision {precision}" - ); - - res - }; - - binary(lhs, rhs, lhs.data_type().clone(), op) -} - -/// Saturated subtraction of two decimal primitive arrays with the same -/// precision and scale. If the precision and scale is different, then an -/// InvalidArgumentError is returned. If the result from the sum is smaller -/// than the possible number with the selected precision then the resulted -/// number in the arrow array is the minimum number for the selected precision. -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::decimal::saturating_sub; -/// use polars_arrow::array::PrimitiveArray; -/// use polars_arrow::datatypes::DataType; -/// -/// let a = PrimitiveArray::from([Some(-99000i128), Some(11100i128), None, Some(22200i128)]).to(DataType::Decimal(5, 2)); -/// let b = PrimitiveArray::from([Some(01000i128), Some(22200i128), None, Some(11100i128)]).to(DataType::Decimal(5, 2)); -/// -/// let result = saturating_sub(&a, &b); -/// let expected = PrimitiveArray::from([Some(-99999i128), Some(-11100i128), None, Some(11100i128)]).to(DataType::Decimal(5, 2)); -/// -/// assert_eq!(result, expected); -/// ``` -pub fn saturating_sub( - lhs: &PrimitiveArray, - rhs: &PrimitiveArray, -) -> PrimitiveArray { - let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); - - let max = max_value(precision); - - let op = move |a, b| { - let res: i128 = a - b; - - match res { - res if res.abs() > max => { - if res > 0 { - max - } else { - -max - } - }, - _ => res, - } - }; - - binary(lhs, rhs, lhs.data_type().clone(), op) -} - -// Implementation of ArraySub trait for PrimitiveArrays -impl ArraySub> for PrimitiveArray { - fn sub(&self, rhs: &PrimitiveArray) -> Self { - sub(self, rhs) - } -} - -// Implementation of ArrayCheckedSub trait for PrimitiveArrays -impl ArrayCheckedSub> for PrimitiveArray { - fn checked_sub(&self, rhs: &PrimitiveArray) -> Self { - checked_sub(self, rhs) - } -} - -// Implementation of ArraySaturatingSub trait for PrimitiveArrays -impl ArraySaturatingSub> for PrimitiveArray { - fn saturating_sub(&self, rhs: &PrimitiveArray) -> Self { - saturating_sub(self, rhs) - } -} -/// Checked subtract of two decimal primitive arrays with the same precision -/// and scale. If the precision and scale is different, then an -/// InvalidArgumentError is returned. If the result from the sub is larger than -/// the possible number with the selected precision (overflowing), then the -/// validity for that index is changed to None -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::decimal::checked_sub; -/// use polars_arrow::array::PrimitiveArray; -/// use polars_arrow::datatypes::DataType; -/// -/// let a = PrimitiveArray::from([Some(-99000i128), Some(11100i128), None, Some(22200i128)]).to(DataType::Decimal(5, 2)); -/// let b = PrimitiveArray::from([Some(01000i128), Some(22200i128), None, Some(11100i128)]).to(DataType::Decimal(5, 2)); -/// -/// let result = checked_sub(&a, &b); -/// let expected = PrimitiveArray::from([None, Some(-11100i128), None, Some(11100i128)]).to(DataType::Decimal(5, 2)); -/// -/// assert_eq!(result, expected); -/// ``` -pub fn checked_sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { - let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); - - let max = max_value(precision); - - let op = move |a, b| { - let res: i128 = a - b; - - match res { - res if res.abs() > max => None, - _ => Some(res), - } - }; - - binary_checked(lhs, rhs, lhs.data_type().clone(), op) -} - -/// Adaptive subtract of two decimal primitive arrays with different precision -/// and scale. If the precision and scale is different, then the smallest scale -/// and precision is adjusted to the largest precision and scale. If during the -/// addition one of the results is smaller than the min possible value, the -/// result precision is changed to the precision of the min value -/// -/// ```nocode -/// 99.9999 -> 6, 4 -/// -00.0001 -> 6, 4 -/// ----------------- -/// 100.0000 -> 7, 4 -/// ``` -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::decimal::adaptive_sub; -/// use polars_arrow::array::PrimitiveArray; -/// use polars_arrow::datatypes::DataType; -/// -/// let a = PrimitiveArray::from([Some(99_9999i128)]).to(DataType::Decimal(6, 4)); -/// let b = PrimitiveArray::from([Some(-00_0001i128)]).to(DataType::Decimal(6, 4)); -/// let result = adaptive_sub(&a, &b).unwrap(); -/// let expected = PrimitiveArray::from([Some(100_0000i128)]).to(DataType::Decimal(7, 4)); -/// -/// assert_eq!(result, expected); -/// ``` -pub fn adaptive_sub( - lhs: &PrimitiveArray, - rhs: &PrimitiveArray, -) -> PolarsResult> { - check_same_len(lhs, rhs)?; - - let (lhs_p, lhs_s, rhs_p, rhs_s) = - if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) = - (lhs.data_type(), rhs.data_type()) - { - (*lhs_p, *lhs_s, *rhs_p, *rhs_s) - } else { - polars_bail!(ComputeError: "Incorrect data type for the array") - }; - - // The resulting precision is mutable because it could change while - // looping through the iterator - let (mut res_p, res_s, diff) = adjusted_precision_scale(lhs_p, lhs_s, rhs_p, rhs_s); - - let shift = 10i128.pow(diff as u32); - let mut max = max_value(res_p); - - let values = lhs - .values() - .iter() - .zip(rhs.values().iter()) - .map(|(l, r)| { - // Based on the array's scales one of the arguments in the sum has to be shifted - // to the left to match the final scale - let res: i128 = if lhs_s > rhs_s { - l - r * shift - } else { - l * shift - r - }; - - // The precision of the resulting array will change if one of the - // subtraction during the iteration produces a value bigger than the - // possible value for the initial precision - - // -99.9999 -> 6, 4 - // 00.0001 -> 6, 4 - // ----------------- - // -100.0000 -> 7, 4 - if res.abs() > max { - res_p = number_digits(res); - max = max_value(res_p); - } - - res - }) - .collect::>(); - - let validity = combine_validities(lhs.validity(), rhs.validity()); - - Ok(PrimitiveArray::::new( - DataType::Decimal(res_p, res_s), - values.into(), - validity, - )) -} diff --git a/crates/polars-arrow/src/compute/arithmetics/mod.rs b/crates/polars-arrow/src/compute/arithmetics/mod.rs index c25b265b1a35..38883ee044cf 100644 --- a/crates/polars-arrow/src/compute/arithmetics/mod.rs +++ b/crates/polars-arrow/src/compute/arithmetics/mod.rs @@ -1,132 +1 @@ -//! Defines basic arithmetic kernels for [`PrimitiveArray`](crate::array::PrimitiveArray)s. -//! -//! The Arithmetics module is composed by basic arithmetics operations that can -//! be performed on [`PrimitiveArray`](crate::array::PrimitiveArray). -//! -//! Whenever possible, each operation declares variations -//! of the basic operation that offers different guarantees: -//! * plain: panics on overflowing and underflowing. -//! * checked: turns an overflowing to a null. -//! * saturating: turns the overflowing to the MAX or MIN value respectively. -//! * overflowing: returns an extra [`Bitmap`] denoting whether the operation overflowed. -//! * adaptive: for [`Decimal`](crate::datatypes::DataType::Decimal) only, -//! adjusts the precision and scale to make the resulting value fit. -#[forbid(unsafe_code)] pub mod basic; -#[cfg(feature = "compute_arithmetics_decimal")] -pub mod decimal; - -use crate::bitmap::Bitmap; - -pub trait ArrayAdd: Sized { - /// Adds itself to `rhs` - fn add(&self, rhs: &Rhs) -> Self; -} - -/// Defines wrapping addition operation for primitive arrays -pub trait ArrayWrappingAdd: Sized { - /// Adds itself to `rhs` using wrapping addition - fn wrapping_add(&self, rhs: &Rhs) -> Self; -} - -/// Defines checked addition operation for primitive arrays -pub trait ArrayCheckedAdd: Sized { - /// Checked add - fn checked_add(&self, rhs: &Rhs) -> Self; -} - -/// Defines saturating addition operation for primitive arrays -pub trait ArraySaturatingAdd: Sized { - /// Saturating add - fn saturating_add(&self, rhs: &Rhs) -> Self; -} - -/// Defines Overflowing addition operation for primitive arrays -pub trait ArrayOverflowingAdd: Sized { - /// Overflowing add - fn overflowing_add(&self, rhs: &Rhs) -> (Self, Bitmap); -} - -/// Defines basic subtraction operation for primitive arrays -pub trait ArraySub: Sized { - /// subtraction - fn sub(&self, rhs: &Rhs) -> Self; -} - -/// Defines wrapping subtraction operation for primitive arrays -pub trait ArrayWrappingSub: Sized { - /// wrapping subtraction - fn wrapping_sub(&self, rhs: &Rhs) -> Self; -} - -/// Defines checked subtraction operation for primitive arrays -pub trait ArrayCheckedSub: Sized { - /// checked subtraction - fn checked_sub(&self, rhs: &Rhs) -> Self; -} - -/// Defines saturating subtraction operation for primitive arrays -pub trait ArraySaturatingSub: Sized { - /// saturarting subtraction - fn saturating_sub(&self, rhs: &Rhs) -> Self; -} - -/// Defines Overflowing subtraction operation for primitive arrays -pub trait ArrayOverflowingSub: Sized { - /// overflowing subtraction - fn overflowing_sub(&self, rhs: &Rhs) -> (Self, Bitmap); -} - -/// Defines basic multiplication operation for primitive arrays -pub trait ArrayMul: Sized { - /// multiplication - fn mul(&self, rhs: &Rhs) -> Self; -} - -/// Defines wrapping multiplication operation for primitive arrays -pub trait ArrayWrappingMul: Sized { - /// wrapping multiplication - fn wrapping_mul(&self, rhs: &Rhs) -> Self; -} - -/// Defines checked multiplication operation for primitive arrays -pub trait ArrayCheckedMul: Sized { - /// checked multiplication - fn checked_mul(&self, rhs: &Rhs) -> Self; -} - -/// Defines saturating multiplication operation for primitive arrays -pub trait ArraySaturatingMul: Sized { - /// saturating multiplication - fn saturating_mul(&self, rhs: &Rhs) -> Self; -} - -/// Defines Overflowing multiplication operation for primitive arrays -pub trait ArrayOverflowingMul: Sized { - /// overflowing multiplication - fn overflowing_mul(&self, rhs: &Rhs) -> (Self, Bitmap); -} - -/// Defines basic division operation for primitive arrays -pub trait ArrayDiv: Sized { - /// division - fn div(&self, rhs: &Rhs) -> Self; -} - -/// Defines checked division operation for primitive arrays -pub trait ArrayCheckedDiv: Sized { - /// checked division - fn checked_div(&self, rhs: &Rhs) -> Self; -} - -/// Defines basic reminder operation for primitive arrays -pub trait ArrayRem: Sized { - /// remainder - fn rem(&self, rhs: &Rhs) -> Self; -} - -/// Defines checked reminder operation for primitive arrays -pub trait ArrayCheckedRem: Sized { - /// checked remainder - fn checked_rem(&self, rhs: &Rhs) -> Self; -} diff --git a/crates/polars-arrow/src/compute/arithmetics/time.rs b/crates/polars-arrow/src/compute/arithmetics/time.rs deleted file mode 100644 index 0e3003e638b3..000000000000 --- a/crates/polars-arrow/src/compute/arithmetics/time.rs +++ /dev/null @@ -1,432 +0,0 @@ -//! Defines the arithmetic kernels for adding a Duration to a Timestamp, -//! Time32, Time64, Date32 and Date64. -//! -//! For the purposes of Arrow Implementations, adding this value to a Timestamp -//! ("t1") naively (i.e. simply summing the two number) is acceptable even -//! though in some cases the resulting Timestamp (t2) would not account for -//! leap-seconds during the elapsed time between "t1" and "t2". Similarly, -//! representing the difference between two Unix timestamp is acceptable, but -//! would yield a value that is possibly a few seconds off from the true -//! elapsed time. - -use std::ops::{Add, Sub}; - -use num_traits::AsPrimitive; -use polars_error::{polars_bail, PolarsResult}; - -use crate::array::PrimitiveArray; -use crate::compute::arity::{binary, unary}; -use crate::datatypes::{DataType, TimeUnit}; -use crate::scalar::{PrimitiveScalar, Scalar}; -use crate::temporal_conversions; -use crate::types::{months_days_ns, NativeType}; - -/// Creates the scale required to add or subtract a Duration to a time array -/// (Timestamp, Time, or Date). The resulting scale always multiplies the rhs -/// number (Duration) so it can be added to the lhs number (time array). -fn create_scale(lhs: &DataType, rhs: &DataType) -> PolarsResult { - // Matching on both data types from both numbers to calculate the correct - // scale for the operation. The timestamp, Time and duration have a - // Timeunit enum in its data type. This enum is used to describe the - // addition of the duration. The Date32 and Date64 have different rules for - // the scaling. - let scale = match (lhs, rhs) { - (DataType::Timestamp(timeunit_a, _), DataType::Duration(timeunit_b)) - | (DataType::Time32(timeunit_a), DataType::Duration(timeunit_b)) - | (DataType::Time64(timeunit_a), DataType::Duration(timeunit_b)) => { - // The scale is based on the TimeUnit that each of the numbers have. - temporal_conversions::timeunit_scale(*timeunit_a, *timeunit_b) - }, - (DataType::Date32, DataType::Duration(timeunit)) => { - // Date32 represents the time elapsed time since UNIX epoch - // (1970-01-01) in days (32 bits). The duration value has to be - // scaled to days to be able to add the value to the Date. - temporal_conversions::timeunit_scale(TimeUnit::Second, *timeunit) - / temporal_conversions::SECONDS_IN_DAY as f64 - }, - (DataType::Date64, DataType::Duration(timeunit)) => { - // Date64 represents the time elapsed time since UNIX epoch - // (1970-01-01) in milliseconds (64 bits). The duration value has - // to be scaled to milliseconds to be able to add the value to the - // Date. - temporal_conversions::timeunit_scale(TimeUnit::Millisecond, *timeunit) - }, - _ => { - polars_bail!(ComputeError: - "Incorrect data type for the arguments" - ) - }, - }; - - Ok(scale) -} - -/// Adds a duration to a time array (Timestamp, Time and Date). The timeunit -/// enum is used to scale correctly both arrays; adding seconds with seconds, -/// or milliseconds with milliseconds. -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::time::add_duration; -/// use polars_arrow::array::PrimitiveArray; -/// use polars_arrow::datatypes::{DataType, TimeUnit}; -/// -/// let timestamp = PrimitiveArray::from([ -/// Some(100000i64), -/// Some(200000i64), -/// None, -/// Some(300000i64), -/// ]) -/// .to(DataType::Timestamp( -/// TimeUnit::Second, -/// Some("America/New_York".to_string()), -/// )); -/// -/// let duration = PrimitiveArray::from([Some(10i64), Some(20i64), None, Some(30i64)]) -/// .to(DataType::Duration(TimeUnit::Second)); -/// -/// let result = add_duration(×tamp, &duration); -/// let expected = PrimitiveArray::from([ -/// Some(100010i64), -/// Some(200020i64), -/// None, -/// Some(300030i64), -/// ]) -/// .to(DataType::Timestamp( -/// TimeUnit::Second, -/// Some("America/New_York".to_string()), -/// )); -/// -/// assert_eq!(result, expected); -/// ``` -pub fn add_duration( - time: &PrimitiveArray, - duration: &PrimitiveArray, -) -> PrimitiveArray -where - f64: AsPrimitive, - T: NativeType + Add, -{ - let scale = create_scale(time.data_type(), duration.data_type()).unwrap(); - - // Closure for the binary operation. The closure contains the scale - // required to add a duration to the timestamp array. - let op = move |a: T, b: i64| a + (b as f64 * scale).as_(); - - binary(time, duration, time.data_type().clone(), op) -} - -/// Adds a duration to a time array (Timestamp, Time and Date). The timeunit -/// enum is used to scale correctly both arrays; adding seconds with seconds, -/// or milliseconds with milliseconds. -pub fn add_duration_scalar( - time: &PrimitiveArray, - duration: &PrimitiveScalar, -) -> PrimitiveArray -where - f64: AsPrimitive, - T: NativeType + Add, -{ - let scale = create_scale(time.data_type(), duration.data_type()).unwrap(); - let duration = if let Some(duration) = *duration.value() { - duration - } else { - return PrimitiveArray::::new_null(time.data_type().clone(), time.len()); - }; - - // Closure for the binary operation. The closure contains the scale - // required to add a duration to the timestamp array. - let op = move |a: T| a + (duration as f64 * scale).as_(); - - unary(time, op, time.data_type().clone()) -} - -/// Subtract a duration to a time array (Timestamp, Time and Date). The timeunit -/// enum is used to scale correctly both arrays; adding seconds with seconds, -/// or milliseconds with milliseconds. -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::time::subtract_duration; -/// use polars_arrow::array::PrimitiveArray; -/// use polars_arrow::datatypes::{DataType, TimeUnit}; -/// -/// let timestamp = PrimitiveArray::from([ -/// Some(100000i64), -/// Some(200000i64), -/// None, -/// Some(300000i64), -/// ]) -/// .to(DataType::Timestamp( -/// TimeUnit::Second, -/// Some("America/New_York".to_string()), -/// )); -/// -/// let duration = PrimitiveArray::from([Some(10i64), Some(20i64), None, Some(30i64)]) -/// .to(DataType::Duration(TimeUnit::Second)); -/// -/// let result = subtract_duration(×tamp, &duration); -/// let expected = PrimitiveArray::from([ -/// Some(99990i64), -/// Some(199980i64), -/// None, -/// Some(299970i64), -/// ]) -/// .to(DataType::Timestamp( -/// TimeUnit::Second, -/// Some("America/New_York".to_string()), -/// )); -/// -/// assert_eq!(result, expected); -/// -/// ``` -pub fn subtract_duration( - time: &PrimitiveArray, - duration: &PrimitiveArray, -) -> PrimitiveArray -where - f64: AsPrimitive, - T: NativeType + Sub, -{ - let scale = create_scale(time.data_type(), duration.data_type()).unwrap(); - - // Closure for the binary operation. The closure contains the scale - // required to add a duration to the timestamp array. - let op = move |a: T, b: i64| a - (b as f64 * scale).as_(); - - binary(time, duration, time.data_type().clone(), op) -} - -/// Subtract a duration to a time array (Timestamp, Time and Date). The timeunit -/// enum is used to scale correctly both arrays; adding seconds with seconds, -/// or milliseconds with milliseconds. -pub fn sub_duration_scalar( - time: &PrimitiveArray, - duration: &PrimitiveScalar, -) -> PrimitiveArray -where - f64: AsPrimitive, - T: NativeType + Sub, -{ - let scale = create_scale(time.data_type(), duration.data_type()).unwrap(); - let duration = if let Some(duration) = *duration.value() { - duration - } else { - return PrimitiveArray::::new_null(time.data_type().clone(), time.len()); - }; - - let op = move |a: T| a - (duration as f64 * scale).as_(); - - unary(time, op, time.data_type().clone()) -} - -/// Calculates the difference between two timestamps returning an array of type -/// Duration. The timeunit enum is used to scale correctly both arrays; -/// subtracting seconds with seconds, or milliseconds with milliseconds. -/// -/// # Examples -/// ``` -/// use polars_arrow::compute::arithmetics::time::subtract_timestamps; -/// use polars_arrow::array::PrimitiveArray; -/// use polars_arrow::datatypes::{DataType, TimeUnit}; -/// let timestamp_a = PrimitiveArray::from([ -/// Some(100_010i64), -/// Some(200_020i64), -/// None, -/// Some(300_030i64), -/// ]) -/// .to(DataType::Timestamp(TimeUnit::Second, None)); -/// -/// let timestamp_b = PrimitiveArray::from([ -/// Some(100_000i64), -/// Some(200_000i64), -/// None, -/// Some(300_000i64), -/// ]) -/// .to(DataType::Timestamp(TimeUnit::Second, None)); -/// -/// let expected = PrimitiveArray::from([Some(10i64), Some(20i64), None, Some(30i64)]) -/// .to(DataType::Duration(TimeUnit::Second)); -/// -/// let result = subtract_timestamps(×tamp_a, &×tamp_b).unwrap(); -/// assert_eq!(result, expected); -/// ``` -pub fn subtract_timestamps( - lhs: &PrimitiveArray, - rhs: &PrimitiveArray, -) -> PolarsResult> { - // Matching on both data types from both arrays. - // Both timestamps have a Timeunit enum in its data type. - // This enum is used to adjust the scale between the timestamps. - match (lhs.data_type(), rhs.data_type()) { - // Naive timestamp comparison. It doesn't take into account timezones - // from the Timestamp timeunit. - (DataType::Timestamp(timeunit_a, None), DataType::Timestamp(timeunit_b, None)) => { - // Closure for the binary operation. The closure contains the scale - // required to calculate the difference between the timestamps. - let scale = temporal_conversions::timeunit_scale(*timeunit_a, *timeunit_b); - let op = move |a, b| a - (b as f64 * scale) as i64; - - Ok(binary(lhs, rhs, DataType::Duration(*timeunit_a), op)) - }, - _ => polars_bail!(ComputeError: - "Incorrect data type for the arguments" - ) - } -} - -/// Calculates the difference between two timestamps as [`DataType::Duration`] with the same time scale. -pub fn sub_timestamps_scalar( - lhs: &PrimitiveArray, - rhs: &PrimitiveScalar, -) -> PolarsResult> { - let (scale, timeunit_a) = - if let (DataType::Timestamp(timeunit_a, None), DataType::Timestamp(timeunit_b, None)) = - (lhs.data_type(), rhs.data_type()) - { - ( - temporal_conversions::timeunit_scale(*timeunit_a, *timeunit_b), - timeunit_a, - ) - } else { - return Err(Error::InvalidArgumentError( - "sub_timestamps_scalar requires both arguments to be timestamps without timezone" - .to_string(), - )); - }; - - let rhs = if let Some(value) = *rhs.value() { - value - } else { - return Ok(PrimitiveArray::::new_null( - lhs.data_type().clone(), - lhs.len(), - )); - }; - - let op = move |a| a - (rhs as f64 * scale) as i64; - - Ok(unary(lhs, op, DataType::Duration(*timeunit_a))) -} - -/// Adds an interval to a [`DataType::Timestamp`]. -pub fn add_interval( - timestamp: &PrimitiveArray, - interval: &PrimitiveArray, -) -> PolarsResult> { - match timestamp.data_type().to_logical_type() { - DataType::Timestamp(time_unit, Some(timezone_str)) => { - let time_unit = *time_unit; - let timezone = temporal_conversions::parse_offset(timezone_str); - match timezone { - Ok(timezone) => Ok(binary( - timestamp, - interval, - timestamp.data_type().clone(), - |timestamp, interval| { - temporal_conversions::add_interval( - timestamp, time_unit, interval, &timezone, - ) - }, - )), - #[cfg(feature = "chrono-tz")] - Err(_) => { - let timezone = temporal_conversions::parse_offset_tz(timezone_str)?; - Ok(binary( - timestamp, - interval, - timestamp.data_type().clone(), - |timestamp, interval| { - temporal_conversions::add_interval( - timestamp, time_unit, interval, &timezone, - ) - }, - )) - }, - #[cfg(not(feature = "chrono-tz"))] - _ => Err(Error::InvalidArgumentError(format!( - "timezone \"{}\" cannot be parsed (feature chrono-tz is not active)", - timezone_str - ))), - } - }, - DataType::Timestamp(time_unit, None) => { - let time_unit = *time_unit; - Ok(binary( - timestamp, - interval, - timestamp.data_type().clone(), - |timestamp, interval| { - temporal_conversions::add_naive_interval(timestamp, time_unit, interval) - }, - )) - }, - _ => Err(Error::InvalidArgumentError( - "Adding an interval is only supported for `DataType::Timestamp`".to_string(), - )), - } -} - -/// Adds an interval to a [`DataType::Timestamp`]. -pub fn add_interval_scalar( - timestamp: &PrimitiveArray, - interval: &PrimitiveScalar, -) -> PolarsResult> { - let interval = if let Some(interval) = *interval.value() { - interval - } else { - return Ok(PrimitiveArray::::new_null( - timestamp.data_type().clone(), - timestamp.len(), - )); - }; - - match timestamp.data_type().to_logical_type() { - DataType::Timestamp(time_unit, Some(timezone_str)) => { - let time_unit = *time_unit; - let timezone = temporal_conversions::parse_offset(timezone_str); - match timezone { - Ok(timezone) => Ok(unary( - timestamp, - |timestamp| { - temporal_conversions::add_interval( - timestamp, time_unit, interval, &timezone, - ) - }, - timestamp.data_type().clone(), - )), - #[cfg(feature = "chrono-tz")] - Err(_) => { - let timezone = temporal_conversions::parse_offset_tz(timezone_str)?; - Ok(unary( - timestamp, - |timestamp| { - temporal_conversions::add_interval( - timestamp, time_unit, interval, &timezone, - ) - }, - timestamp.data_type().clone(), - )) - }, - #[cfg(not(feature = "chrono-tz"))] - _ => Err(Error::InvalidArgumentError(format!( - "timezone \"{}\" cannot be parsed (feature chrono-tz is not active)", - timezone_str - ))), - } - }, - DataType::Timestamp(time_unit, None) => { - let time_unit = *time_unit; - Ok(unary( - timestamp, - |timestamp| { - temporal_conversions::add_naive_interval(timestamp, time_unit, interval) - }, - timestamp.data_type().clone(), - )) - }, - _ => Err(Error::InvalidArgumentError( - "Adding an interval is only supported for `DataType::Timestamp`".to_string(), - )), - } -} diff --git a/crates/polars-arrow/src/compute/if_then_else.rs b/crates/polars-arrow/src/compute/if_then_else.rs index 292f4e484f81..9433f431fb19 100644 --- a/crates/polars-arrow/src/compute/if_then_else.rs +++ b/crates/polars-arrow/src/compute/if_then_else.rs @@ -6,24 +6,6 @@ use crate::bitmap::utils::SlicesIterator; /// Returns the values from `lhs` if the predicate is `true` or from the `rhs` if the predicate is false /// Returns `None` if the predicate is `None`. -/// # Example -/// ```rust -/// # use polars_arrow::error::Result; -/// use polars_arrow::compute::if_then_else::if_then_else; -/// use polars_arrow::array::{Int32Array, BooleanArray}; -/// -/// # fn main() -> Result<()> { -/// let lhs = Int32Array::from_slice(&[1, 2, 3]); -/// let rhs = Int32Array::from_slice(&[4, 5, 6]); -/// let predicate = BooleanArray::from(&[Some(true), None, Some(false)]); -/// let result = if_then_else(&predicate, &lhs, &rhs)?; -/// -/// let expected = Int32Array::from(&[Some(1), None, Some(6)]); -/// -/// assert_eq!(expected, result.as_ref()); -/// # Ok(()) -/// # } -/// ``` pub fn if_then_else( predicate: &BooleanArray, lhs: &dyn Array, diff --git a/crates/polars-arrow/src/io/ipc/mod.rs b/crates/polars-arrow/src/io/ipc/mod.rs index 6ac9c3011b79..39ad7753359f 100644 --- a/crates/polars-arrow/src/io/ipc/mod.rs +++ b/crates/polars-arrow/src/io/ipc/mod.rs @@ -25,54 +25,6 @@ //! the case of the `File` variant it also implements [`Seek`](std::io::Seek). In //! practice it means that `File`s can be arbitrarily accessed while `Stream`s are only //! read in certain order - the one they were written in (first in, first out). -//! -//! # Examples -//! Read and write to a file: -//! ``` -//! use polars_arrow::io::ipc::{{read::{FileReader, read_file_metadata}}, {write::{FileWriter, WriteOptions}}}; -//! # use std::fs::File; -//! # use polars_arrow::datatypes::{Field, Schema, DataType}; -//! # use polars_arrow::array::{Int32Array, Array}; -//! # use polars_arrow::chunk::Chunk; -//! # use polars_arrow::error::Error; -//! // Setup the writer -//! let path = "example.arrow".to_string(); -//! let mut file = File::create(&path)?; -//! let x_coord = Field::new("x", DataType::Int32, false); -//! let y_coord = Field::new("y", DataType::Int32, false); -//! let schema = Schema::from(vec![x_coord, y_coord]); -//! let options = WriteOptions {compression: None}; -//! let mut writer = FileWriter::try_new(file, schema, None, options)?; -//! -//! // Setup the data -//! let x_data = Int32Array::from_slice([-1i32, 1]); -//! let y_data = Int32Array::from_slice([1i32, -1]); -//! let chunk = Chunk::try_new(vec![x_data.boxed(), y_data.boxed()])?; -//! -//! // Write the messages and finalize the stream -//! for _ in 0..5 { -//! writer.write(&chunk, None); -//! } -//! writer.finish(); -//! -//! // Fetch some of the data and get the reader back -//! let mut reader = File::open(&path)?; -//! let metadata = read_file_metadata(&mut reader)?; -//! let mut reader = FileReader::new(reader, metadata, None, None); -//! let row1 = reader.next().unwrap(); // [[-1, 1], [1, -1]] -//! let row2 = reader.next().unwrap(); // [[-1, 1], [1, -1]] -//! let mut reader = reader.into_inner(); -//! // Do more stuff with the reader, like seeking ahead. -//! # Ok::<(), Error>(()) -//! ``` -//! -//! For further information and examples please consult the -//! [user guide](https://jorgecarleitao.github.io/polars_arrow/io/index.html). -//! For even more examples check the `examples` folder in the main repository -//! ([1](https://github.com/jorgecarleitao/polars_arrow/blob/main/examples/ipc_file_read.rs), -//! [2](https://github.com/jorgecarleitao/polars_arrow/blob/main/examples/ipc_file_write.rs), -//! [3](https://github.com/jorgecarleitao/polars_arrow/tree/main/examples/ipc_pyarrow)). - mod compression; mod endianness; diff --git a/crates/polars-arrow/src/io/ipc/write/file_async.rs b/crates/polars-arrow/src/io/ipc/write/file_async.rs index cad67f35fdea..5ed1350a65ff 100644 --- a/crates/polars-arrow/src/io/ipc/write/file_async.rs +++ b/crates/polars-arrow/src/io/ipc/write/file_async.rs @@ -21,46 +21,6 @@ type WriteOutput = (usize, Option, Vec, Option); /// /// The file header is automatically written before writing the first chunk, and the file footer is /// automatically written when the sink is closed. -/// -/// # Examples -/// -/// ``` -/// use futures::{SinkExt, TryStreamExt, io::Cursor}; -/// use polars_arrow::array::{Array, Int32Array}; -/// use polars_arrow::datatypes::{DataType, Field, Schema}; -/// use polars_arrow::chunk::Chunk; -/// use polars_arrow::io::ipc::write::file_async::FileSink; -/// use polars_arrow::io::ipc::read::file_async::{read_file_metadata_async, FileStream}; -/// # futures::executor::block_on(async move { -/// let schema = Schema::from(vec![ -/// Field::new("values", DataType::Int32, true), -/// ]); -/// -/// let mut buffer = Cursor::new(vec![]); -/// let mut sink = FileSink::new( -/// &mut buffer, -/// schema, -/// None, -/// Default::default(), -/// ); -/// -/// // Write chunks to file -/// for i in 0..3 { -/// let values = Int32Array::from(&[Some(i), None]); -/// let chunk = Chunk::new(vec![values.boxed()]); -/// sink.feed(chunk.into()).await?; -/// } -/// sink.close().await?; -/// drop(sink); -/// -/// // Read chunks from file -/// buffer.set_position(0); -/// let metadata = read_file_metadata_async(&mut buffer).await?; -/// let mut stream = FileStream::new(buffer, metadata, None, None); -/// let chunks = stream.try_collect::>().await?; -/// # polars_arrow::error::Result::Ok(()) -/// # }).unwrap(); -/// ``` pub struct FileSink<'a, W: AsyncWrite + Unpin + Send + 'a> { writer: Option, task: Option>>>, diff --git a/crates/polars-arrow/src/io/ipc/write/stream_async.rs b/crates/polars-arrow/src/io/ipc/write/stream_async.rs index 7e8d056ce52b..49305c2ab383 100644 --- a/crates/polars-arrow/src/io/ipc/write/stream_async.rs +++ b/crates/polars-arrow/src/io/ipc/write/stream_async.rs @@ -17,37 +17,6 @@ use crate::datatypes::*; /// A sink that writes array [`chunks`](crate::chunk::Chunk) as an IPC stream. /// /// The stream header is automatically written before writing the first chunk. -/// -/// # Examples -/// -/// ``` -/// use futures::SinkExt; -/// use polars_arrow::array::{Array, Int32Array}; -/// use polars_arrow::datatypes::{DataType, Field, Schema}; -/// use polars_arrow::chunk::Chunk; -/// # use polars_arrow::io::ipc::write::stream_async::StreamSink; -/// # futures::executor::block_on(async move { -/// let schema = Schema::from(vec![ -/// Field::new("values", DataType::Int32, true), -/// ]); -/// -/// let mut buffer = vec![]; -/// let mut sink = StreamSink::new( -/// &mut buffer, -/// &schema, -/// None, -/// Default::default(), -/// ); -/// -/// for i in 0..3 { -/// let values = Int32Array::from(&[Some(i), None]); -/// let chunk = Chunk::new(vec![values.boxed()]); -/// sink.feed(chunk.into()).await?; -/// } -/// sink.close().await?; -/// # polars_arrow::error::Result::Ok(()) -/// # }).unwrap(); -/// ``` pub struct StreamSink<'a, W: AsyncWrite + Unpin + Send + 'a> { writer: Option, task: Option>>>, diff --git a/crates/polars-arrow/src/io/parquet/write/sink.rs b/crates/polars-arrow/src/io/parquet/write/sink.rs index 284e2a6f7639..d8d2734ce461 100644 --- a/crates/polars-arrow/src/io/parquet/write/sink.rs +++ b/crates/polars-arrow/src/io/parquet/write/sink.rs @@ -18,47 +18,6 @@ use crate::datatypes::Schema; /// /// Any values in the sink's `metadata` field will be written to the file's footer /// when the sink is closed. -/// -/// # Examples -/// -/// ``` -/// use futures::SinkExt; -/// use polars_arrow::array::{Array, Int32Array}; -/// use polars_arrow::datatypes::{DataType, Field, Schema}; -/// use polars_arrow::chunk::Chunk; -/// use polars_arrow::io::parquet::write::{Encoding, WriteOptions, CompressionOptions, Version}; -/// # use polars_arrow::io::parquet::write::FileSink; -/// # futures::executor::block_on(async move { -/// -/// let schema = Schema::from(vec![ -/// Field::new("values", DataType::Int32, true), -/// ]); -/// let encoding = vec![vec![Encoding::Plain]]; -/// let options = WriteOptions { -/// write_statistics: true, -/// compression: CompressionOptions::Uncompressed, -/// version: Version::V2, -/// data_pagesize_limit: None, -/// }; -/// -/// let mut buffer = vec![]; -/// let mut sink = FileSink::try_new( -/// &mut buffer, -/// schema, -/// encoding, -/// options, -/// )?; -/// -/// for i in 0..3 { -/// let values = Int32Array::from(&[Some(i), None]); -/// let chunk = Chunk::new(vec![values.boxed()]); -/// sink.feed(chunk).await?; -/// } -/// sink.metadata.insert(String::from("key"), Some(String::from("value"))); -/// sink.close().await?; -/// # polars_arrow::error::Result::Ok(()) -/// # }).unwrap(); -/// ``` pub struct FileSink<'a, W: AsyncWrite + Send + Unpin> { writer: Option>, task: Option>>>>, From 89cc1e2d67448183aa622e1b5f7130793b29b081 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Wed, 18 Oct 2023 10:32:03 +0200 Subject: [PATCH 042/103] refactor(python): Assert utils refactor (#11813) --- py-polars/polars/testing/__init__.py | 4 +- py-polars/polars/testing/asserts/__init__.py | 9 + py-polars/polars/testing/asserts/frame.py | 274 +++++++++++ .../testing/{asserts.py => asserts/series.py} | 460 +++++------------- py-polars/polars/testing/asserts/utils.py | 15 + py-polars/tests/unit/testing/__init__.py | 0 .../unit/testing/test_assert_frame_equal.py | 420 ++++++++++++++++ ...testing.py => test_assert_series_equal.py} | 409 +--------------- 8 files changed, 835 insertions(+), 756 deletions(-) create mode 100644 py-polars/polars/testing/asserts/__init__.py create mode 100644 py-polars/polars/testing/asserts/frame.py rename py-polars/polars/testing/{asserts.py => asserts/series.py} (54%) create mode 100644 py-polars/polars/testing/asserts/utils.py create mode 100644 py-polars/tests/unit/testing/__init__.py create mode 100644 py-polars/tests/unit/testing/test_assert_frame_equal.py rename py-polars/tests/unit/testing/{test_testing.py => test_assert_series_equal.py} (60%) diff --git a/py-polars/polars/testing/__init__.py b/py-polars/polars/testing/__init__.py index 2461de6ba6ff..b5962f7fba2c 100644 --- a/py-polars/polars/testing/__init__.py +++ b/py-polars/polars/testing/__init__.py @@ -6,8 +6,8 @@ ) __all__ = [ - "assert_series_equal", - "assert_series_not_equal", "assert_frame_equal", "assert_frame_not_equal", + "assert_series_equal", + "assert_series_not_equal", ] diff --git a/py-polars/polars/testing/asserts/__init__.py b/py-polars/polars/testing/asserts/__init__.py new file mode 100644 index 000000000000..4e00da7cc1fa --- /dev/null +++ b/py-polars/polars/testing/asserts/__init__.py @@ -0,0 +1,9 @@ +from polars.testing.asserts.frame import assert_frame_equal, assert_frame_not_equal +from polars.testing.asserts.series import assert_series_equal, assert_series_not_equal + +__all__ = [ + "assert_frame_equal", + "assert_frame_not_equal", + "assert_series_equal", + "assert_series_not_equal", +] diff --git a/py-polars/polars/testing/asserts/frame.py b/py-polars/polars/testing/asserts/frame.py new file mode 100644 index 000000000000..4faa4d810050 --- /dev/null +++ b/py-polars/polars/testing/asserts/frame.py @@ -0,0 +1,274 @@ +from __future__ import annotations + +from typing import cast + +from polars.dataframe import DataFrame +from polars.exceptions import ComputeError, InvalidAssert +from polars.lazyframe import LazyFrame +from polars.testing.asserts.series import _assert_series_values_equal +from polars.testing.asserts.utils import raise_assertion_error + + +def assert_frame_equal( + left: DataFrame | LazyFrame, + right: DataFrame | LazyFrame, + *, + check_row_order: bool = True, + check_column_order: bool = True, + check_dtype: bool = True, + check_exact: bool = False, + rtol: float = 1e-5, + atol: float = 1e-8, + nans_compare_equal: bool = True, + categorical_as_str: bool = False, +) -> None: + """ + Assert that the left and right frame are equal. + + Raises a detailed ``AssertionError`` if the frames differ. + This function is intended for use in unit tests. + + Parameters + ---------- + left + The first DataFrame or LazyFrame to compare. + right + The second DataFrame or LazyFrame to compare. + check_row_order + Require row order to match. + + .. note:: + Setting this to ``False`` requires sorting the data, which will fail on + frames that contain unsortable columns. + check_column_order + Require column order to match. + check_dtype + Require data types to match. + check_exact + Require data values to match exactly. If set to ``False``, values are considered + equal when within tolerance of each other (see ``rtol`` and ``atol``). + rtol + Relative tolerance for inexact checking. Fraction of values in ``right``. + atol + Absolute tolerance for inexact checking. + nans_compare_equal + Consider NaN values to be equal. + categorical_as_str + Cast categorical columns to string before comparing. Enabling this helps + compare columns that do not share the same string cache. + + See Also + -------- + assert_series_equal + assert_frame_not_equal + + Notes + ----- + When using pytest, it may be worthwhile to shorten Python traceback printing + by passing ``--tb=short``. The default mode tends to be unhelpfully verbose. + More information in the + `pytest docs `_. + + Examples + -------- + >>> from polars.testing import assert_frame_equal + >>> df1 = pl.DataFrame({"a": [1, 2, 3]}) + >>> df2 = pl.DataFrame({"a": [1, 5, 3]}) + >>> assert_frame_equal(df1, df2) # doctest: +SKIP + Traceback (most recent call last): + ... + AssertionError: Series are different (value mismatch) + [left]: [1, 2, 3] + [right]: [1, 5, 3] + + The above exception was the direct cause of the following exception: + + Traceback (most recent call last): + ... + AssertionError: values for column 'a' are different + + """ + lazy = _assert_correct_input_type(left, right) + objects = "LazyFrames" if lazy else "DataFrames" + + _assert_frame_schema_equal( + left, + right, + check_column_order=check_column_order, + check_dtype=check_dtype, + objects=objects, + ) + + if lazy: + left, right = left.collect(), right.collect() # type: ignore[union-attr] + left, right = cast(DataFrame, left), cast(DataFrame, right) + + if left.height != right.height: + raise_assertion_error( + objects, "number of rows does not match", left.height, right.height + ) + + if not check_row_order: + left, right = _sort_dataframes(left, right) + + for c in left.columns: + try: + _assert_series_values_equal( + left.get_column(c), + right.get_column(c), + check_exact=check_exact, + rtol=rtol, + atol=atol, + nans_compare_equal=nans_compare_equal, + categorical_as_str=categorical_as_str, + ) + except AssertionError as exc: + msg = f"values for column {c!r} are different" + raise AssertionError(msg) from exc + + +def _assert_correct_input_type( + left: DataFrame | LazyFrame, right: DataFrame | LazyFrame +) -> bool: + if isinstance(left, DataFrame) and isinstance(right, DataFrame): + return False + elif isinstance(left, LazyFrame) and isinstance(right, LazyFrame): + return True + else: + raise_assertion_error( + "inputs", + "unexpected input types", + type(left).__name__, + type(right).__name__, + ) + + +def _assert_frame_schema_equal( + left: DataFrame | LazyFrame, + right: DataFrame | LazyFrame, + *, + check_dtype: bool, + check_column_order: bool, + objects: str, +) -> None: + left_schema, right_schema = left.schema, right.schema + + # Fast path for equal frames + if left_schema == right_schema: + return + + # Special error message for when column names do not match + if left_schema.keys() != right_schema.keys(): + if left_not_right := [c for c in left_schema if c not in right_schema]: + msg = f"columns {left_not_right!r} in left {objects[:-1]}, but not in right" + raise AssertionError(msg) + else: + right_not_left = [c for c in right_schema if c not in left_schema] + msg = f"columns {right_not_left!r} in right {objects[:-1]}, but not in left" + raise AssertionError(msg) + + if check_column_order: + left_columns, right_columns = list(left_schema), list(right_schema) + if left_columns != right_columns: + detail = "columns are not in the same order" + raise_assertion_error(objects, detail, left_columns, right_columns) + + if check_dtype: + left_schema_dict, right_schema_dict = dict(left_schema), dict(right_schema) + if check_column_order or left_schema_dict != right_schema_dict: + detail = "dtypes do not match" + raise_assertion_error(objects, detail, left_schema_dict, right_schema_dict) + + +def _sort_dataframes(left: DataFrame, right: DataFrame) -> tuple[DataFrame, DataFrame]: + by = left.columns + try: + left = left.sort(by) + right = right.sort(by) + except ComputeError as exc: + msg = "cannot set `check_row_order=False` on frame with unsortable columns" + raise InvalidAssert(msg) from exc + return left, right + + +def assert_frame_not_equal( + left: DataFrame | LazyFrame, + right: DataFrame | LazyFrame, + *, + check_row_order: bool = True, + check_column_order: bool = True, + check_dtype: bool = True, + check_exact: bool = False, + rtol: float = 1e-5, + atol: float = 1e-8, + nans_compare_equal: bool = True, + categorical_as_str: bool = False, +) -> None: + """ + Assert that the left and right frame are **not** equal. + + This function is intended for use in unit tests. + + Parameters + ---------- + left + The first DataFrame or LazyFrame to compare. + right + The second DataFrame or LazyFrame to compare. + check_row_order + Require row order to match. + + .. note:: + Setting this to ``False`` requires sorting the data, which will fail on + frames that contain unsortable columns. + check_column_order + Require column order to match. + check_dtype + Require data types to match. + check_exact + Require data values to match exactly. If set to ``False``, values are considered + equal when within tolerance of each other (see ``rtol`` and ``atol``). + rtol + Relative tolerance for inexact checking. Fraction of values in ``right``. + atol + Absolute tolerance for inexact checking. + nans_compare_equal + Consider NaN values to be equal. + categorical_as_str + Cast categorical columns to string before comparing. Enabling this helps + compare columns that do not share the same string cache. + + See Also + -------- + assert_frame_equal + assert_series_not_equal + + Examples + -------- + >>> from polars.testing import assert_frame_not_equal + >>> df1 = pl.DataFrame({"a": [1, 2, 3]}) + >>> df2 = pl.DataFrame({"a": [1, 2, 3]}) + >>> assert_frame_not_equal(df1, df2) # doctest: +SKIP + Traceback (most recent call last): + ... + AssertionError: frames are equal + + """ + try: + assert_frame_equal( + left=left, + right=right, + check_column_order=check_column_order, + check_row_order=check_row_order, + check_dtype=check_dtype, + check_exact=check_exact, + rtol=rtol, + atol=atol, + nans_compare_equal=nans_compare_equal, + categorical_as_str=categorical_as_str, + ) + except AssertionError: + return + else: + msg = "frames are equal" + raise AssertionError(msg) diff --git a/py-polars/polars/testing/asserts.py b/py-polars/polars/testing/asserts/series.py similarity index 54% rename from py-polars/polars/testing/asserts.py rename to py-polars/polars/testing/asserts/series.py index 7c44dd3a19f5..028c1cff6db7 100644 --- a/py-polars/polars/testing/asserts.py +++ b/py-polars/polars/testing/asserts/series.py @@ -1,9 +1,6 @@ from __future__ import annotations -from typing import Any, NoReturn - from polars import functions as F -from polars.dataframe import DataFrame from polars.datatypes import ( FLOAT_DTYPES, UNSIGNED_INTEGER_DTYPES, @@ -15,227 +12,8 @@ dtype_to_py_type, unpack_dtypes, ) -from polars.exceptions import ComputeError, InvalidAssert -from polars.lazyframe import LazyFrame from polars.series import Series - - -def assert_frame_equal( - left: DataFrame | LazyFrame, - right: DataFrame | LazyFrame, - *, - check_row_order: bool = True, - check_column_order: bool = True, - check_dtype: bool = True, - check_exact: bool = False, - rtol: float = 1.0e-5, - atol: float = 1.0e-8, - nans_compare_equal: bool = True, - categorical_as_str: bool = False, -) -> None: - """ - Assert that the left and right frame are equal. - - Raises a detailed ``AssertionError`` if the frames differ. - This function is intended for use in unit tests. - - Parameters - ---------- - left - The first DataFrame or LazyFrame to compare. - right - The second DataFrame or LazyFrame to compare. - check_row_order - Require row order to match. - - .. note:: - Setting this to ``False`` requires sorting the data, which will fail on - frames that contain unsortable columns. - check_column_order - Require column order to match. - check_dtype - Require data types to match. - check_exact - Require data values to match exactly. If set to ``False``, values are considered - equal when within tolerance of each other (see ``rtol`` and ``atol``). - rtol - Relative tolerance for inexact checking. Fraction of values in ``right``. - atol - Absolute tolerance for inexact checking. - nans_compare_equal - Consider NaN values to be equal. - categorical_as_str - Cast categorical columns to string before comparing. Enabling this helps - compare columns that do not share the same string cache. - - See Also - -------- - assert_series_equal - assert_frame_not_equal - - Examples - -------- - >>> from polars.testing import assert_frame_equal - >>> df1 = pl.DataFrame({"a": [1, 2, 3]}) - >>> df2 = pl.DataFrame({"a": [1, 5, 3]}) - >>> assert_frame_equal(df1, df2) # doctest: +SKIP - Traceback (most recent call last): - ... - AssertionError: Series are different (value mismatch) - [left]: [1, 2, 3] - [right]: [1, 5, 3] - - The above exception was the direct cause of the following exception: - - Traceback (most recent call last): - ... - AssertionError: values for column 'a' are different - - """ - collect_input_frames = isinstance(left, LazyFrame) and isinstance(right, LazyFrame) - if collect_input_frames: - objs = "LazyFrames" - elif isinstance(left, DataFrame) and isinstance(right, DataFrame): - objs = "DataFrames" - else: - _raise_assertion_error( - "Inputs", - "unexpected input types", - type(left).__name__, - type(right).__name__, - ) - - if left_not_right := [c for c in left.columns if c not in right.columns]: - msg = f"columns {left_not_right!r} in left frame, but not in right" - raise AssertionError(msg) - - if right_not_left := [c for c in right.columns if c not in left.columns]: - msg = f"columns {right_not_left!r} in right frame, but not in left" - raise AssertionError(msg) - - if check_column_order and left.columns != right.columns: - msg = f"columns are not in the same order:\n{left.columns!r}\n{right.columns!r}" - raise AssertionError(msg) - - if collect_input_frames: - if check_dtype: # check this _before_ we collect - left_schema, right_schema = left.schema, right.schema - if left_schema != right_schema: - _raise_assertion_error( - objs, "lazy schemas are not equal", left_schema, right_schema - ) - left, right = left.collect(), right.collect() # type: ignore[union-attr] - - if left.shape[0] != right.shape[0]: # type: ignore[union-attr] - _raise_assertion_error(objs, "length mismatch", left.shape, right.shape) # type: ignore[union-attr] - - if not check_row_order: - try: - left = left.sort(by=left.columns) - right = right.sort(by=left.columns) - except ComputeError as exc: - msg = "cannot set `check_row_order=False` on frame with unsortable columns" - raise InvalidAssert(msg) from exc - - # note: does not assume a particular column order - for c in left.columns: - try: - _assert_series_inner( - left[c], # type: ignore[arg-type, index] - right[c], # type: ignore[arg-type, index] - check_dtype=check_dtype, - check_exact=check_exact, - atol=atol, - rtol=rtol, - nans_compare_equal=nans_compare_equal, - categorical_as_str=categorical_as_str, - ) - except AssertionError as exc: - msg = f"values for column {c!r} are different" - raise AssertionError(msg) from exc - - -def assert_frame_not_equal( - left: DataFrame | LazyFrame, - right: DataFrame | LazyFrame, - *, - check_row_order: bool = True, - check_column_order: bool = True, - check_dtype: bool = True, - check_exact: bool = False, - rtol: float = 1.0e-5, - atol: float = 1.0e-8, - nans_compare_equal: bool = True, - categorical_as_str: bool = False, -) -> None: - """ - Assert that the left and right frame are **not** equal. - - This function is intended for use in unit tests. - - Parameters - ---------- - left - The first DataFrame or LazyFrame to compare. - right - The second DataFrame or LazyFrame to compare. - check_row_order - Require row order to match. - - .. note:: - Setting this to ``False`` requires sorting the data, which will fail on - frames that contain unsortable columns. - check_column_order - Require column order to match. - check_dtype - Require data types to match. - check_exact - Require data values to match exactly. If set to ``False``, values are considered - equal when within tolerance of each other (see ``rtol`` and ``atol``). - rtol - Relative tolerance for inexact checking. Fraction of values in ``right``. - atol - Absolute tolerance for inexact checking. - nans_compare_equal - Consider NaN values to be equal. - categorical_as_str - Cast categorical columns to string before comparing. Enabling this helps - compare columns that do not share the same string cache. - - See Also - -------- - assert_frame_equal - assert_series_not_equal - - Examples - -------- - >>> from polars.testing import assert_frame_not_equal - >>> df1 = pl.DataFrame({"a": [1, 2, 3]}) - >>> df2 = pl.DataFrame({"a": [1, 2, 3]}) - >>> assert_frame_not_equal(df1, df2) # doctest: +SKIP - Traceback (most recent call last): - ... - AssertionError: frames are equal - - """ - try: - assert_frame_equal( - left=left, - right=right, - check_column_order=check_column_order, - check_row_order=check_row_order, - check_dtype=check_dtype, - check_exact=check_exact, - rtol=rtol, - atol=atol, - nans_compare_equal=nans_compare_equal, - categorical_as_str=categorical_as_str, - ) - except AssertionError: - return - else: - msg = "frames are equal" - raise AssertionError(msg) +from polars.testing.asserts.utils import raise_assertion_error def assert_series_equal( @@ -245,8 +23,8 @@ def assert_series_equal( check_dtype: bool = True, check_names: bool = True, check_exact: bool = False, - rtol: float = 1.0e-5, - atol: float = 1.0e-8, + rtol: float = 1e-5, + atol: float = 1e-8, nans_compare_equal: bool = True, categorical_as_str: bool = False, ) -> None: @@ -284,6 +62,13 @@ def assert_series_equal( assert_frame_equal assert_series_not_equal + Notes + ----- + When using pytest, it may be worthwhile to shorten Python traceback printing + by passing ``--tb=short``. The default mode tends to be unhelpfully verbose. + More information in the + `pytest docs `_. + Examples -------- >>> from polars.testing import assert_series_equal @@ -298,124 +83,46 @@ def assert_series_equal( """ if not (isinstance(left, Series) and isinstance(right, Series)): # type: ignore[redundant-expr] - _raise_assertion_error( - "Inputs", + raise_assertion_error( + "inputs", "unexpected input types", type(left).__name__, type(right).__name__, ) - if len(left) != len(right): - _raise_assertion_error("Series", "length mismatch", len(left), len(right)) + if left.len() != right.len(): + raise_assertion_error("Series", "length mismatch", left.len(), right.len()) if check_names and left.name != right.name: - _raise_assertion_error("Series", "name mismatch", left.name, right.name) + raise_assertion_error("Series", "name mismatch", left.name, right.name) + + if check_dtype and left.dtype != right.dtype: + raise_assertion_error("Series", "dtype mismatch", left.dtype, right.dtype) - _assert_series_inner( + _assert_series_values_equal( left, right, - check_dtype=check_dtype, check_exact=check_exact, - atol=atol, rtol=rtol, + atol=atol, nans_compare_equal=nans_compare_equal, categorical_as_str=categorical_as_str, ) -def assert_series_not_equal( +def _assert_series_values_equal( left: Series, right: Series, *, - check_dtype: bool = True, - check_names: bool = True, - check_exact: bool = False, - rtol: float = 1.0e-5, - atol: float = 1.0e-8, - nans_compare_equal: bool = True, - categorical_as_str: bool = False, -) -> None: - """ - Assert that the left and right Series are **not** equal. - - This function is intended for use in unit tests. - - Parameters - ---------- - left - the series to compare. - right - the series to compare with. - check_dtype - if True, data types need to match exactly. - check_names - if True, names need to match. - check_exact - if False, test if values are within tolerance of each other - (see `rtol` & `atol`). - rtol - relative tolerance for inexact checking. Fraction of values in `right`. - atol - absolute tolerance for inexact checking. - nans_compare_equal - if your assert/test requires float NaN != NaN, set this to False. - categorical_as_str - Cast categorical columns to string before comparing. Enabling this helps - compare DataFrames that do not share the same string cache. - - See Also - -------- - assert_series_equal - assert_frame_not_equal - - Examples - -------- - >>> from polars.testing import assert_series_not_equal - >>> s1 = pl.Series([1, 2, 3]) - >>> s2 = pl.Series([1, 2, 3]) - >>> assert_series_not_equal(s1, s2) # doctest: +SKIP - Traceback (most recent call last): - ... - AssertionError: Series are equal - - """ - try: - assert_series_equal( - left=left, - right=right, - check_dtype=check_dtype, - check_names=check_names, - check_exact=check_exact, - rtol=rtol, - atol=atol, - nans_compare_equal=nans_compare_equal, - categorical_as_str=categorical_as_str, - ) - except AssertionError: - return - else: - msg = "Series are equal" - raise AssertionError(msg) - - -def _assert_series_inner( - left: Series, - right: Series, - *, - check_dtype: bool, check_exact: bool, - atol: float, rtol: float, + atol: float, nans_compare_equal: bool, categorical_as_str: bool, ) -> None: - """Compare Series dtype + values.""" - if check_dtype and left.dtype != right.dtype: - _raise_assertion_error("Series", "dtype mismatch", left.dtype, right.dtype) - + """Assert that the values in both Series are equal.""" if categorical_as_str and left.dtype == Categorical: - left = left.cast(Utf8) - right = right.cast(Utf8) + left, right = left.cast(Utf8), right.cast(Utf8) # create mask of which (if any) values are unequal unequal = left.ne_missing(right) @@ -436,10 +143,9 @@ def _assert_series_inner( if _assert_series_nested( left=left.filter(unequal), right=right.filter(unequal), - check_dtype=check_dtype, check_exact=check_exact, - atol=atol, rtol=rtol, + atol=atol, nans_compare_equal=nans_compare_equal, categorical_as_str=categorical_as_str, ): @@ -457,7 +163,7 @@ def _assert_series_inner( # assert exact, or with tolerance if unequal.any(): if check_exact: - _raise_assertion_error( + raise_assertion_error( "Series", "exact value mismatch", left=left.to_list(), @@ -468,14 +174,14 @@ def _assert_series_inner( left, right, unequal, - atol=atol, rtol=rtol, + atol=atol, nans_compare_equal=nans_compare_equal, comparing_floats=comparing_floats, ) if not equal: - _raise_assertion_error( + raise_assertion_error( "Series", f"value mismatch{nan_info}", left=left.to_list(), @@ -487,9 +193,9 @@ def _check_series_equal_inexact( left: Series, right: Series, unequal: Series, - atol: float, - rtol: float, *, + rtol: float, + atol: float, nans_compare_equal: bool, comparing_floats: bool, ) -> tuple[bool, str]: @@ -529,10 +235,9 @@ def _assert_series_nested( left: Series, right: Series, *, - check_dtype: bool, check_exact: bool, - atol: float, rtol: float, + atol: float, nans_compare_equal: bool, categorical_as_str: bool, ) -> bool: @@ -547,26 +252,25 @@ def _assert_series_nested( if nans_compare_equal: continue else: - _raise_assertion_error( + raise_assertion_error( "Series", f"Nested value mismatch (nans_compare_equal={nans_compare_equal})", s1, s2, ) elif (s1 is None and s2 is not None) or (s2 is None and s1 is not None): - _raise_assertion_error("Series", "nested value mismatch", s1, s2) + raise_assertion_error("Series", "nested value mismatch", s1, s2) elif len(s1) != len(s2): - _raise_assertion_error( + raise_assertion_error( "Series", "nested list length mismatch", len(s1), len(s2) ) - _assert_series_inner( + _assert_series_values_equal( s1, s2, - check_dtype=check_dtype, check_exact=check_exact, - atol=atol, rtol=rtol, + atol=atol, nans_compare_equal=nans_compare_equal, categorical_as_str=categorical_as_str, ) @@ -576,24 +280,23 @@ def _assert_series_nested( elif left.dtype == Struct == right.dtype: ls, rs = left.struct.unnest(), right.struct.unnest() if len(ls.columns) != len(rs.columns): - _raise_assertion_error( + raise_assertion_error( "Series", "nested struct fields mismatch", len(ls.columns), len(rs.columns), ) elif len(ls) != len(rs): - _raise_assertion_error( + raise_assertion_error( "Series", "nested struct length mismatch", len(ls), len(rs) ) for s1, s2 in zip(ls, rs): - _assert_series_inner( + _assert_series_values_equal( s1, s2, - check_dtype=check_dtype, check_exact=check_exact, - atol=atol, rtol=rtol, + atol=atol, nans_compare_equal=nans_compare_equal, categorical_as_str=categorical_as_str, ) @@ -604,13 +307,76 @@ def _assert_series_nested( return False -def _raise_assertion_error( - obj: str, - detail: str, - left: Any, - right: Any, -) -> NoReturn: - """Raise a detailed assertion error.""" - __tracebackhide__ = True - msg = f"{obj} are different ({detail})\n[left]: {left}\n[right]: {right}" - raise AssertionError(msg) +def assert_series_not_equal( + left: Series, + right: Series, + *, + check_dtype: bool = True, + check_names: bool = True, + check_exact: bool = False, + rtol: float = 1e-5, + atol: float = 1e-8, + nans_compare_equal: bool = True, + categorical_as_str: bool = False, +) -> None: + """ + Assert that the left and right Series are **not** equal. + + This function is intended for use in unit tests. + + Parameters + ---------- + left + the series to compare. + right + the series to compare with. + check_dtype + if True, data types need to match exactly. + check_names + if True, names need to match. + check_exact + if False, test if values are within tolerance of each other + (see `rtol` & `atol`). + rtol + relative tolerance for inexact checking. Fraction of values in `right`. + atol + absolute tolerance for inexact checking. + nans_compare_equal + if your assert/test requires float NaN != NaN, set this to False. + categorical_as_str + Cast categorical columns to string before comparing. Enabling this helps + compare DataFrames that do not share the same string cache. + + See Also + -------- + assert_series_equal + assert_frame_not_equal + + Examples + -------- + >>> from polars.testing import assert_series_not_equal + >>> s1 = pl.Series([1, 2, 3]) + >>> s2 = pl.Series([1, 2, 3]) + >>> assert_series_not_equal(s1, s2) # doctest: +SKIP + Traceback (most recent call last): + ... + AssertionError: Series are equal + + """ + try: + assert_series_equal( + left=left, + right=right, + check_dtype=check_dtype, + check_names=check_names, + check_exact=check_exact, + rtol=rtol, + atol=atol, + nans_compare_equal=nans_compare_equal, + categorical_as_str=categorical_as_str, + ) + except AssertionError: + return + else: + msg = "Series are equal" + raise AssertionError(msg) diff --git a/py-polars/polars/testing/asserts/utils.py b/py-polars/polars/testing/asserts/utils.py new file mode 100644 index 000000000000..713e57170ac1 --- /dev/null +++ b/py-polars/polars/testing/asserts/utils.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from typing import Any, NoReturn + + +def raise_assertion_error( + objects: str, + detail: str, + left: Any, + right: Any, +) -> NoReturn: + """Raise a detailed assertion error.""" + __tracebackhide__ = True + msg = f"{objects} are different ({detail})\n[left]: {left}\n[right]: {right}" + raise AssertionError(msg) diff --git a/py-polars/tests/unit/testing/__init__.py b/py-polars/tests/unit/testing/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/py-polars/tests/unit/testing/test_assert_frame_equal.py b/py-polars/tests/unit/testing/test_assert_frame_equal.py new file mode 100644 index 000000000000..c9a3bf19814d --- /dev/null +++ b/py-polars/tests/unit/testing/test_assert_frame_equal.py @@ -0,0 +1,420 @@ +from __future__ import annotations + +import math +from typing import Any + +import pytest + +import polars as pl +from polars.exceptions import InvalidAssert +from polars.testing import assert_frame_equal, assert_frame_not_equal + + +@pytest.mark.parametrize( + ("df1", "df2", "kwargs"), + [ + pytest.param( + pl.DataFrame({"a": [0.2, 0.3]}), + pl.DataFrame({"a": [0.2, 0.3]}), + {"atol": 1e-15}, + id="equal_floats_low_atol", + ), + pytest.param( + pl.DataFrame({"a": [0.2, 0.3]}), + pl.DataFrame({"a": [0.2, 0.3000000000000001]}), + {"atol": 1e-15}, + id="approx_equal_float_low_atol", + ), + pytest.param( + pl.DataFrame({"a": [0.2, 0.3]}), + pl.DataFrame({"a": [0.2, 0.31]}), + {"atol": 0.1}, + id="approx_equal_float_high_atol", + ), + pytest.param( + pl.DataFrame({"a": [0.2, 1.3]}), + pl.DataFrame({"a": [0.2, 0.9]}), + {"atol": 1}, + id="approx_equal_float_integer_atol", + ), + pytest.param( + pl.DataFrame({"a": [0.0, 1.0, 2.0]}, schema={"a": pl.Float64}), + pl.DataFrame({"a": [0, 1, 2]}, schema={"a": pl.Int64}), + {"check_dtype": False}, + id="equal_int_float_integer_no_check_dtype", + ), + pytest.param( + pl.DataFrame({"a": [0, 1, 2]}, schema={"a": pl.Float64}), + pl.DataFrame({"a": [0, 1, 2]}, schema={"a": pl.Float32}), + {"check_dtype": False}, + id="equal_int_float_integer_no_check_dtype", + ), + pytest.param( + pl.DataFrame({"a": [0, 1, 2]}, schema={"a": pl.Int64}), + pl.DataFrame({"a": [0, 1, 2]}, schema={"a": pl.Int64}), + {}, + id="equal_int", + ), + pytest.param( + pl.DataFrame({"a": ["a", "b", "c"]}, schema={"a": pl.Utf8}), + pl.DataFrame({"a": ["a", "b", "c"]}, schema={"a": pl.Utf8}), + {}, + id="equal_str", + ), + pytest.param( + pl.DataFrame({"a": [[0.2, 0.3]]}), + pl.DataFrame({"a": [[0.2, 0.300001]]}), + {"atol": 1e-5}, + id="list_of_float_low_atol", + ), + pytest.param( + pl.DataFrame({"a": [[0.2, 0.3]]}), + pl.DataFrame({"a": [[0.2, 0.31]]}), + {"atol": 0.1}, + id="list_of_float_high_atol", + ), + pytest.param( + pl.DataFrame({"a": [[0.2, 1.3]]}), + pl.DataFrame({"a": [[0.2, 0.9]]}), + {"atol": 1}, + id="list_of_float_integer_atol", + ), + pytest.param( + pl.DataFrame({"a": [[0.2, 0.3]]}), + pl.DataFrame({"a": [[0.2, 0.300000001]]}), + {"rtol": 1e-5}, + id="list_of_float_low_rtol", + ), + pytest.param( + pl.DataFrame({"a": [[0.2, 0.3]]}), + pl.DataFrame({"a": [[0.2, 0.301]]}), + {"rtol": 0.1}, + id="list_of_float_high_rtol", + ), + pytest.param( + pl.DataFrame({"a": [[0.2, 1.3]]}), + pl.DataFrame({"a": [[0.2, 0.9]]}), + {"rtol": 1}, + id="list_of_float_integer_rtol", + ), + pytest.param( + pl.DataFrame({"a": [[None, 1.3]]}), + pl.DataFrame({"a": [[None, 0.9]]}), + {"rtol": 1}, + id="list_of_none_and_float_integer_rtol", + ), + pytest.param( + pl.DataFrame({"a": [[None, 1.3]]}), + pl.DataFrame({"a": [[None, 0.9]]}), + {"rtol": 1, "nans_compare_equal": False}, + id="list_of_none_and_float_integer_rtol", + ), + pytest.param( + pl.DataFrame({"a": [[[0.2, 3.0]]]}), + pl.DataFrame({"a": [[[0.2, 3.00000001]]]}), + {"atol": 0.1, "nans_compare_equal": True}, + id="nested_list_of_float_atol_high_nans_compare_equal_true", + ), + pytest.param( + pl.DataFrame({"a": [[[0.2, 3.0]]]}), + pl.DataFrame({"a": [[[0.2, 3.00000001]]]}), + {"atol": 0.1, "nans_compare_equal": False}, + id="nested_list_of_float_atol_high_nans_compare_equal_false", + ), + ], +) +def test_assert_frame_equal_passes_assertion( + df1: pl.DataFrame, + df2: pl.DataFrame, + kwargs: dict[str, Any], +) -> None: + assert_frame_equal(df1, df2, **kwargs) + with pytest.raises(AssertionError): + assert_frame_not_equal(df1, df2, **kwargs) + + +@pytest.mark.parametrize( + ("df1", "df2", "kwargs"), + [ + pytest.param( + pl.DataFrame({"a": [[0.2, 0.3]]}), + pl.DataFrame({"a": [[0.2, 0.3, 0.4]]}), + {}, + id="list_of_float_different_lengths", + ), + pytest.param( + pl.DataFrame({"a": [[0.2, 0.3]]}), + pl.DataFrame({"a": [[0.2, 0.3000000000000001]]}), + {"check_exact": True}, + id="list_of_float_check_exact", + ), + pytest.param( + pl.DataFrame({"a": [[0.2, 0.3]]}), + pl.DataFrame({"a": [[0.2, 0.300001]]}), + {"atol": 1e-15, "rtol": 0}, + id="list_of_float_too_low_atol", + ), + pytest.param( + pl.DataFrame({"a": [[0.2, 0.3]]}), + pl.DataFrame({"a": [[0.2, 0.30000001]]}), + {"atol": -1, "rtol": 0}, + id="list_of_float_negative_atol", + ), + pytest.param( + pl.DataFrame({"a": [[math.nan, 1.3]]}), + pl.DataFrame({"a": [[math.nan, 0.9]]}), + {"rtol": 1, "nans_compare_equal": False}, + id="list_of_nan_and_float_integer_rtol", + ), + pytest.param( + pl.DataFrame({"a": [[2.0, 3.0]]}), + pl.DataFrame({"a": [[2, 3]]}), + {"check_exact": False, "check_dtype": True}, + id="list_of_float_list_of_int_check_dtype_true", + ), + pytest.param( + pl.DataFrame({"a": [[[0.2, math.nan, 3.0]]]}), + pl.DataFrame({"a": [[[0.2, math.nan, 3.11]]]}), + {"atol": 0.1, "rtol": 0, "nans_compare_equal": True}, + id="nested_list_of_float_and_nan_atol_high_nans_compare_equal_true", + ), + pytest.param( + pl.DataFrame({"a": [[[0.2, math.nan, 3.0]]]}), + pl.DataFrame({"a": [[[0.2, math.nan, 3.0]]]}), + {"nans_compare_equal": False}, + id="nested_list_of_float_and_nan_atol_high_nans_compare_equal_false", + ), + pytest.param( + pl.DataFrame({"a": [[[0.2, 3.0]]]}), + pl.DataFrame({"a": [[[0.2, 3.11]]]}), + {"atol": 0.1, "nans_compare_equal": False}, + id="nested_list_of_float_atol_high_nans_compare_equal_false", + ), + pytest.param( + pl.DataFrame({"a": [[[[0.2, 3.0]]]]}), + pl.DataFrame({"a": [[[[0.2, 3.11]]]]}), + {"atol": 0.1, "rtol": 0, "nans_compare_equal": True}, + id="double_nested_list_of_float_atol_high_nans_compare_equal_true", + ), + pytest.param( + pl.DataFrame({"a": [[[[0.2, math.nan, 3.0]]]]}), + pl.DataFrame({"a": [[[[0.2, math.nan, 3.0]]]]}), + {"atol": 0.1, "nans_compare_equal": False}, + id="double_nested_list_of_float_and_nan_atol_high_nans_compare_equal_false", + ), + pytest.param( + pl.DataFrame({"a": [[[[0.2, 3.0]]]]}), + pl.DataFrame({"a": [[[[0.2, 3.11]]]]}), + {"atol": 0.1, "rtol": 0, "nans_compare_equal": False}, + id="double_nested_list_of_float_atol_high_nans_compare_equal_false", + ), + pytest.param( + pl.DataFrame({"a": [[[[[0.2, 3.0]]]]]}), + pl.DataFrame({"a": [[[[[0.2, 3.11]]]]]}), + {"atol": 0.1, "rtol": 0, "nans_compare_equal": True}, + id="triple_nested_list_of_float_atol_high_nans_compare_equal_true", + ), + pytest.param( + pl.DataFrame({"a": [[[[[0.2, math.nan, 3.0]]]]]}), + pl.DataFrame({"a": [[[[[0.2, math.nan, 3.0]]]]]}), + {"atol": 0.1, "nans_compare_equal": False}, + id="triple_nested_list_of_float_and_nan_atol_high_nans_compare_equal_true", + ), + ], +) +def test_assert_frame_equal_raises_assertion_error( + df1: pl.DataFrame, + df2: pl.DataFrame, + kwargs: dict[str, Any], +) -> None: + with pytest.raises(AssertionError): + assert_frame_equal(df1, df2, **kwargs) + assert_frame_not_equal(df1, df2, **kwargs) + + +def test_compare_frame_equal_nans() -> None: + nan = float("NaN") + df1 = pl.DataFrame( + data={"x": [1.0, nan], "y": [nan, 2.0]}, + schema=[("x", pl.Float32), ("y", pl.Float64)], + ) + assert_frame_equal(df1, df1, check_exact=True) + + df2 = pl.DataFrame( + data={"x": [1.0, nan], "y": [None, 2.0]}, + schema=[("x", pl.Float32), ("y", pl.Float64)], + ) + assert_frame_not_equal(df1, df2) + with pytest.raises(AssertionError, match="values for column 'y' are different"): + assert_frame_equal(df1, df2, check_exact=True) + + +def test_compare_frame_equal_nested_nans() -> None: + nan = float("NaN") + + # list dtype + df1 = pl.DataFrame( + data={"x": [[1.0, nan]], "y": [[nan, 2.0]]}, + schema=[("x", pl.List(pl.Float32)), ("y", pl.List(pl.Float64))], + ) + assert_frame_equal(df1, df1, check_exact=True) + + df2 = pl.DataFrame( + data={"x": [[1.0, nan]], "y": [[None, 2.0]]}, + schema=[("x", pl.List(pl.Float32)), ("y", pl.List(pl.Float64))], + ) + assert_frame_not_equal(df1, df2) + with pytest.raises(AssertionError, match="values for column 'y' are different"): + assert_frame_equal(df1, df2, check_exact=True) + + # struct dtype + df3 = pl.from_dicts( + [ + { + "id": 1, + "struct": [ + {"x": "text", "y": [0.0, nan]}, + {"x": "text", "y": [0.0, nan]}, + ], + }, + { + "id": 2, + "struct": [ + {"x": "text", "y": [1]}, + {"x": "text", "y": [1]}, + ], + }, + ] + ) + df4 = pl.from_dicts( + [ + { + "id": 1, + "struct": [ + {"x": "text", "y": [0.0, nan], "z": ["$"]}, + {"x": "text", "y": [0.0, nan], "z": ["$"]}, + ], + }, + { + "id": 2, + "struct": [ + {"x": "text", "y": [nan, 1], "z": ["!"]}, + {"x": "text", "y": [nan, 1], "z": ["?"]}, + ], + }, + ] + ) + + assert_frame_equal(df3, df3) + assert_frame_not_equal(df3, df3, nans_compare_equal=False) + + assert_frame_equal(df4, df4) + assert_frame_not_equal(df4, df4, nans_compare_equal=False) + + assert_frame_not_equal(df3, df4) + for check_dtype in (True, False): + with pytest.raises(AssertionError, match="mismatch|different"): + assert_frame_equal(df3, df4, check_dtype=check_dtype) + + +def test_assert_frame_equal_pass() -> None: + df1 = pl.DataFrame({"a": [1, 2]}) + df2 = pl.DataFrame({"a": [1, 2]}) + assert_frame_equal(df1, df2) + + +def test_assert_frame_equal_types() -> None: + df1 = pl.DataFrame({"a": [1, 2]}) + srs1 = pl.Series(values=[1, 2], name="a") + with pytest.raises( + AssertionError, match=r"inputs are different \(unexpected input types\)" + ): + assert_frame_equal(df1, srs1) # type: ignore[arg-type] + + +def test_assert_frame_equal_length_mismatch() -> None: + df1 = pl.DataFrame({"a": [1, 2]}) + df2 = pl.DataFrame({"a": [1, 2, 3]}) + with pytest.raises( + AssertionError, + match=r"DataFrames are different \(number of rows does not match\)", + ): + assert_frame_equal(df1, df2) + + +def test_assert_frame_equal_column_mismatch() -> None: + df1 = pl.DataFrame({"a": [1, 2]}) + df2 = pl.DataFrame({"b": [1, 2]}) + with pytest.raises( + AssertionError, match="columns \\['a'\\] in left DataFrame, but not in right" + ): + assert_frame_equal(df1, df2) + + +def test_assert_frame_equal_column_mismatch2() -> None: + df1 = pl.LazyFrame({"a": [1, 2]}) + df2 = pl.LazyFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]}) + with pytest.raises( + AssertionError, + match="columns \\['b', 'c'\\] in right LazyFrame, but not in left", + ): + assert_frame_equal(df1, df2) + + +def test_assert_frame_equal_column_mismatch_order() -> None: + df1 = pl.DataFrame({"b": [3, 4], "a": [1, 2]}) + df2 = pl.DataFrame({"a": [1, 2], "b": [3, 4]}) + with pytest.raises(AssertionError, match="columns are not in the same order"): + assert_frame_equal(df1, df2) + + assert_frame_equal(df1, df2, check_column_order=False) + + +def test_assert_frame_equal_ignore_row_order() -> None: + df1 = pl.DataFrame({"a": [1, 2], "b": [4, 3]}) + df2 = pl.DataFrame({"a": [2, 1], "b": [3, 4]}) + df3 = pl.DataFrame({"b": [3, 4], "a": [2, 1]}) + with pytest.raises(AssertionError, match="values for column 'a' are different"): + assert_frame_equal(df1, df2) + + assert_frame_equal(df1, df2, check_row_order=False) + # eg: + # ┌─────┬─────┐ ┌─────┬─────┐ + # │ a ┆ b │ │ a ┆ b │ + # │ --- ┆ --- │ │ --- ┆ --- │ + # │ i64 ┆ i64 │ (eq) │ i64 ┆ i64 │ + # ╞═════╪═════╡ == ╞═════╪═════╡ + # │ 1 ┆ 4 │ │ 2 ┆ 3 │ + # │ 2 ┆ 3 │ │ 1 ┆ 4 │ + # └─────┴─────┘ └─────┴─────┘ + + with pytest.raises(AssertionError, match="columns are not in the same order"): + assert_frame_equal(df1, df3, check_row_order=False) + + assert_frame_equal(df1, df3, check_row_order=False, check_column_order=False) + + # note: not all column types support sorting + with pytest.raises( + InvalidAssert, + match="cannot set `check_row_order=False`.*unsortable columns", + ): + assert_frame_equal( + left=pl.DataFrame({"a": [[1, 2], [3, 4]], "b": [3, 4]}), + right=pl.DataFrame({"a": [[3, 4], [1, 2]], "b": [4, 3]}), + check_row_order=False, + ) + + +def test_assert_frame_equal_dtypes_mismatch() -> None: + data = {"a": [1, 2], "b": [3, 4]} + df1 = pl.DataFrame(data, schema={"a": pl.Int8, "b": pl.Int16}) + df2 = pl.DataFrame(data, schema={"b": pl.Int16, "a": pl.Int16}) + + with pytest.raises(AssertionError, match="dtypes do not match"): + assert_frame_equal(df1, df2, check_column_order=False) + + +def test_assert_frame_not_equal() -> None: + df = pl.DataFrame({"a": [1, 2]}) + with pytest.raises(AssertionError, match="frames are equal"): + assert_frame_not_equal(df, df) diff --git a/py-polars/tests/unit/testing/test_testing.py b/py-polars/tests/unit/testing/test_assert_series_equal.py similarity index 60% rename from py-polars/tests/unit/testing/test_testing.py rename to py-polars/tests/unit/testing/test_assert_series_equal.py index 74f352dd0bf0..13c298fc1746 100644 --- a/py-polars/tests/unit/testing/test_testing.py +++ b/py-polars/tests/unit/testing/test_assert_series_equal.py @@ -7,13 +7,7 @@ import pytest import polars as pl -from polars.exceptions import InvalidAssert -from polars.testing import ( - assert_frame_equal, - assert_frame_not_equal, - assert_series_equal, - assert_series_not_equal, -) +from polars.testing import assert_series_equal, assert_series_not_equal def test_compare_series_value_mismatch() -> None: @@ -146,7 +140,7 @@ def test_compare_series_type_mismatch() -> None: srs2 = pl.DataFrame({"col1": [2, 3, 4]}) with pytest.raises( - AssertionError, match=r"Inputs are different \(unexpected input types\)" + AssertionError, match=r"inputs are different \(unexpected input types\)" ): assert_series_equal(srs1, srs2) # type: ignore[arg-type] @@ -185,399 +179,6 @@ def test_compare_series_value_exact_mismatch() -> None: assert_series_equal(srs1, srs2, check_exact=True) -def test_compare_frame_equal_nans() -> None: - nan = float("NaN") - df1 = pl.DataFrame( - data={"x": [1.0, nan], "y": [nan, 2.0]}, - schema=[("x", pl.Float32), ("y", pl.Float64)], - ) - assert_frame_equal(df1, df1, check_exact=True) - - df2 = pl.DataFrame( - data={"x": [1.0, nan], "y": [None, 2.0]}, - schema=[("x", pl.Float32), ("y", pl.Float64)], - ) - assert_frame_not_equal(df1, df2) - with pytest.raises(AssertionError, match="values for column 'y' are different"): - assert_frame_equal(df1, df2, check_exact=True) - - -def test_compare_frame_equal_nested_nans() -> None: - nan = float("NaN") - - # list dtype - df1 = pl.DataFrame( - data={"x": [[1.0, nan]], "y": [[nan, 2.0]]}, - schema=[("x", pl.List(pl.Float32)), ("y", pl.List(pl.Float64))], - ) - assert_frame_equal(df1, df1, check_exact=True) - - df2 = pl.DataFrame( - data={"x": [[1.0, nan]], "y": [[None, 2.0]]}, - schema=[("x", pl.List(pl.Float32)), ("y", pl.List(pl.Float64))], - ) - assert_frame_not_equal(df1, df2) - with pytest.raises(AssertionError, match="values for column 'y' are different"): - assert_frame_equal(df1, df2, check_exact=True) - - # struct dtype - df3 = pl.from_dicts( - [ - { - "id": 1, - "struct": [ - {"x": "text", "y": [0.0, nan]}, - {"x": "text", "y": [0.0, nan]}, - ], - }, - { - "id": 2, - "struct": [ - {"x": "text", "y": [1]}, - {"x": "text", "y": [1]}, - ], - }, - ] - ) - df4 = pl.from_dicts( - [ - { - "id": 1, - "struct": [ - {"x": "text", "y": [0.0, nan], "z": ["$"]}, - {"x": "text", "y": [0.0, nan], "z": ["$"]}, - ], - }, - { - "id": 2, - "struct": [ - {"x": "text", "y": [nan, 1], "z": ["!"]}, - {"x": "text", "y": [nan, 1], "z": ["?"]}, - ], - }, - ] - ) - - assert_frame_equal(df3, df3) - assert_frame_not_equal(df3, df3, nans_compare_equal=False) - - assert_frame_equal(df4, df4) - assert_frame_not_equal(df4, df4, nans_compare_equal=False) - - assert_frame_not_equal(df3, df4) - for check_dtype in (True, False): - with pytest.raises(AssertionError, match="mismatch|different"): - assert_frame_equal(df3, df4, check_dtype=check_dtype) - - -def test_assert_frame_equal_pass() -> None: - df1 = pl.DataFrame({"a": [1, 2]}) - df2 = pl.DataFrame({"a": [1, 2]}) - assert_frame_equal(df1, df2) - - -def test_assert_frame_equal_types() -> None: - df1 = pl.DataFrame({"a": [1, 2]}) - srs1 = pl.Series(values=[1, 2], name="a") - with pytest.raises( - AssertionError, match=r"Inputs are different \(unexpected input types\)" - ): - assert_frame_equal(df1, srs1) # type: ignore[arg-type] - - -def test_assert_frame_equal_length_mismatch() -> None: - df1 = pl.DataFrame({"a": [1, 2]}) - df2 = pl.DataFrame({"a": [1, 2, 3]}) - with pytest.raises( - AssertionError, match=r"DataFrames are different \(length mismatch\)" - ): - assert_frame_equal(df1, df2) - - -def test_assert_frame_equal_column_mismatch() -> None: - df1 = pl.DataFrame({"a": [1, 2]}) - df2 = pl.DataFrame({"b": [1, 2]}) - with pytest.raises( - AssertionError, match="columns \\['a'\\] in left frame, but not in right" - ): - assert_frame_equal(df1, df2) - - -def test_assert_frame_equal_column_mismatch2() -> None: - df1 = pl.DataFrame({"a": [1, 2]}) - df2 = pl.DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]}) - with pytest.raises( - AssertionError, match="columns \\['b', 'c'\\] in right frame, but not in left" - ): - assert_frame_equal(df1, df2) - - -def test_assert_frame_equal_column_mismatch_order() -> None: - df1 = pl.DataFrame({"b": [3, 4], "a": [1, 2]}) - df2 = pl.DataFrame({"a": [1, 2], "b": [3, 4]}) - with pytest.raises(AssertionError, match="columns are not in the same order"): - assert_frame_equal(df1, df2) - - assert_frame_equal(df1, df2, check_column_order=False) - - -def test_assert_frame_equal_ignore_row_order() -> None: - df1 = pl.DataFrame({"a": [1, 2], "b": [4, 3]}) - df2 = pl.DataFrame({"a": [2, 1], "b": [3, 4]}) - df3 = pl.DataFrame({"b": [3, 4], "a": [2, 1]}) - with pytest.raises(AssertionError, match="values for column 'a' are different"): - assert_frame_equal(df1, df2) - - assert_frame_equal(df1, df2, check_row_order=False) - # eg: - # ┌─────┬─────┐ ┌─────┬─────┐ - # │ a ┆ b │ │ a ┆ b │ - # │ --- ┆ --- │ │ --- ┆ --- │ - # │ i64 ┆ i64 │ (eq) │ i64 ┆ i64 │ - # ╞═════╪═════╡ == ╞═════╪═════╡ - # │ 1 ┆ 4 │ │ 2 ┆ 3 │ - # │ 2 ┆ 3 │ │ 1 ┆ 4 │ - # └─────┴─────┘ └─────┴─────┘ - - with pytest.raises(AssertionError, match="columns are not in the same order"): - assert_frame_equal(df1, df3, check_row_order=False) - - assert_frame_equal(df1, df3, check_row_order=False, check_column_order=False) - - # note: not all column types support sorting - with pytest.raises( - InvalidAssert, - match="cannot set `check_row_order=False`.*unsortable columns", - ): - assert_frame_equal( - left=pl.DataFrame({"a": [[1, 2], [3, 4]], "b": [3, 4]}), - right=pl.DataFrame({"a": [[3, 4], [1, 2]], "b": [4, 3]}), - check_row_order=False, - ) - - -@pytest.mark.parametrize( - ("df1", "df2", "kwargs"), - [ - pytest.param( - pl.DataFrame({"a": [0.2, 0.3]}), - pl.DataFrame({"a": [0.2, 0.3]}), - {"atol": 1e-15}, - id="equal_floats_low_atol", - ), - pytest.param( - pl.DataFrame({"a": [0.2, 0.3]}), - pl.DataFrame({"a": [0.2, 0.3000000000000001]}), - {"atol": 1e-15}, - id="approx_equal_float_low_atol", - ), - pytest.param( - pl.DataFrame({"a": [0.2, 0.3]}), - pl.DataFrame({"a": [0.2, 0.31]}), - {"atol": 0.1}, - id="approx_equal_float_high_atol", - ), - pytest.param( - pl.DataFrame({"a": [0.2, 1.3]}), - pl.DataFrame({"a": [0.2, 0.9]}), - {"atol": 1}, - id="approx_equal_float_integer_atol", - ), - pytest.param( - pl.DataFrame({"a": [0.0, 1.0, 2.0]}, schema={"a": pl.Float64}), - pl.DataFrame({"a": [0, 1, 2]}, schema={"a": pl.Int64}), - {"check_dtype": False}, - id="equal_int_float_integer_no_check_dtype", - ), - pytest.param( - pl.DataFrame({"a": [0, 1, 2]}, schema={"a": pl.Float64}), - pl.DataFrame({"a": [0, 1, 2]}, schema={"a": pl.Float32}), - {"check_dtype": False}, - id="equal_int_float_integer_no_check_dtype", - ), - pytest.param( - pl.DataFrame({"a": [0, 1, 2]}, schema={"a": pl.Int64}), - pl.DataFrame({"a": [0, 1, 2]}, schema={"a": pl.Int64}), - {}, - id="equal_int", - ), - pytest.param( - pl.DataFrame({"a": ["a", "b", "c"]}, schema={"a": pl.Utf8}), - pl.DataFrame({"a": ["a", "b", "c"]}, schema={"a": pl.Utf8}), - {}, - id="equal_str", - ), - pytest.param( - pl.DataFrame({"a": [[0.2, 0.3]]}), - pl.DataFrame({"a": [[0.2, 0.300001]]}), - {"atol": 1e-5}, - id="list_of_float_low_atol", - ), - pytest.param( - pl.DataFrame({"a": [[0.2, 0.3]]}), - pl.DataFrame({"a": [[0.2, 0.31]]}), - {"atol": 0.1}, - id="list_of_float_high_atol", - ), - pytest.param( - pl.DataFrame({"a": [[0.2, 1.3]]}), - pl.DataFrame({"a": [[0.2, 0.9]]}), - {"atol": 1}, - id="list_of_float_integer_atol", - ), - pytest.param( - pl.DataFrame({"a": [[0.2, 0.3]]}), - pl.DataFrame({"a": [[0.2, 0.300000001]]}), - {"rtol": 1e-5}, - id="list_of_float_low_rtol", - ), - pytest.param( - pl.DataFrame({"a": [[0.2, 0.3]]}), - pl.DataFrame({"a": [[0.2, 0.301]]}), - {"rtol": 0.1}, - id="list_of_float_high_rtol", - ), - pytest.param( - pl.DataFrame({"a": [[0.2, 1.3]]}), - pl.DataFrame({"a": [[0.2, 0.9]]}), - {"rtol": 1}, - id="list_of_float_integer_rtol", - ), - pytest.param( - pl.DataFrame({"a": [[None, 1.3]]}), - pl.DataFrame({"a": [[None, 0.9]]}), - {"rtol": 1}, - id="list_of_none_and_float_integer_rtol", - ), - pytest.param( - pl.DataFrame({"a": [[None, 1.3]]}), - pl.DataFrame({"a": [[None, 0.9]]}), - {"rtol": 1, "nans_compare_equal": False}, - id="list_of_none_and_float_integer_rtol", - ), - pytest.param( - pl.DataFrame({"a": [[[0.2, 3.0]]]}), - pl.DataFrame({"a": [[[0.2, 3.00000001]]]}), - {"atol": 0.1, "nans_compare_equal": True}, - id="nested_list_of_float_atol_high_nans_compare_equal_true", - ), - pytest.param( - pl.DataFrame({"a": [[[0.2, 3.0]]]}), - pl.DataFrame({"a": [[[0.2, 3.00000001]]]}), - {"atol": 0.1, "nans_compare_equal": False}, - id="nested_list_of_float_atol_high_nans_compare_equal_false", - ), - ], -) -def test_assert_frame_equal_passes_assertion( - df1: pl.DataFrame, - df2: pl.DataFrame, - kwargs: Any, -) -> None: - assert_frame_equal(df1, df2, **kwargs) - with pytest.raises(AssertionError): - assert_frame_not_equal(df1, df2, **kwargs) - - -@pytest.mark.parametrize( - ("df1", "df2", "kwargs"), - [ - pytest.param( - pl.DataFrame({"a": [[0.2, 0.3]]}), - pl.DataFrame({"a": [[0.2, 0.3, 0.4]]}), - {}, - id="list_of_float_different_lengths", - ), - pytest.param( - pl.DataFrame({"a": [[0.2, 0.3]]}), - pl.DataFrame({"a": [[0.2, 0.3000000000000001]]}), - {"check_exact": True}, - id="list_of_float_check_exact", - ), - pytest.param( - pl.DataFrame({"a": [[0.2, 0.3]]}), - pl.DataFrame({"a": [[0.2, 0.300001]]}), - {"atol": 1e-15, "rtol": 0}, - id="list_of_float_too_low_atol", - ), - pytest.param( - pl.DataFrame({"a": [[0.2, 0.3]]}), - pl.DataFrame({"a": [[0.2, 0.30000001]]}), - {"atol": -1, "rtol": 0}, - id="list_of_float_negative_atol", - ), - pytest.param( - pl.DataFrame({"a": [[math.nan, 1.3]]}), - pl.DataFrame({"a": [[math.nan, 0.9]]}), - {"rtol": 1, "nans_compare_equal": False}, - id="list_of_nan_and_float_integer_rtol", - ), - pytest.param( - pl.DataFrame({"a": [[2.0, 3.0]]}), - pl.DataFrame({"a": [[2, 3]]}), - {"check_exact": False, "check_dtype": True}, - id="list_of_float_list_of_int_check_dtype_true", - ), - pytest.param( - pl.DataFrame({"a": [[[0.2, math.nan, 3.0]]]}), - pl.DataFrame({"a": [[[0.2, math.nan, 3.11]]]}), - {"atol": 0.1, "rtol": 0, "nans_compare_equal": True}, - id="nested_list_of_float_and_nan_atol_high_nans_compare_equal_true", - ), - pytest.param( - pl.DataFrame({"a": [[[0.2, math.nan, 3.0]]]}), - pl.DataFrame({"a": [[[0.2, math.nan, 3.0]]]}), - {"nans_compare_equal": False}, - id="nested_list_of_float_and_nan_atol_high_nans_compare_equal_false", - ), - pytest.param( - pl.DataFrame({"a": [[[0.2, 3.0]]]}), - pl.DataFrame({"a": [[[0.2, 3.11]]]}), - {"atol": 0.1, "nans_compare_equal": False}, - id="nested_list_of_float_atol_high_nans_compare_equal_false", - ), - pytest.param( - pl.DataFrame({"a": [[[[0.2, 3.0]]]]}), - pl.DataFrame({"a": [[[[0.2, 3.11]]]]}), - {"atol": 0.1, "rtol": 0, "nans_compare_equal": True}, - id="double_nested_list_of_float_atol_high_nans_compare_equal_true", - ), - pytest.param( - pl.DataFrame({"a": [[[[0.2, math.nan, 3.0]]]]}), - pl.DataFrame({"a": [[[[0.2, math.nan, 3.0]]]]}), - {"atol": 0.1, "nans_compare_equal": False}, - id="double_nested_list_of_float_and_nan_atol_high_nans_compare_equal_false", - ), - pytest.param( - pl.DataFrame({"a": [[[[0.2, 3.0]]]]}), - pl.DataFrame({"a": [[[[0.2, 3.11]]]]}), - {"atol": 0.1, "rtol": 0, "nans_compare_equal": False}, - id="double_nested_list_of_float_atol_high_nans_compare_equal_false", - ), - pytest.param( - pl.DataFrame({"a": [[[[[0.2, 3.0]]]]]}), - pl.DataFrame({"a": [[[[[0.2, 3.11]]]]]}), - {"atol": 0.1, "rtol": 0, "nans_compare_equal": True}, - id="triple_nested_list_of_float_atol_high_nans_compare_equal_true", - ), - pytest.param( - pl.DataFrame({"a": [[[[[0.2, math.nan, 3.0]]]]]}), - pl.DataFrame({"a": [[[[[0.2, math.nan, 3.0]]]]]}), - {"atol": 0.1, "nans_compare_equal": False}, - id="triple_nested_list_of_float_and_nan_atol_high_nans_compare_equal_true", - ), - ], -) -def test_assert_frame_equal_raises_assertion_error( - df1: pl.DataFrame, - df2: pl.DataFrame, - kwargs: Any, -) -> None: - with pytest.raises(AssertionError): - assert_frame_equal(df1, df2, **kwargs) - assert_frame_not_equal(df1, df2, **kwargs) - - def test_assert_series_equal_int_overflow() -> None: # internally may call 'abs' if not check_exact, which can overflow on signed int s0 = pl.Series([-128], dtype=pl.Int8) @@ -1046,12 +647,6 @@ def test_assert_series_equal_full_series() -> None: assert_series_equal(s1, s2) -def test_assert_frame_not_equal() -> None: - df = pl.DataFrame({"a": [1, 2]}) - with pytest.raises(AssertionError, match="frames are equal"): - assert_frame_not_equal(df, df) - - def test_assert_series_not_equal() -> None: s = pl.Series("a", [1, 2]) with pytest.raises(AssertionError, match="Series are equal"): From 28a99f6336367246d33e7a8b82cd5e609256e68f Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Thu, 19 Oct 2023 11:36:35 +0800 Subject: [PATCH 043/103] chore(rust): Move round to ops (#11838) --- crates/polars-core/Cargo.toml | 1 - crates/polars-core/src/series/mod.rs | 9 --- crates/polars-core/src/series/ops/mod.rs | 2 - crates/polars-ops/src/series/ops/mod.rs | 4 ++ .../src/series/ops/round.rs | 66 ++++++++++++------- crates/polars-plan/Cargo.toml | 2 +- crates/polars/Cargo.toml | 2 +- 7 files changed, 50 insertions(+), 36 deletions(-) rename crates/{polars-core => polars-ops}/src/series/ops/round.rs (55%) diff --git a/crates/polars-core/Cargo.toml b/crates/polars-core/Cargo.toml index 12a8a536ebdd..8ec0afe8a936 100644 --- a/crates/polars-core/Cargo.toml +++ b/crates/polars-core/Cargo.toml @@ -138,7 +138,6 @@ docs-selection = [ "temporal", "random", "zip_with", - "round_series", "checked_arithmetic", "is_first_distinct", "is_last_distinct", diff --git a/crates/polars-core/src/series/mod.rs b/crates/polars-core/src/series/mod.rs index d8672a93ae56..5b978a29ce3c 100644 --- a/crates/polars-core/src/series/mod.rs +++ b/crates/polars-core/src/series/mod.rs @@ -1043,13 +1043,4 @@ mod test { let _ = series.slice(-6, 2); let _ = series.slice(4, 2); } - - #[test] - #[cfg(feature = "round_series")] - fn test_round_series() { - let series = Series::new("a", &[1.003, 2.23222, 3.4352]); - let out = series.round(2).unwrap(); - let ca = out.f64().unwrap(); - assert_eq!(ca.get(0), Some(1.0)); - } } diff --git a/crates/polars-core/src/series/ops/mod.rs b/crates/polars-core/src/series/ops/mod.rs index 48766748f58b..650e5cbecaa4 100644 --- a/crates/polars-core/src/series/ops/mod.rs +++ b/crates/polars-core/src/series/ops/mod.rs @@ -3,8 +3,6 @@ mod extend; #[cfg(feature = "moment")] pub mod moment; mod null; -#[cfg(feature = "round_series")] -mod round; mod to_list; mod unique; #[cfg(feature = "serde")] diff --git a/crates/polars-ops/src/series/ops/mod.rs b/crates/polars-ops/src/series/ops/mod.rs index b771520bd192..ea7f4a9a455e 100644 --- a/crates/polars-ops/src/series/ops/mod.rs +++ b/crates/polars-ops/src/series/ops/mod.rs @@ -36,6 +36,8 @@ mod rank; mod rle; #[cfg(feature = "rolling_window")] mod rolling; +#[cfg(feature = "round_series")] +mod round; #[cfg(feature = "search_sorted")] mod search_sorted; #[cfg(feature = "to_dummies")] @@ -81,6 +83,8 @@ pub use rank::*; pub use rle::*; #[cfg(feature = "rolling_window")] pub use rolling::*; +#[cfg(feature = "round_series")] +pub use round::*; #[cfg(feature = "search_sorted")] pub use search_sorted::*; #[cfg(feature = "to_dummies")] diff --git a/crates/polars-core/src/series/ops/round.rs b/crates/polars-ops/src/series/ops/round.rs similarity index 55% rename from crates/polars-core/src/series/ops/round.rs rename to crates/polars-ops/src/series/ops/round.rs index 37abe7797941..542107ed5871 100644 --- a/crates/polars-core/src/series/ops/round.rs +++ b/crates/polars-ops/src/series/ops/round.rs @@ -1,14 +1,17 @@ use num_traits::pow::Pow; +use polars_core::prelude::*; -use crate::prelude::*; +use crate::series::ops::SeriesSealed; -impl Series { +pub trait RoundSeries: SeriesSealed { /// Round underlying floating point array to given decimal. - pub fn round(&self, decimals: u32) -> PolarsResult { - if let Ok(ca) = self.f32() { - if decimals == 0 { + fn round(&self, decimals: u32) -> PolarsResult { + let s = self.as_series(); + + if let Ok(ca) = s.f32() { + return if decimals == 0 { let s = ca.apply_values(|val| val.round()).into_series(); - return Ok(s); + Ok(s) } else { // Note we do the computation on f64 floats to not lose precision // when the computation is done, we cast to f32 @@ -16,47 +19,66 @@ impl Series { let s = ca .apply_values(|val| ((val as f64 * multiplier).round() / multiplier) as f32) .into_series(); - return Ok(s); - } + Ok(s) + }; } - if let Ok(ca) = self.f64() { - if decimals == 0 { + if let Ok(ca) = s.f64() { + return if decimals == 0 { let s = ca.apply_values(|val| val.round()).into_series(); - return Ok(s); + Ok(s) } else { let multiplier = 10.0.pow(decimals as f64); let s = ca .apply_values(|val| (val * multiplier).round() / multiplier) .into_series(); - return Ok(s); - } + Ok(s) + }; } - polars_bail!(opq = round, self.dtype()); + polars_bail!(opq = round, s.dtype()); } /// Floor underlying floating point array to the lowest integers smaller or equal to the float value. - pub fn floor(&self) -> PolarsResult { - if let Ok(ca) = self.f32() { + fn floor(&self) -> PolarsResult { + let s = self.as_series(); + + if let Ok(ca) = s.f32() { let s = ca.apply_values(|val| val.floor()).into_series(); return Ok(s); } - if let Ok(ca) = self.f64() { + if let Ok(ca) = s.f64() { let s = ca.apply_values(|val| val.floor()).into_series(); return Ok(s); } - polars_bail!(opq = floor, self.dtype()); + polars_bail!(opq = floor, s.dtype()); } /// Ceil underlying floating point array to the highest integers smaller or equal to the float value. - pub fn ceil(&self) -> PolarsResult { - if let Ok(ca) = self.f32() { + fn ceil(&self) -> PolarsResult { + let s = self.as_series(); + + if let Ok(ca) = s.f32() { let s = ca.apply_values(|val| val.ceil()).into_series(); return Ok(s); } - if let Ok(ca) = self.f64() { + if let Ok(ca) = s.f64() { let s = ca.apply_values(|val| val.ceil()).into_series(); return Ok(s); } - polars_bail!(opq = ceil, self.dtype()); + polars_bail!(opq = ceil, s.dtype()); + } +} + +impl RoundSeries for Series {} + +#[cfg(test)] +mod test { + use crate::prelude::*; + + #[test] + fn test_round_series() { + let series = Series::new("a", &[1.003, 2.23222, 3.4352]); + let out = series.round(2).unwrap(); + let ca = out.f64().unwrap(); + assert_eq!(ca.get(0), Some(1.0)); } } diff --git a/crates/polars-plan/Cargo.toml b/crates/polars-plan/Cargo.toml index 99fbd668c782..1a088bc806ef 100644 --- a/crates/polars-plan/Cargo.toml +++ b/crates/polars-plan/Cargo.toml @@ -89,7 +89,7 @@ extract_jsonpath = ["polars-ops/extract_jsonpath"] approx_unique = ["polars-ops/approx_unique"] is_in = ["polars-ops/is_in"] repeat_by = ["polars-ops/repeat_by"] -round_series = ["polars-core/round_series"] +round_series = ["polars-ops/round_series"] is_first_distinct = ["polars-core/is_first_distinct", "polars-ops/is_first_distinct"] is_last_distinct = ["polars-core/is_last_distinct", "polars-ops/is_last_distinct"] is_unique = ["polars-ops/is_unique"] diff --git a/crates/polars/Cargo.toml b/crates/polars/Cargo.toml index 3ca2884540ca..369b1d07f9cf 100644 --- a/crates/polars/Cargo.toml +++ b/crates/polars/Cargo.toml @@ -113,7 +113,7 @@ sort_multiple = ["polars-core/sort_multiple"] approx_unique = ["polars-lazy?/approx_unique", "polars-ops/approx_unique"] is_in = ["polars-lazy?/is_in"] zip_with = ["polars-core/zip_with", "polars-ops/zip_with"] -round_series = ["polars-core/round_series", "polars-lazy?/round_series", "polars-ops/round_series"] +round_series = ["polars-ops/round_series", "polars-lazy?/round_series"] checked_arithmetic = ["polars-core/checked_arithmetic"] repeat_by = ["polars-ops/repeat_by", "polars-lazy?/repeat_by"] is_first_distinct = ["polars-lazy?/is_first_distinct", "polars-ops/is_first_distinct"] From a42185f9ecc7c5016dc5a3cecc40ee9f200e88cf Mon Sep 17 00:00:00 2001 From: Eric Woolsey Date: Thu, 19 Oct 2023 01:08:28 -0700 Subject: [PATCH 044/103] docs(rust): Update doc comments for with_column to reflect that columns can be updated (#11840) --- crates/polars-lazy/src/frame/mod.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/polars-lazy/src/frame/mod.rs b/crates/polars-lazy/src/frame/mod.rs index 261fbaaaf517..20292101113e 100644 --- a/crates/polars-lazy/src/frame/mod.rs +++ b/crates/polars-lazy/src/frame/mod.rs @@ -1182,7 +1182,7 @@ impl LazyFrame { JoinBuilder::new(self) } - /// Add a column, given as an expression, to a DataFrame. + /// Add or replace a column, given as an expression, to a DataFrame. /// /// # Example /// @@ -1214,7 +1214,7 @@ impl LazyFrame { Self::from_logical_plan(lp, opt_state) } - /// Add multiple columns, given as expressions, to a DataFrame. + /// Add or replace multiple columns, given as expressions, to a DataFrame. /// /// # Example /// @@ -1239,7 +1239,7 @@ impl LazyFrame { ) } - /// Add multiple columns to a DataFrame, but evaluate them sequentially. + /// Add or replace multiple columns to a DataFrame, but evaluate them sequentially. pub fn with_columns_seq>(self, exprs: E) -> LazyFrame { let exprs = exprs.as_ref().to_vec(); self.with_columns_impl( From a05b29862ccced75db8fcc8b503f39dcc612ceed Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Thu, 19 Oct 2023 21:02:29 +0800 Subject: [PATCH 045/103] fix: propagate validity when cast primitive to list (#11846) --- crates/polars-arrow/src/compute/cast/mod.rs | 7 ++++++- py-polars/tests/unit/datatypes/test_list.py | 8 ++++---- py-polars/tests/unit/series/test_series.py | 12 ++++++++++++ 3 files changed, 22 insertions(+), 5 deletions(-) diff --git a/crates/polars-arrow/src/compute/cast/mod.rs b/crates/polars-arrow/src/compute/cast/mod.rs index 4ca928b4a98c..bbabbe279439 100644 --- a/crates/polars-arrow/src/compute/cast/mod.rs +++ b/crates/polars-arrow/src/compute/cast/mod.rs @@ -520,7 +520,12 @@ pub fn cast( // Safety: offsets _are_ monotonically increasing let offsets = unsafe { Offsets::new_unchecked(offsets) }; - let list_array = ListArray::::new(to_type.clone(), offsets.into(), values, None); + let list_array = ListArray::::new( + to_type.clone(), + offsets.into(), + values, + array.validity().cloned(), + ); Ok(Box::new(list_array)) }, diff --git a/py-polars/tests/unit/datatypes/test_list.py b/py-polars/tests/unit/datatypes/test_list.py index abe5ec123780..9806a643bf0a 100644 --- a/py-polars/tests/unit/datatypes/test_list.py +++ b/py-polars/tests/unit/datatypes/test_list.py @@ -468,10 +468,10 @@ def test_list_recursive_categorical_cast() -> None: @pytest.mark.parametrize( ("data", "expected_data", "dtype"), [ - ([1, 2], [[1], [2]], pl.Int64), - ([1.0, 2.0], [[1.0], [2.0]], pl.Float64), - (["x", "y"], [["x"], ["y"]], pl.Utf8), - ([True, False], [[True], [False]], pl.Boolean), + ([None, 1, 2], [None, [1], [2]], pl.Int64), + ([None, 1.0, 2.0], [None, [1.0], [2.0]], pl.Float64), + ([None, "x", "y"], [None, ["x"], ["y"]], pl.Utf8), + ([None, True, False], [None, [True], [False]], pl.Boolean), ], ) def test_non_nested_cast_to_list( diff --git a/py-polars/tests/unit/series/test_series.py b/py-polars/tests/unit/series/test_series.py index 54ba6288e2fa..6d31979f973c 100644 --- a/py-polars/tests/unit/series/test_series.py +++ b/py-polars/tests/unit/series/test_series.py @@ -243,6 +243,18 @@ def test_equal() -> None: assert s3.dt.convert_time_zone("Asia/Tokyo").series_equal(s4) is True +@pytest.mark.parametrize( + "dtype", + [pl.Int64, pl.Float64, pl.Utf8, pl.Boolean], +) +def test_eq_missing_list_and_primitive(dtype: PolarsDataType) -> None: + s1 = pl.Series([None, None], dtype=dtype) + s2 = pl.Series([None, None], dtype=pl.List(dtype)) + + expected = pl.Series([True, True]) + assert_series_equal(s1.eq_missing(s2), expected) + + def test_to_frame() -> None: s1 = pl.Series([1, 2]) s2 = pl.Series("s", [1, 2]) From d39c360d4b4b5dd97de39d3cc610bc2d267fa7d9 Mon Sep 17 00:00:00 2001 From: nameexhaustion Date: Fri, 20 Oct 2023 01:12:45 +1100 Subject: [PATCH 046/103] fix(rust): panic on hive scan from cloud (#11847) --- crates/polars-io/src/parquet/read.rs | 7 +++++++ .../src/physical_plan/executors/scan/parquet.rs | 3 ++- crates/polars-pipe/src/executors/sources/parquet.rs | 3 ++- crates/polars-plan/src/logical_plan/schema.rs | 6 +++++- 4 files changed, 16 insertions(+), 3 deletions(-) diff --git a/crates/polars-io/src/parquet/read.rs b/crates/polars-io/src/parquet/read.rs index 2586ce0b25d8..4df47cf9e658 100644 --- a/crates/polars-io/src/parquet/read.rs +++ b/crates/polars-io/src/parquet/read.rs @@ -129,6 +129,13 @@ impl ParquetReader { self } + /// Set the [`Schema`] if already known. This must be exactly the same as + /// the schema in the file itself. + pub fn with_schema(mut self, schema: Option) -> Self { + self.schema = schema; + self + } + /// [`Schema`] of the file. pub fn schema(&mut self) -> PolarsResult { match &self.schema { diff --git a/crates/polars-lazy/src/physical_plan/executors/scan/parquet.rs b/crates/polars-lazy/src/physical_plan/executors/scan/parquet.rs index 143f3eeaaa43..3579bd9de004 100644 --- a/crates/polars-lazy/src/physical_plan/executors/scan/parquet.rs +++ b/crates/polars-lazy/src/physical_plan/executors/scan/parquet.rs @@ -57,6 +57,7 @@ impl ParquetExec { if let Some(file) = file { ParquetReader::new(file) + .with_schema(Some(self.file_info.reader_schema.clone())) .with_n_rows(n_rows) .read_parallel(self.options.parallel) .with_row_count(mem::take(&mut self.file_options.row_count)) @@ -72,7 +73,7 @@ impl ParquetExec { let reader = ParquetAsyncReader::from_uri( &self.path.to_string_lossy(), self.cloud_options.as_ref(), - Some(self.file_info.schema.clone()), + Some(self.file_info.reader_schema.clone()), self.metadata.clone(), ) .await? diff --git a/crates/polars-pipe/src/executors/sources/parquet.rs b/crates/polars-pipe/src/executors/sources/parquet.rs index b8afaf1b8208..c487d34e8d89 100644 --- a/crates/polars-pipe/src/executors/sources/parquet.rs +++ b/crates/polars-pipe/src/executors/sources/parquet.rs @@ -80,7 +80,7 @@ impl ParquetSource { ParquetAsyncReader::from_uri( &uri, self.cloud_options.as_ref(), - Some(self.file_info.schema.clone()), + Some(self.file_info.reader_schema.clone()), self.metadata.clone(), ) .await? @@ -102,6 +102,7 @@ impl ParquetSource { let file = std::fs::File::open(path).unwrap(); ParquetReader::new(file) + .with_schema(Some(self.file_info.reader_schema.clone())) .with_n_rows(file_options.n_rows) .with_row_count(file_options.row_count) .with_projection(projection) diff --git a/crates/polars-plan/src/logical_plan/schema.rs b/crates/polars-plan/src/logical_plan/schema.rs index e97377c41e75..223ec2df7fd6 100644 --- a/crates/polars-plan/src/logical_plan/schema.rs +++ b/crates/polars-plan/src/logical_plan/schema.rs @@ -45,6 +45,9 @@ impl LogicalPlan { #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct FileInfo { pub schema: SchemaRef, + // Stores the schema used for the reader, as the main schema can contain + // extra hive columns. + pub reader_schema: SchemaRef, // - known size // - estimated size pub row_estimation: (Option, usize), @@ -54,7 +57,8 @@ pub struct FileInfo { impl FileInfo { pub fn new(schema: SchemaRef, row_estimation: (Option, usize)) -> Self { Self { - schema, + schema: schema.clone(), + reader_schema: schema.clone(), row_estimation, hive_parts: None, } From f41e8f41dbb5354d5da2cb9a8b0231c5bfd7a71d Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Thu, 19 Oct 2023 16:19:06 +0200 Subject: [PATCH 047/103] perf: properly push down slice before left/asof join (#11854) --- .../polars-ops/src/frame/join/asof/groups.rs | 35 +++++++------ .../src/frame/join/hash_join/mod.rs | 49 +++++++++---------- crates/polars-ops/src/frame/join/mod.rs | 14 ++++-- crates/polars-ops/src/series/ops/round.rs | 2 +- 4 files changed, 55 insertions(+), 45 deletions(-) diff --git a/crates/polars-ops/src/frame/join/asof/groups.rs b/crates/polars-ops/src/frame/join/asof/groups.rs index a8e31f69af50..7bd4fe1a2db4 100644 --- a/crates/polars-ops/src/frame/join/asof/groups.rs +++ b/crates/polars-ops/src/frame/join/asof/groups.rs @@ -774,9 +774,20 @@ pub trait AsofJoinBy: IntoDf { suffix: Option<&str>, slice: Option<(i64, usize)>, ) -> PolarsResult { - let self_df = self.to_df(); + let (self_sliced_slot, other_sliced_slot); // Keeps temporaries alive. + let (self_df, other_df); + if let Some((offset, len)) = slice { + self_sliced_slot = self.to_df().slice(offset, len); + other_sliced_slot = other.slice(offset, len); + self_df = &self_sliced_slot; + other_df = &other_sliced_slot; + } else { + self_df = self.to_df(); + other_df = other; + } + let left_asof = self_df.column(left_on)?.to_physical_repr(); - let right_asof = other.column(right_on)?.to_physical_repr(); + let right_asof = other_df.column(right_on)?.to_physical_repr(); let right_asof_name = right_asof.name(); let left_asof_name = left_asof.name(); @@ -787,7 +798,7 @@ pub trait AsofJoinBy: IntoDf { )?; let mut left_by = self_df.select(left_by)?; - let mut right_by = other.select(right_by)?; + let mut right_by = other_df.select(right_by)?; unsafe { for (l, r) in left_by @@ -826,7 +837,7 @@ pub trait AsofJoinBy: IntoDf { drop_these.push(right_asof_name); } - let cols = other + let cols = other_df .get_columns() .iter() .filter_map(|s| { @@ -837,19 +848,15 @@ pub trait AsofJoinBy: IntoDf { } }) .collect(); - let other = DataFrame::new_no_checks(cols); + let proj_other_df = DataFrame::new_no_checks(cols); - let mut left = self_df.clone(); - let mut right_join_tuples = &*right_join_tuples; - - if let Some((offset, len)) = slice { - left = left.slice(offset, len); - right_join_tuples = slice_slice(right_join_tuples, offset, len); - } + let left = self_df.clone(); + let right_join_tuples = &*right_join_tuples; // SAFETY: join tuples are in bounds. - let right_df = - unsafe { other.take_unchecked(&right_join_tuples.iter().copied().collect_ca("")) }; + let right_df = unsafe { + proj_other_df.take_unchecked(&right_join_tuples.iter().copied().collect_ca("")) + }; _finish_join(left, right_df, suffix) } diff --git a/crates/polars-ops/src/frame/join/hash_join/mod.rs b/crates/polars-ops/src/frame/join/hash_join/mod.rs index 04aaef862338..b6f4df2cba6e 100644 --- a/crates/polars-ops/src/frame/join/hash_join/mod.rs +++ b/crates/polars-ops/src/frame/join/hash_join/mod.rs @@ -132,19 +132,11 @@ pub trait JoinDispatch: IntoDf { ) -> PolarsResult { let ca_self = self.to_df(); let (left_idx, right_idx) = ids; - let materialize_left = || { - let mut left_idx = &*left_idx; - if let Some((offset, len)) = args.slice { - left_idx = slice_slice(left_idx, offset, len); - } - unsafe { ca_self._create_left_df_from_slice(left_idx, true, true) } - }; + let materialize_left = + || unsafe { ca_self._create_left_df_from_slice(&left_idx, true, true) }; let materialize_right = || { - let mut right_idx = &*right_idx; - if let Some((offset, len)) = args.slice { - right_idx = slice_slice(right_idx, offset, len); - } + let right_idx = &*right_idx; unsafe { other.take_unchecked(&right_idx.iter().copied().collect_ca("")) } }; let (df_left, df_right) = POOL.join(materialize_left, materialize_right); @@ -161,39 +153,38 @@ pub trait JoinDispatch: IntoDf { ) -> PolarsResult { let ca_self = self.to_df(); let suffix = &args.suffix; - let slice = args.slice; let (left_idx, right_idx) = ids; let materialize_left = || match left_idx { - ChunkJoinIds::Left(left_idx) => { + ChunkJoinIds::Left(left_idx) => unsafe { let mut left_idx = &*left_idx; - if let Some((offset, len)) = slice { + if let Some((offset, len)) = args.slice { left_idx = slice_slice(left_idx, offset, len); } - unsafe { ca_self._create_left_df_from_slice(left_idx, true, true) } + ca_self._create_left_df_from_slice(left_idx, true, true) }, - ChunkJoinIds::Right(left_idx) => { + ChunkJoinIds::Right(left_idx) => unsafe { let mut left_idx = &*left_idx; - if let Some((offset, len)) = slice { + if let Some((offset, len)) = args.slice { left_idx = slice_slice(left_idx, offset, len); } - unsafe { ca_self.create_left_df_chunked(left_idx, true) } + ca_self.create_left_df_chunked(left_idx, true) }, }; let materialize_right = || match right_idx { - ChunkJoinOptIds::Left(right_idx) => { + ChunkJoinOptIds::Left(right_idx) => unsafe { let mut right_idx = &*right_idx; - if let Some((offset, len)) = slice { + if let Some((offset, len)) = args.slice { right_idx = slice_slice(right_idx, offset, len); } - unsafe { other.take_unchecked(&right_idx.iter().copied().collect_ca("")) } + other.take_unchecked(&right_idx.iter().copied().collect_ca("")) }, - ChunkJoinOptIds::Right(right_idx) => { + ChunkJoinOptIds::Right(right_idx) => unsafe { let mut right_idx = &*right_idx; - if let Some((offset, len)) = slice { + if let Some((offset, len)) = args.slice { right_idx = slice_slice(right_idx, offset, len); } - unsafe { other._take_opt_chunked_unchecked(right_idx) } + other._take_opt_chunked_unchecked(right_idx) }, }; let (df_left, df_right) = POOL.join(materialize_left, materialize_right); @@ -213,9 +204,17 @@ pub trait JoinDispatch: IntoDf { #[cfg(feature = "dtype-categorical")] _check_categorical_src(s_left.dtype(), s_right.dtype())?; - // ensure that the chunks are aligned otherwise we go OOB let mut left = ca_self.clone(); let mut s_left = s_left.clone(); + // Eagerly limit left if possible. + if let Some((offset, len)) = args.slice { + if offset == 0 { + left = left.slice(0, len); + s_left = s_left.slice(0, len); + } + } + + // Ensure that the chunks are aligned otherwise we go OOB. let mut right = other.clone(); let mut s_right = s_right.clone(); if left.should_rechunk() { diff --git a/crates/polars-ops/src/frame/join/mod.rs b/crates/polars-ops/src/frame/join/mod.rs index 17d3f59b833f..30aaead46702 100644 --- a/crates/polars-ops/src/frame/join/mod.rs +++ b/crates/polars-ops/src/frame/join/mod.rs @@ -188,7 +188,7 @@ pub trait DataFrameJoinOps: IntoDf { _check_categorical_src(l.dtype(), r.dtype())? } - // Single keys + // Single keys. if selected_left.len() == 1 { let s_left = left_df.column(selected_left[0].name())?; let s_right = other.column(selected_right[0].name())?; @@ -255,12 +255,13 @@ pub trait DataFrameJoinOps: IntoDf { } new.unwrap() } - // make sure that we don't have logical types. - // we don't overwrite the original selected as that might be used to create a column in the new df + + // Make sure that we don't have logical types. + // We don't overwrite the original selected as that might be used to create a column in the new df. let selected_left_physical = _to_physical_and_bit_repr(&selected_left); let selected_right_physical = _to_physical_and_bit_repr(&selected_right); - // multiple keys + // Multiple keys. match args.how { JoinType::Inner => { let left = DataFrame::new_no_checks(selected_left_physical); @@ -290,8 +291,11 @@ pub trait DataFrameJoinOps: IntoDf { JoinType::Left => { let mut left = DataFrame::new_no_checks(selected_left_physical); let mut right = DataFrame::new_no_checks(selected_right_physical); - let ids = _left_join_multiple_keys(&mut left, &mut right, None, None); + if let Some((offset, len)) = args.slice { + left = left.slice(offset, len); + } + let ids = _left_join_multiple_keys(&mut left, &mut right, None, None); left_df._finish_left_join(ids, &remove_selected(other, &selected_right), args) }, JoinType::Outer => { diff --git a/crates/polars-ops/src/series/ops/round.rs b/crates/polars-ops/src/series/ops/round.rs index 542107ed5871..9d090099a42a 100644 --- a/crates/polars-ops/src/series/ops/round.rs +++ b/crates/polars-ops/src/series/ops/round.rs @@ -72,7 +72,7 @@ impl RoundSeries for Series {} #[cfg(test)] mod test { - use crate::prelude::*; + use super::*; #[test] fn test_round_series() { From dfbc5f4922e7b84d8c81c2f307ee3c0157345c44 Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Thu, 19 Oct 2023 19:22:34 +0200 Subject: [PATCH 048/103] refactor: add missing polars-ops tests to CI (#11859) --- .github/workflows/test-rust.yml | 18 +++++++------ crates/Makefile | 24 ++++++++--------- .../src/chunked_array/interpolate.rs | 27 +++++++++++++------ crates/polars-ops/src/frame/join/mod.rs | 6 ++++- .../series/ops/approx_algo/hyperloglogplus.rs | 1 + .../polars-plan/src/dsl/functions/temporal.rs | 2 ++ crates/polars-plan/src/dsl/mod.rs | 16 +---------- 7 files changed, 50 insertions(+), 44 deletions(-) diff --git a/.github/workflows/test-rust.yml b/.github/workflows/test-rust.yml index d4085dcfa82c..381235da5d6b 100644 --- a/.github/workflows/test-rust.yml +++ b/.github/workflows/test-rust.yml @@ -45,27 +45,29 @@ jobs: - name: Compile tests run: > cargo test --all-features --no-run + -p polars-core + -p polars-io -p polars-lazy + -p polars-ops -p polars-plan - -p polars-io - -p polars-core - -p polars-time - -p polars-utils -p polars-row -p polars-sql + -p polars-time + -p polars-utils - name: Run tests if: github.ref_name != 'main' run: > cargo test --all-features + -p polars-core + -p polars-io -p polars-lazy + -p polars-ops -p polars-plan - -p polars-io - -p polars-core - -p polars-time - -p polars-utils -p polars-row -p polars-sql + -p polars-time + -p polars-utils integration-test: runs-on: ${{ matrix.os }} diff --git a/crates/Makefile b/crates/Makefile index 718f3dde5580..eaf0f02d5cce 100644 --- a/crates/Makefile +++ b/crates/Makefile @@ -42,30 +42,30 @@ miri: ## Run miri .PHONY: test test: ## Run tests cargo test --all-features \ - -p polars-lazy \ - -p polars-io \ -p polars-core \ - -p polars-time \ - -p polars-utils \ - -p polars-row \ - -p polars-sql \ + -p polars-io \ + -p polars-lazy \ -p polars-ops \ -p polars-plan \ + -p polars-row \ + -p polars-sql \ + -p polars-time \ + -p polars-utils \ -- \ --test-threads=2 .PHONY: nextest nextest: ## Run tests with nextest cargo nextest run --all-features \ - -p polars-lazy \ - -p polars-io \ -p polars-core \ - -p polars-time \ - -p polars-utils \ - -p polars-row \ - -p polars-sql \ + -p polars-io \ + -p polars-lazy \ -p polars-ops \ -p polars-plan \ + -p polars-row \ + -p polars-sql \ + -p polars-time \ + -p polars-utils \ .PHONY: integration-tests integration-tests: ## Run integration tests diff --git a/crates/polars-ops/src/chunked_array/interpolate.rs b/crates/polars-ops/src/chunked_array/interpolate.rs index b06c06574960..6c2fe4949f4c 100644 --- a/crates/polars-ops/src/chunked_array/interpolate.rs +++ b/crates/polars-ops/src/chunked_array/interpolate.rs @@ -242,26 +242,34 @@ mod test { fn test_interpolate() { let ca = UInt32Chunked::new("", &[Some(1), None, None, Some(4), Some(5)]); let out = interpolate(&ca.into_series(), InterpolationMethod::Linear); - let out = out.u32().unwrap(); + let out = out.f64().unwrap(); assert_eq!( Vec::from(out), - &[Some(1), Some(2), Some(3), Some(4), Some(5)] + &[Some(1.0), Some(2.0), Some(3.0), Some(4.0), Some(5.0)] ); let ca = UInt32Chunked::new("", &[None, Some(1), None, None, Some(4), Some(5)]); let out = interpolate(&ca.into_series(), InterpolationMethod::Linear); - let out = out.u32().unwrap(); + let out = out.f64().unwrap(); assert_eq!( Vec::from(out), - &[None, Some(1), Some(2), Some(3), Some(4), Some(5)] + &[None, Some(1.0), Some(2.0), Some(3.0), Some(4.0), Some(5.0)] ); let ca = UInt32Chunked::new("", &[None, Some(1), None, None, Some(4), Some(5), None]); let out = interpolate(&ca.into_series(), InterpolationMethod::Linear); - let out = out.u32().unwrap(); + let out = out.f64().unwrap(); assert_eq!( Vec::from(out), - &[None, Some(1), Some(2), Some(3), Some(4), Some(5), None] + &[ + None, + Some(1.0), + Some(2.0), + Some(3.0), + Some(4.0), + Some(5.0), + None + ] ); let ca = UInt32Chunked::new("", &[None, Some(1), None, None, Some(4), Some(5), None]); let out = interpolate(&ca.into_series(), InterpolationMethod::Nearest); @@ -276,8 +284,11 @@ mod test { fn test_interpolate_decreasing_unsigned() { let ca = UInt32Chunked::new("", &[Some(4), None, None, Some(1)]); let out = interpolate(&ca.into_series(), InterpolationMethod::Linear); - let out = out.u32().unwrap(); - assert_eq!(Vec::from(out), &[Some(4), Some(3), Some(2), Some(1)]) + let out = out.f64().unwrap(); + assert_eq!( + Vec::from(out), + &[Some(4.0), Some(3.0), Some(2.0), Some(1.0)] + ) } #[test] diff --git a/crates/polars-ops/src/frame/join/mod.rs b/crates/polars-ops/src/frame/join/mod.rs index 30aaead46702..0b92ec715960 100644 --- a/crates/polars-ops/src/frame/join/mod.rs +++ b/crates/polars-ops/src/frame/join/mod.rs @@ -52,12 +52,13 @@ pub trait DataFrameJoinOps: IntoDf { /// /// ```no_run /// # use polars_core::prelude::*; + /// # use polars_ops::prelude::*; /// let df1: DataFrame = df!("Fruit" => &["Apple", "Banana", "Pear"], /// "Phosphorus (mg/100g)" => &[11, 22, 12])?; /// let df2: DataFrame = df!("Name" => &["Apple", "Banana", "Pear"], /// "Potassium (mg/100g)" => &[107, 358, 115])?; /// - /// let df3: DataFrame = df1.join(&df2, ["Fruit"], ["Name"], JoinType::Inner, None)?; + /// let df3: DataFrame = df1.join(&df2, ["Fruit"], ["Name"], JoinArgs::new(JoinType::Inner))?; /// assert_eq!(df3.shape(), (3, 3)); /// println!("{}", df3); /// # Ok::<(), PolarsError>(()) @@ -373,6 +374,7 @@ pub trait DataFrameJoinOps: IntoDf { /// /// ``` /// # use polars_core::prelude::*; + /// # use polars_ops::prelude::*; /// fn join_dfs(left: &DataFrame, right: &DataFrame) -> PolarsResult { /// left.inner_join(right, ["join_column_left"], ["join_column_right"]) /// } @@ -395,6 +397,7 @@ pub trait DataFrameJoinOps: IntoDf { /// /// ```no_run /// # use polars_core::prelude::*; + /// # use polars_ops::prelude::*; /// let df1: DataFrame = df!("Wavelength (nm)" => &[480.0, 650.0, 577.0, 1201.0, 100.0])?; /// let df2: DataFrame = df!("Color" => &["Blue", "Yellow", "Red"], /// "Wavelength nm" => &[480.0, 577.0, 650.0])?; @@ -437,6 +440,7 @@ pub trait DataFrameJoinOps: IntoDf { /// /// ``` /// # use polars_core::prelude::*; + /// # use polars_ops::prelude::*; /// fn join_dfs(left: &DataFrame, right: &DataFrame) -> PolarsResult { /// left.outer_join(right, ["join_column_left"], ["join_column_right"]) /// } diff --git a/crates/polars-ops/src/series/ops/approx_algo/hyperloglogplus.rs b/crates/polars-ops/src/series/ops/approx_algo/hyperloglogplus.rs index 7df61317d9bc..d507a1fcf20c 100644 --- a/crates/polars-ops/src/series/ops/approx_algo/hyperloglogplus.rs +++ b/crates/polars-ops/src/series/ops/approx_algo/hyperloglogplus.rs @@ -9,6 +9,7 @@ //! # Examples //! //! ``` +//! # use polars_ops::prelude::*; //! let mut hllp = HyperLogLog::new(); //! hllp.add(&12345); //! hllp.add(&23456); diff --git a/crates/polars-plan/src/dsl/functions/temporal.rs b/crates/polars-plan/src/dsl/functions/temporal.rs index 6ae7a8ee0c5c..b2eb3af5a229 100644 --- a/crates/polars-plan/src/dsl/functions/temporal.rs +++ b/crates/polars-plan/src/dsl/functions/temporal.rs @@ -156,6 +156,7 @@ pub fn datetime(args: DatetimeArgs) -> Expr { /// their default value of `lit(0)`, as demonstrated below. /// /// ``` +/// # use polars_plan::prelude::*; /// let args = DurationArgs { /// days: lit(5), /// hours: col("num_hours"), @@ -165,6 +166,7 @@ pub fn datetime(args: DatetimeArgs) -> Expr { /// ``` /// If you prefer builder syntax, `with_*` methods are also available. /// ``` +/// # use polars_plan::prelude::*; /// let args = DurationArgs::new().with_weeks(lit(42)).with_hours(lit(84)); /// ``` #[derive(Debug, Clone)] diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index 40dbb939de6f..05dca3120e10 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -1141,6 +1141,7 @@ impl Expr { /// Keep the original root name /// /// ```rust,no_run + /// # use polars_core::prelude::*; /// # use polars_plan::prelude::*; /// fn example(df: LazyFrame) -> LazyFrame { /// df.select([ @@ -1181,21 +1182,6 @@ impl Expr { /// Exclude a column from a wildcard/regex selection. /// /// You may also use regexes in the exclude as long as they start with `^` and end with `$`/ - /// - /// # Example - /// - /// ```rust - /// use polars_core::prelude::*; - /// use polars_lazy::prelude::*; - /// - /// // Select all columns except foo. - /// fn example(df: DataFrame) -> LazyFrame { - /// df.lazy() - /// .select(&[ - /// col("*").exclude(&["foo"]) - /// ]) - /// } - /// ``` pub fn exclude(self, columns: impl IntoVec) -> Expr { let v = columns .into_vec() From cd0288e55f9d4dd68fb7197a859dcd63df1a2255 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Thu, 19 Oct 2023 23:20:32 +0200 Subject: [PATCH 049/103] build: Bump docs dependencies (#11852) --- .github/dependabot.yml | 12 ++++++++++++ docs/requirements.txt | 6 +++--- docs/src/python/user-guide/sql/intro.py | 16 ++++++++-------- mkdocs.yml | 4 ++-- 4 files changed, 25 insertions(+), 13 deletions(-) diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 31d8d580266f..ee769d681e8a 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -48,3 +48,15 @@ updates: prefix: build(python) prefix-development: chore(python) labels: ['skip changelog'] + + # Documentation + - package-ecosystem: pip + directory: docs + schedule: + interval: monthly + ignore: + - dependency-name: '*' + update-types: ['version-update:semver-patch'] + commit-message: + prefix: chore(python) + labels: ['skip changelog'] diff --git a/docs/requirements.txt b/docs/requirements.txt index 2c317b06415b..35d0629394c0 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -3,7 +3,7 @@ pyarrow graphviz matplotlib -mkdocs-material==9.2.5 +mkdocs-material==9.4.6 mkdocs-macros-plugin==1.0.4 -markdown-exec[ansi]==1.6.0 -PyGithub==1.59.1 +markdown-exec[ansi]==1.7.0 +PyGithub==2.1.1 diff --git a/docs/src/python/user-guide/sql/intro.py b/docs/src/python/user-guide/sql/intro.py index 3b59ac9e70d1..143ec75c4f76 100644 --- a/docs/src/python/user-guide/sql/intro.py +++ b/docs/src/python/user-guide/sql/intro.py @@ -39,7 +39,7 @@ # --8<-- [end:execute] # --8<-- [start:prepare_multiple_sources] -with open("products_categories.json", "w") as temp_file: +with open("docs/data/products_categories.json", "w") as temp_file: json_data = """{"product_id": 1, "category": "Category 1"} {"product_id": 2, "category": "Category 1"} {"product_id": 3, "category": "Category 2"} @@ -48,7 +48,7 @@ temp_file.write(json_data) -with open("products_masterdata.csv", "w") as temp_file: +with open("docs/data/products_masterdata.csv", "w") as temp_file: csv_data = """product_id,product_name 1,Product A 2,Product B @@ -73,19 +73,19 @@ # sales_data is a Pandas DataFrame with schema {'product_id': Int64, 'sales': Int64} ctx = pl.SQLContext( - products_masterdata=pl.scan_csv("products_masterdata.csv"), - products_categories=pl.scan_ndjson("products_categories.json"), + products_masterdata=pl.scan_csv("docs/data/products_masterdata.csv"), + products_categories=pl.scan_ndjson("docs/data/products_categories.json"), sales_data=pl.from_pandas(sales_data), eager_execution=True, ) query = """ -SELECT +SELECT product_id, product_name, category, sales -FROM +FROM products_masterdata LEFT JOIN products_categories USING (product_id) LEFT JOIN sales_data USING (product_id) @@ -95,6 +95,6 @@ # --8<-- [end:execute_multiple_sources] # --8<-- [start:clean_multiple_sources] -os.remove("products_categories.json") -os.remove("products_masterdata.csv") +os.remove("docs/data/products_categories.json") +os.remove("docs/data/products_masterdata.csv") # --8<-- [end:clean_multiple_sources] diff --git a/mkdocs.yml b/mkdocs.yml index 501d047b35e5..0c56ab8298e5 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -140,8 +140,8 @@ markdown_extensions: - pymdownx.details - attr_list - pymdownx.emoji: - emoji_index: !!python/name:materialx.emoji.twemoji - emoji_generator: !!python/name:materialx.emoji.to_svg + emoji_index: !!python/name:material.extensions.emoji.twemoji + emoji_generator: !!python/name:material.extensions.emoji.to_svg - pymdownx.superfences - pymdownx.tabbed: alternate_style: true From 21e2c0cf21dff0c10c8f16775c7ffc386a59d2d2 Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Fri, 20 Oct 2023 05:22:39 +0800 Subject: [PATCH 050/103] docs(python): fix typo in code example in section Expressions - Basic operators (#11848) --- docs/src/python/user-guide/expressions/operators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/python/user-guide/expressions/operators.py b/docs/src/python/user-guide/expressions/operators.py index 6f617487c81e..92bf57952332 100644 --- a/docs/src/python/user-guide/expressions/operators.py +++ b/docs/src/python/user-guide/expressions/operators.py @@ -34,7 +34,7 @@ # --8<-- [start:logical] df_logical = df.select( (pl.col("nrs") > 1).alias("nrs > 1"), - (pl.col("random") <= 0.5).alias("random < .5"), + (pl.col("random") <= 0.5).alias("random <= .5"), (pl.col("nrs") != 1).alias("nrs != 1"), (pl.col("nrs") == 1).alias("nrs == 1"), ((pl.col("random") <= 0.5) & (pl.col("nrs") > 1)).alias("and_expr"), # and From 65659b9541d45b3cbf9b40d11f3433a2bd0027c0 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 19 Oct 2023 23:41:48 +0200 Subject: [PATCH 051/103] chore(python): bump hypothesis from 6.87.1 to 6.88.1 in /py-polars (#11865) Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- py-polars/docs/requirements-docs.txt | 2 +- py-polars/requirements-dev.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/py-polars/docs/requirements-docs.txt b/py-polars/docs/requirements-docs.txt index b755ae5dc02f..592c5fc249b0 100644 --- a/py-polars/docs/requirements-docs.txt +++ b/py-polars/docs/requirements-docs.txt @@ -4,7 +4,7 @@ numpy pandas pyarrow -hypothesis==6.87.1 +hypothesis==6.88.1 sphinx==7.2.4 diff --git a/py-polars/requirements-dev.txt b/py-polars/requirements-dev.txt index 33094478f6f2..f63137249749 100644 --- a/py-polars/requirements-dev.txt +++ b/py-polars/requirements-dev.txt @@ -53,7 +53,7 @@ gevent # TOOLING # ------- -hypothesis==6.87.1 +hypothesis==6.88.1 pytest==7.4.0 pytest-cov==4.1.0 pytest-xdist==3.3.1 From 754067c83118d943fc2fd81757dde3d1e2333222 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 19 Oct 2023 23:42:11 +0200 Subject: [PATCH 052/103] build(rust): update aws-creds requirement from 0.35.0 to 0.36.0 (#11868) Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- examples/read_parquet_cloud/Cargo.toml | 2 +- examples/write_parquet_cloud/Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/read_parquet_cloud/Cargo.toml b/examples/read_parquet_cloud/Cargo.toml index f6f5b56eb430..bbb43403bd95 100644 --- a/examples/read_parquet_cloud/Cargo.toml +++ b/examples/read_parquet_cloud/Cargo.toml @@ -6,4 +6,4 @@ edition = "2021" [dependencies] polars = { path = "../../crates/polars", features = ["lazy", "aws", "parquet"] } -aws-creds = "0.35.0" +aws-creds = "0.36.0" diff --git a/examples/write_parquet_cloud/Cargo.toml b/examples/write_parquet_cloud/Cargo.toml index 7bf6a24e46d3..fe02ad8f8457 100644 --- a/examples/write_parquet_cloud/Cargo.toml +++ b/examples/write_parquet_cloud/Cargo.toml @@ -6,5 +6,5 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -aws-creds = "0.35.0" +aws-creds = "0.36.0" polars = { path = "../../crates/polars", features = ["lazy", "aws", "parquet", "cloud_write"] } From 2fea82016733573861d6a5d6d0cd1c4c1fb505e0 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 19 Oct 2023 23:46:13 +0200 Subject: [PATCH 053/103] chore(python): bump black from 23.9.1 to 23.10.0 in /py-polars (#11866) Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Stijn de Gooijer --- .github/workflows/docs-global.yml | 2 +- py-polars/requirements-lint.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/docs-global.yml b/.github/workflows/docs-global.yml index 6e8f12bcae5e..54e02c389469 100644 --- a/.github/workflows/docs-global.yml +++ b/.github/workflows/docs-global.yml @@ -26,7 +26,7 @@ jobs: - uses: psf/black@stable with: src: docs/src/python - version: "23.9.1" + version: "23.10.0" deploy: runs-on: ubuntu-latest diff --git a/py-polars/requirements-lint.txt b/py-polars/requirements-lint.txt index 20bdfbddf117..40588f518050 100644 --- a/py-polars/requirements-lint.txt +++ b/py-polars/requirements-lint.txt @@ -1,4 +1,4 @@ -black==23.9.1 +black==23.10.0 blackdoc==0.3.8 mypy==1.6.0 ruff==0.1.0 From 570dca7a2e19c704e1730e8412e9d7a19eebbd2c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 19 Oct 2023 23:55:53 +0200 Subject: [PATCH 054/103] build(rust): update regex-syntax requirement from 0.7 to 0.8 (#11870) Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- crates/polars-arrow/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/polars-arrow/Cargo.toml b/crates/polars-arrow/Cargo.toml index 39828ff7e59f..777f1364ad09 100644 --- a/crates/polars-arrow/Cargo.toml +++ b/crates/polars-arrow/Cargo.toml @@ -36,7 +36,7 @@ lexical-core = { workspace = true, optional = true } fallible-streaming-iterator = { workspace = true, optional = true } regex = { workspace = true, optional = true } -regex-syntax = { version = "0.7", optional = true } +regex-syntax = { version = "0.8", optional = true } streaming-iterator = { workspace = true } indexmap = { workspace = true, optional = true } From 6cf037b904cfe213b58f9a9bfddbd0c43c8f1129 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 20 Oct 2023 00:10:32 +0200 Subject: [PATCH 055/103] build(rust): update zstd requirement from 0.12 to 0.13 (#11869) Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Stijn de Gooijer --- crates/polars-arrow/Cargo.toml | 2 +- py-polars/Cargo.lock | 24 +++++++++++++++++++++--- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/crates/polars-arrow/Cargo.toml b/crates/polars-arrow/Cargo.toml index 777f1364ad09..d464aad88f6b 100644 --- a/crates/polars-arrow/Cargo.toml +++ b/crates/polars-arrow/Cargo.toml @@ -47,7 +47,7 @@ hex = { workspace = true, optional = true } # for IPC compression lz4 = { version = "1.24", optional = true } -zstd = { version = "0.12", optional = true } +zstd = { version = "0.13", optional = true } base64 = { workspace = true, optional = true } diff --git a/py-polars/Cargo.lock b/py-polars/Cargo.lock index b492b67de12e..afbe06e97abe 100644 --- a/py-polars/Cargo.lock +++ b/py-polars/Cargo.lock @@ -1489,7 +1489,7 @@ dependencies = [ "snap", "streaming-decompression", "xxhash-rust", - "zstd", + "zstd 0.12.4", ] [[package]] @@ -1626,7 +1626,7 @@ dependencies = [ "simdutf8", "streaming-iterator", "strength_reduce", - "zstd", + "zstd 0.13.0", ] [[package]] @@ -3114,7 +3114,16 @@ version = "0.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a27595e173641171fc74a1232b7b1c7a7cb6e18222c11e9dfb9888fa424c53c" dependencies = [ - "zstd-safe", + "zstd-safe 6.0.6", +] + +[[package]] +name = "zstd" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bffb3309596d527cfcba7dfc6ed6052f1d39dfbd7c867aa2e865e4a449c10110" +dependencies = [ + "zstd-safe 7.0.0", ] [[package]] @@ -3127,6 +3136,15 @@ dependencies = [ "zstd-sys", ] +[[package]] +name = "zstd-safe" +version = "7.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43747c7422e2924c11144d5229878b98180ef8b06cca4ab5af37afc8a8d8ea3e" +dependencies = [ + "zstd-sys", +] + [[package]] name = "zstd-sys" version = "2.0.8+zstd.1.5.5" From eac03a2211591708d0b1d096f33c9a812e23ca03 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 20 Oct 2023 00:20:48 +0200 Subject: [PATCH 056/103] build(rust): update pyo3-build-config requirement from 0.19 to 0.20 (#11872) Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- examples/python_rust_compiled_function/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/python_rust_compiled_function/Cargo.toml b/examples/python_rust_compiled_function/Cargo.toml index 13381f035e59..da8b5f37096a 100644 --- a/examples/python_rust_compiled_function/Cargo.toml +++ b/examples/python_rust_compiled_function/Cargo.toml @@ -14,4 +14,4 @@ polars = { path = "../../crates/polars" } pyo3 = { workspace = true, features = ["extension-module"] } [build-dependencies] -pyo3-build-config = "0.19" +pyo3-build-config = "0.20" From c8cfdeeced44928d92ce1f2b2718893d39461796 Mon Sep 17 00:00:00 2001 From: Romano Vacca Date: Fri, 20 Oct 2023 06:53:31 +0200 Subject: [PATCH 057/103] docs: fix incorrect example of valid time zones (#11873) --- .../transformations/time-series/timezones.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/user-guide/transformations/time-series/timezones.md b/docs/user-guide/transformations/time-series/timezones.md index a12b97c68dd9..de5046d4cafd 100644 --- a/docs/user-guide/transformations/time-series/timezones.md +++ b/docs/user-guide/transformations/time-series/timezones.md @@ -12,13 +12,13 @@ hide: The `Datetime` datatype can have a time zone associated with it. Examples of valid time zones are: -- `None`: no time zone, also known as "time zone naive"; -- `UTC`: Coordinated Universal Time; +- `None`: no time zone, also known as "time zone naive". +- `UTC`: Coordinated Universal Time. - `Asia/Kathmandu`: time zone in "area/location" format. See the [list of tz database time zones](https://en.wikipedia.org/wiki/List_of_tz_database_time_zones) - to see what's available; -- `+01:00`: fixed offsets. May be useful when parsing, but you almost certainly want the "Area/Location" - format above instead as it will deal with irregularities such as DST (Daylight Saving Time) for you. + to see what's available. + +Caution: Fixed offsets such as +02:00, should not be used for handling time zones. It's advised to use the "Area/Location" format mentioned above, as it can manage timezones more effectively. Note that, because a `Datetime` can only have a single time zone, it is impossible to have a column with multiple time zones. If you are parsing data @@ -27,8 +27,8 @@ them all to a common time zone (`UTC`), see [parsing dates and times](parsing.md The main methods for setting and converting between time zones are: -- `dt.convert_time_zone`: convert from one time zone to another; -- `dt.replace_time_zone`: set/unset/change time zone; +- `dt.convert_time_zone`: convert from one time zone to another. +- `dt.replace_time_zone`: set/unset/change time zone. Let's look at some examples of common operations: From e21c1a70d2f59757b6702568ab34b511ff7fc007 Mon Sep 17 00:00:00 2001 From: Romano Vacca Date: Fri, 20 Oct 2023 06:55:17 +0200 Subject: [PATCH 058/103] fix(rust): series.to_numpy fails with dtype=Null (#11858) --- py-polars/src/series/export.rs | 5 +++++ py-polars/tests/unit/test_interop.py | 8 ++++++++ 2 files changed, 13 insertions(+) diff --git a/py-polars/src/series/export.rs b/py-polars/src/series/export.rs index fb1f9df18e08..2ec3eec4d343 100644 --- a/py-polars/src/series/export.rs +++ b/py-polars/src/series/export.rs @@ -69,6 +69,11 @@ impl PySeries { PyArray1::from_iter(py, ca.into_iter().map(|opt_v| opt_v.to_object(py))); Ok(np_arr.into_py(py)) }, + DataType::Null => { + let n = s.len(); + let np_arr = PyArray1::from_iter(py, std::iter::repeat(f32::NAN).take(n)); + Ok(np_arr.into_py(py)) + }, dt => { raise_err!( format!("'to_numpy' not supported for dtype: {dt:?}"), diff --git a/py-polars/tests/unit/test_interop.py b/py-polars/tests/unit/test_interop.py index d3981cc8528a..af4efa23edcc 100644 --- a/py-polars/tests/unit/test_interop.py +++ b/py-polars/tests/unit/test_interop.py @@ -79,6 +79,14 @@ def test_to_numpy_no_zero_copy( series.to_numpy(zero_copy_only=True, use_pyarrow=use_pyarrow) +def test_to_numpy_empty_no_pyarrow() -> None: + series = pl.Series([], dtype=pl.Null) + result = series.to_numpy() + assert result.dtype == pl.Float32 + assert result.shape == (0,) + assert result.size == 0 + + def test_from_pandas() -> None: df = pd.DataFrame( { From c2562d8b3908ae0117387f482208faf0d30f1f7b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 20 Oct 2023 06:57:25 +0200 Subject: [PATCH 059/103] build(rust): update simd-json requirement from 0.11 to 0.12 (#11871) --- Cargo.toml | 2 +- py-polars/Cargo.lock | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 06145d401f60..eda1f6d2a3af 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,7 +45,7 @@ rayon = "1.8" regex = "1.9" serde = "1.0.188" serde_json = "1" -simd-json = { version = "0.11", features = ["allow-non-simd", "known-key"] } +simd-json = { version = "0.12", features = ["known-key"] } smartstring = "1" sqlparser = "0.38" strum_macros = "0.25" diff --git a/py-polars/Cargo.lock b/py-polars/Cargo.lock index afbe06e97abe..549ce47c3d89 100644 --- a/py-polars/Cargo.lock +++ b/py-polars/Cargo.lock @@ -2394,9 +2394,9 @@ dependencies = [ [[package]] name = "simd-json" -version = "0.11.1" +version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "474b451aaac1828ed12f6454a80fe58b940ae2998d10389d41533940a6f641bf" +checksum = "f0f07a84c7456b901b8dd2c1d44caca8b0fd2c2616206ee5acc9d9da61e8d9ec" dependencies = [ "ahash", "getrandom", From a7fdbee97ad7187fe82615c2eb1c66c7fea1ca09 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Fri, 20 Oct 2023 07:00:19 +0200 Subject: [PATCH 060/103] docs: add section about plugins (#11855) --- CONTRIBUTING.md | 2 +- docs/index.md | 4 - docs/user-guide/expressions/plugins.md | 231 ++++++++++++++++++ .../expressions/user-defined-functions.md | 2 +- mkdocs.yml | 1 + 5 files changed, 234 insertions(+), 6 deletions(-) create mode 100644 docs/user-guide/expressions/plugins.md diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 44321d2f35bb..280e5e03e581 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -157,7 +157,7 @@ The user guide is maintained in the `docs/user-guide` folder. Before creating a The user guide is built using [MkDocs](https://www.mkdocs.org/). You install the dependencies for building the user guide by running `make requirements` in the root of the repo. -Run `mkdocs serve` to build and serve the user guide so you can view it locally and see updates as you make changes. +Run `mkdocs serve` to build and serve the user guide, so you can view it locally and see updates as you make changes. #### Creating a new user guide page diff --git a/docs/index.md b/docs/index.md index 2621ba4ee11d..c5c3dfb1bfbf 100644 --- a/docs/index.md +++ b/docs/index.md @@ -52,10 +52,6 @@ See the results in h2oai's [db-benchmark](https://duckdblabs.github.io/db-benchm {{code_block('home/example','example',['scan_csv','filter','group_by','collect'])}} -## Sponsors - -[](https://www.xomnia.com/)   [](https://www.jetbrains.com) - ## Community `Polars` has a very active community with frequent releases (approximately weekly). Below are some of the top contributors to the project: diff --git a/docs/user-guide/expressions/plugins.md b/docs/user-guide/expressions/plugins.md new file mode 100644 index 000000000000..bc39ecfb7eca --- /dev/null +++ b/docs/user-guide/expressions/plugins.md @@ -0,0 +1,231 @@ +# Expression plugins + +Expression plugins are the preferred way to create user defined functions. They allow you to compile a rust function +and register that as an expression into the polars library. The polars engine will dynamically link your function at runtime +and your expression will run almost as fast as native expressions. Note that this works without any interference of python +and thus no GIL contention. + +They will benefit from the same benefits default expression have: + +- Optimization +- Parallelism +- Rust native performance + +To get started we will see what is needed to create a custom expression. + +## Our first custom expression: Pig Latin + +For our first expression we are going to create a pig latin converter. Pig latin is a silly language where in every word +the first letter is removed, added to the back and finally "ay" is added. So the word "pig" would convert to "igpay". + +We could of course already do that with expressions, e.g. `col(..) + col(..).str.slice(0, 1) + "ay"`, but a specialized +function for this would perform better and allows us to learn about the plugins. + +### Setting up + +We start with a new library as the following `Cargo.toml` file + +```toml +[package] +name = "expression_lib" +version = "0.1.0" +edition = "2021" + +[lib] +name = "expression_lib" +crate-type = ["cdylib"] + +[dependencies] +polars = { version = "*" } +pyo3 = { version = "0.20.0", features = ["extension-module"] } +pyo3-polars = { version = "*", features = ["derive"] } +serde = { version = "1", features = ["derive"] } +``` + +### Writing the expression + +In this library we create a helper function that converts a `&str` to pig-latin, and we create the function that we will +expose as an expression. To expose a function we must add the `#[polars_expr(output=DataType)]` attribute and the function +must always accept `inputs: &[Series]` as its first argument. + +```rust +use polars::prelude::*; +use pyo3_polars::derive::polars_expr; +use std::fmt::Write; + +fn pig_latin_str(value: &str, output: &mut String) { + if let Some(first_char) = value.chars().next() { + write!(output, "{}{}ay", &value[1..], first_char).unwrap() + } +} + +#[polars_expr(output_type=Utf8)] +fn pig_latinnify(inputs: &[Series]) -> PolarsResult { + let ca = inputs[0].utf8()?; + let out: Utf8Chunked = ca.apply_to_buffer(pig_latin_str); + Ok(out.into_series()) +} +``` + +This is all that is needed on the rust side. On the python side we must setup a folder with the same name as defined in +the `Cargo.toml`, in this case "expression_lib". We will create a folder in the same directory as our rust `src` folder +named `expression_lib` and we create an `expression_lib/init.py`. + +Then we create a new class `Language` that will hold the expressions for our new `expr.language` namespace. The function +name of our expression can be registered. Note that it is important that this name is correct, otherwise the main polars +package cannot resolve the function name. Furthermore we can set additional keyword arguments that explain to polars how +this expression behaves. In this case we tell polars that this function is elementwise. This allows polars to run this +expression in batches. Whereas for other operations this would not be allowed, think for instance of a sort, or a slice. + +```python +import polars as pl +from polars.type_aliases import IntoExpr +from polars.utils.udfs import _get_shared_lib_location + +# boilerplate needed to inform polars of the location of binary wheel. +lib = _get_shared_lib_location(__file__) + +@pl.api.register_expr_namespace("language") +class Language: + def __init__(self, expr: pl.Expr): + self._expr = expr + + def pig_latinnify(self) -> pl.Expr: + return self._expr._register_plugin( + lib=lib, + symbol="pig_latinnify", + is_elementwise=True, + ) +``` + +We can then compile this library in our environment by installing `maturin` and running `maturin develop --release`. + +And that's it. Our expression is ready to use! + +```python +import polars as pl +from expression_lib import Language + +df = pl.DataFrame( + { + "convert": ["pig", "latin", "is", "silly"], + } +) + + +out = df.with_columns( + pig_latin=pl.col("convert").language.pig_latinnify(), +) +``` + +## Accepting kwargs + +If you want to accept `kwargs` (keyword arguments) in a polars expression, all you have to do is define a rust `struct` +and make sure that it derives `serde::Deserialize`. + +```rust +/// Provide your own kwargs struct with the proper schema and accept that type +/// in your plugin expression. +#[derive(Deserialize)] +pub struct MyKwargs { + float_arg: f64, + integer_arg: i64, + string_arg: String, + boolean_arg: bool, +} + +/// If you want to accept `kwargs`. You define a `kwargs` argument +/// on the second position in you plugin. You can provide any custom struct that is deserializable +/// with the pickle protocol (on the rust side). +#[polars_expr(output_type=Utf8)] +fn append_kwargs(input: &[Series], kwargs: MyKwargs) -> PolarsResult { + let input = &input[0]; + let input = input.cast(&DataType::Utf8)?; + let ca = input.utf8().unwrap(); + + Ok(ca + .apply_to_buffer(|val, buf| { + write!( + buf, + "{}-{}-{}-{}-{}", + val, kwargs.float_arg, kwargs.integer_arg, kwargs.string_arg, kwargs.boolean_arg + ) + .unwrap() + }) + .into_series()) +} +``` + +On the python side the kwargs can be passed when we register the plugin. + +```python +@pl.api.register_expr_namespace("my_expr") +class MyCustomExpr: + def __init__(self, expr: pl.Expr): + self._expr = expr + + def append_args( + self, + float_arg: float, + integer_arg: int, + string_arg: str, + boolean_arg: bool, + ) -> pl.Expr: + """ + This example shows how arguments other than `Series` can be used. + """ + return self._expr._register_plugin( + lib=lib, + args=[], + kwargs={ + "float_arg": float_arg, + "integer_arg": integer_arg, + "string_arg": string_arg, + "boolean_arg": boolean_arg, + }, + symbol="append_kwargs", + is_elementwise=True, + ) +``` + +## Output data types + +Output data types ofcourse don't have to be fixed. They often depend on the input types of an expression. To accommodate +this you can provide the `#[polars_expr()]` macro with an `output_type_func` argument that points to a function. This +function can map input fields `&[Field]` to an output `Field` (name and data type). + +In the snippet below is an example where we use the utility `FieldsMapper` to help with this mapping. + +```rust +use polars_plan::dsl::FieldsMapper; + +fn haversine_output(input_fields: &[Field]) -> PolarsResult { + FieldsMapper::new(input_fields).map_to_float_dtype() +} + +#[polars_expr(output_type_func=haversine_output)] +fn haversine(inputs: &[Series]) -> PolarsResult { + let out = match inputs[0].dtype() { + DataType::Float32 => { + let start_lat = inputs[0].f32().unwrap(); + let start_long = inputs[1].f32().unwrap(); + let end_lat = inputs[2].f32().unwrap(); + let end_long = inputs[3].f32().unwrap(); + crate::distances::naive_haversine(start_lat, start_long, end_lat, end_long)? + .into_series() + } + DataType::Float64 => { + let start_lat = inputs[0].f64().unwrap(); + let start_long = inputs[1].f64().unwrap(); + let end_lat = inputs[2].f64().unwrap(); + let end_long = inputs[3].f64().unwrap(); + crate::distances::naive_haversine(start_lat, start_long, end_lat, end_long)? + .into_series() + } + _ => polars_bail!(InvalidOperation: "only supported for float types"), + }; + Ok(out) +} +``` + +That's all you need to know to get started. Take a look at this [repo](https://github.com/pola-rs/pyo3-polars/tree/main/example/derive_expression) to see how this all fits together. diff --git a/docs/user-guide/expressions/user-defined-functions.md b/docs/user-guide/expressions/user-defined-functions.md index dd83cb13c382..3dd43f035f85 100644 --- a/docs/user-guide/expressions/user-defined-functions.md +++ b/docs/user-guide/expressions/user-defined-functions.md @@ -1,4 +1,4 @@ -# User-defined functions +# User-defined functions (Python) !!! warning "Not updated for Python Polars `0.19.0`" diff --git a/mkdocs.yml b/mkdocs.yml index 0c56ab8298e5..7734cbd11d5b 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -37,6 +37,7 @@ nav: - user-guide/expressions/window.md - user-guide/expressions/folds.md - user-guide/expressions/lists.md + - user-guide/expressions/plugins.md - user-guide/expressions/user-defined-functions.md - user-guide/expressions/structs.md - user-guide/expressions/numpy.md From 0b8be40ccee5c5dcef43db15bd43a77d228afbca Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Fri, 20 Oct 2023 13:03:07 +0800 Subject: [PATCH 061/103] fix: fix project pushdown for double projection contains count (#11843) --- .../projection_pushdown/projection.rs | 52 ++++++++++++------- py-polars/tests/unit/test_projections.py | 6 +++ 2 files changed, 40 insertions(+), 18 deletions(-) diff --git a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/projection.rs b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/projection.rs index e49ba5cb5642..8cf418011888 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/projection.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/projection.rs @@ -8,6 +8,31 @@ fn is_count(node: Node, expr_arena: &Arena) -> bool { } } +/// In this function we check a double projection case +/// df +/// .select(col("foo").alias("bar")) +/// .select(col("bar") +/// +/// In this query, bar cannot pass this projection, as it would not exist in DF. +/// THE ORDER IS IMPORTANT HERE! +/// this removes projection names, so any checks to upstream names should +/// be done before this branch. +fn check_double_projection( + expr: &Node, + expr_arena: &mut Arena, + acc_projections: &mut Vec, + projected_names: &mut PlHashSet>, +) { + for (_, ae) in (&*expr_arena).iter(*expr) { + if let AExpr::Alias(_, name) = ae { + if projected_names.remove(name) { + acc_projections + .retain(|expr| !aexpr_to_leaf_names(*expr, expr_arena).contains(name)); + } + } + } +} + #[allow(clippy::too_many_arguments)] pub(super) fn process_projection( proj_pd: &mut ProjectionPushDown, @@ -29,6 +54,14 @@ pub(super) fn process_projection( // simply select the first column let (first_name, _) = input_schema.try_get_at_index(0)?; let expr = expr_arena.add(AExpr::Column(Arc::from(first_name.as_str()))); + if !acc_projections.is_empty() { + check_double_projection( + &exprs[0], + expr_arena, + &mut acc_projections, + &mut projected_names, + ); + } add_expr_to_accumulated(expr, &mut acc_projections, &mut projected_names, expr_arena); local_projection.push(exprs[0]); } else { @@ -48,24 +81,7 @@ pub(super) fn process_projection( continue; } - // in this branch we check a double projection case - // df - // .select(col("foo").alias("bar")) - // .select(col("bar") - // - // In this query, bar cannot pass this projection, as it would not exist in DF. - // THE ORDER IS IMPORTANT HERE! - // this removes projection names, so any checks to upstream names should - // be done before this branch. - for (_, ae) in (&*expr_arena).iter(*e) { - if let AExpr::Alias(_, name) = ae { - if projected_names.remove(name) { - acc_projections.retain(|expr| { - !aexpr_to_leaf_names(*expr, expr_arena).contains(name) - }); - } - } - } + check_double_projection(e, expr_arena, &mut acc_projections, &mut projected_names); } // do local as we still need the effect of the projection // e.g. a projection is more than selecting a column, it can diff --git a/py-polars/tests/unit/test_projections.py b/py-polars/tests/unit/test_projections.py index 2e276b837b01..85dc01ce006d 100644 --- a/py-polars/tests/unit/test_projections.py +++ b/py-polars/tests/unit/test_projections.py @@ -320,3 +320,9 @@ def test_projection_rename_10595() -> None: assert lf.select("a", "b").rename({"b": "a", "a": "b"}).select( "a" ).collect().schema == {"a": pl.Float32} + + +def test_projection_count_11841() -> None: + pl.LazyFrame({"x": 1}).select(records=pl.count()).select( + pl.lit(1).alias("x"), pl.all() + ).collect() From d1af5f93ee17de52f75c146912e2ef7fd35d92a5 Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Fri, 20 Oct 2023 07:07:36 +0200 Subject: [PATCH 062/103] refactor(rust): rename new_from_owned_with_null_bitmap (#11828) --- crates/polars-core/src/chunked_array/from.rs | 8 ++------ crates/polars-ops/src/series/ops/rank.rs | 13 ++++++------- py-polars/src/series/numpy_ufunc.rs | 7 ++----- 3 files changed, 10 insertions(+), 18 deletions(-) diff --git a/crates/polars-core/src/chunked_array/from.rs b/crates/polars-core/src/chunked_array/from.rs index b20ea1cde3ca..1c67b3b75963 100644 --- a/crates/polars-core/src/chunked_array/from.rs +++ b/crates/polars-core/src/chunked_array/from.rs @@ -273,12 +273,8 @@ where Self::with_chunk(name, to_primitive::(v, None)) } - /// Nullify values in slice with an existing null bitmap - pub fn new_from_owned_with_null_bitmap( - name: &str, - values: Vec, - buffer: Option, - ) -> Self { + /// Create a new ChunkedArray from a Vec and a validity mask. + pub fn from_vec_validity(name: &str, values: Vec, buffer: Option) -> Self { let arr = to_array::(values, buffer); let mut out = ChunkedArray { field: Arc::new(Field::new(name, T::get_dtype())), diff --git a/crates/polars-ops/src/series/ops/rank.rs b/crates/polars-ops/src/series/ops/rank.rs index 568ea02345ee..c251aaa6922e 100644 --- a/crates/polars-ops/src/series/ops/rank.rs +++ b/crates/polars-ops/src/series/ops/rank.rs @@ -113,7 +113,7 @@ fn rank(s: &Series, method: RankMethod, descending: bool, seed: Option) -> rank += 1; } } - IdxCa::new_from_owned_with_null_bitmap(s.name(), out, validity).into_series() + IdxCa::from_vec_validity(s.name(), out, validity).into_series() } else { let sorted_values = unsafe { s.take_unchecked(&sort_idx_ca) }; let not_consecutive_same = sorted_values @@ -136,7 +136,7 @@ fn rank(s: &Series, method: RankMethod, descending: bool, seed: Option) -> rank += 1; } }); - IdxCa::new_from_owned_with_null_bitmap(s.name(), out, validity).into_series() + IdxCa::from_vec_validity(s.name(), out, validity).into_series() }, Average => unsafe { let mut out = vec![0.0; s.len()]; @@ -149,8 +149,7 @@ fn rank(s: &Series, method: RankMethod, descending: bool, seed: Option) -> *out.get_unchecked_mut(*i as usize) = avg; } }); - Float64Chunked::new_from_owned_with_null_bitmap(s.name(), out, validity) - .into_series() + Float64Chunked::from_vec_validity(s.name(), out, validity).into_series() }, Min => unsafe { let mut out = vec![0 as IdxSize; s.len()]; @@ -160,7 +159,7 @@ fn rank(s: &Series, method: RankMethod, descending: bool, seed: Option) -> } rank += ties.len() as IdxSize; }); - IdxCa::new_from_owned_with_null_bitmap(s.name(), out, validity).into_series() + IdxCa::from_vec_validity(s.name(), out, validity).into_series() }, Max => unsafe { let mut out = vec![0 as IdxSize; s.len()]; @@ -170,7 +169,7 @@ fn rank(s: &Series, method: RankMethod, descending: bool, seed: Option) -> *out.get_unchecked_mut(*i as usize) = rank - 1; } }); - IdxCa::new_from_owned_with_null_bitmap(s.name(), out, validity).into_series() + IdxCa::from_vec_validity(s.name(), out, validity).into_series() }, Dense => unsafe { let mut out = vec![0 as IdxSize; s.len()]; @@ -180,7 +179,7 @@ fn rank(s: &Series, method: RankMethod, descending: bool, seed: Option) -> } rank += 1; }); - IdxCa::new_from_owned_with_null_bitmap(s.name(), out, validity).into_series() + IdxCa::from_vec_validity(s.name(), out, validity).into_series() }, Ordinal => unreachable!(), } diff --git a/py-polars/src/series/numpy_ufunc.rs b/py-polars/src/series/numpy_ufunc.rs index f37e20e33b91..91265aca0789 100644 --- a/py-polars/src/series/numpy_ufunc.rs +++ b/py-polars/src/series/numpy_ufunc.rs @@ -86,11 +86,8 @@ macro_rules! impl_ufuncs { assert_eq!(get_refcnt(out_array), 3); let validity = self.series.chunks()[0].validity().cloned(); - let ca = ChunkedArray::<$type>::new_from_owned_with_null_bitmap( - self.name(), - av, - validity, - ); + let ca = + ChunkedArray::<$type>::from_vec_validity(self.name(), av, validity); PySeries::new(ca.into_series()) }, Err(e) => { From 7b9f10edd2d3df913f8f03d3991575c8ab24ed2e Mon Sep 17 00:00:00 2001 From: Robbert-Jan 't Hoen <147692816+rjthoen@users.noreply.github.com> Date: Fri, 20 Oct 2023 07:09:14 +0200 Subject: [PATCH 063/103] fix(python): Frame slicing single column (#11825) --- py-polars/polars/dataframe/frame.py | 8 +++----- py-polars/tests/unit/dataframe/test_df.py | 5 ++--- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 572d9bfdebaa..4f802362ba6b 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -1631,13 +1631,11 @@ def __getitem__( ): df = self[:, col_selection] return df.slice(row_selection, 1) - # df[2, "a"] - if isinstance(col_selection, str): - return self[col_selection][row_selection] - # column selection can be "a" and ["a", "b"] + # df[:, "a"] if isinstance(col_selection, str): - col_selection = [col_selection] + series = self.get_column(col_selection) + return series[row_selection] # df[:, 1] if isinstance(col_selection, int): diff --git a/py-polars/tests/unit/dataframe/test_df.py b/py-polars/tests/unit/dataframe/test_df.py index b30c8bc44e21..0f32b704deff 100644 --- a/py-polars/tests/unit/dataframe/test_df.py +++ b/py-polars/tests/unit/dataframe/test_df.py @@ -124,7 +124,7 @@ def test_selection() -> None: # select columns by mask assert df[:2, :1].rows() == [(1,), (2,)] - assert df[:2, "a"].rows() == [(1,), (2,)] # type: ignore[attr-defined] + assert df[:2, ["a"]].rows() == [(1,), (2,)] # column selection by string(s) in first dimension assert df["a"].to_list() == [1, 2, 3] @@ -136,7 +136,7 @@ def test_selection() -> None: assert_frame_equal(df[-1], pl.DataFrame({"a": [3], "b": [3.0], "c": ["c"]})) # row, column selection when using two dimensions - assert df[:, 0].to_list() == [1, 2, 3] + assert df[:, "a"].to_list() == [1, 2, 3] assert df[:, 1].to_list() == [1.0, 2.0, 3.0] assert df[:2, 2].to_list() == ["a", "b"] @@ -155,7 +155,6 @@ def test_selection() -> None: assert typing.cast(float, df[1, 1]) == 2.0 assert typing.cast(int, df[2, 0]) == 3 - assert df[[0, 1], "b"].rows() == [(1.0,), (2.0,)] # type: ignore[attr-defined] assert df[[2], ["a", "b"]].rows() == [(3, 3.0)] assert df.to_series(0).name == "a" assert (df["a"] == df["a"]).sum() == 3 From 27a4fe2a5c387a7e6e5661e19e265efc259f9396 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Fri, 20 Oct 2023 08:15:42 +0200 Subject: [PATCH 064/103] fix: recursively check allowed streaming dtypes (#11879) --- .../src/physical_plan/streaming/convert_alp.rs | 4 ++++ .../unit/streaming/test_streaming_categoricals.py | 14 ++++++++++++++ 2 files changed, 18 insertions(+) create mode 100644 py-polars/tests/unit/streaming/test_streaming_categoricals.py diff --git a/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs b/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs index e23357f58d40..bc2af56f18d7 100644 --- a/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs +++ b/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs @@ -357,6 +357,10 @@ pub(crate) fn insert_streaming_nodes( DataType::Object(_) => false, #[cfg(feature = "dtype-categorical")] DataType::Categorical(_) => string_cache, + DataType::List(inner) => allowed_dtype(inner, string_cache), + DataType::Struct(fields) => fields + .iter() + .all(|fld| allowed_dtype(fld.data_type(), string_cache)), _ => true, } } diff --git a/py-polars/tests/unit/streaming/test_streaming_categoricals.py b/py-polars/tests/unit/streaming/test_streaming_categoricals.py new file mode 100644 index 000000000000..776e0c0ce377 --- /dev/null +++ b/py-polars/tests/unit/streaming/test_streaming_categoricals.py @@ -0,0 +1,14 @@ +import polars as pl + + +def test_streaming_nested_categorical() -> None: + assert ( + pl.LazyFrame({"numbers": [1, 1, 2], "cat": [["str"], ["foo"], ["bar"]]}) + .with_columns(pl.col("cat").cast(pl.List(pl.Categorical))) + .group_by("numbers") + .agg(pl.col("cat").first()) + .sort("numbers") + ).collect(streaming=True).to_dict(False) == { + "numbers": [1, 2], + "cat": [["str"], ["bar"]], + } From d31c30d83af31b913090b678f5727827ad517145 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Fri, 20 Oct 2023 08:39:24 +0200 Subject: [PATCH 065/103] ci: Allow manual trigger for docs deployment (#11881) --- .github/workflows/docs-global.yml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/docs-global.yml b/.github/workflows/docs-global.yml index 54e02c389469..823278186535 100644 --- a/.github/workflows/docs-global.yml +++ b/.github/workflows/docs-global.yml @@ -9,6 +9,8 @@ on: push: tags: - py-** + # Allow manual trigger until we have properly versioned docs + workflow_dispatch: jobs: markdown-link-check: @@ -72,12 +74,12 @@ jobs: run: mkdocs build - name: Add .nojekyll - if: ${{ github.ref_type == 'tag' }} + if: github.ref_type == 'tag' || github.event_name == 'workflow_dispatch' working-directory: site run: touch .nojekyll - name: Deploy docs - if: ${{ github.ref_type == 'tag' }} + if: github.ref_type == 'tag' || github.event_name == 'workflow_dispatch' uses: JamesIves/github-pages-deploy-action@v4 with: folder: site From 106dce8f037a917b108d86ee968d6828d355c9d5 Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Fri, 20 Oct 2023 15:39:19 +0800 Subject: [PATCH 066/103] feat(rust, python): Introduce list.sample (#11845) --- .../src/chunked_array/list/iterator.rs | 50 +++++++++++- crates/polars-lazy/Cargo.toml | 1 + crates/polars-ops/Cargo.toml | 1 + .../src/chunked_array/list/namespace.rs | 80 +++++++++++++++++++ crates/polars-plan/Cargo.toml | 1 + .../polars-plan/src/dsl/function_expr/list.rs | 41 ++++++++++ .../polars-plan/src/dsl/function_expr/mod.rs | 13 +++ .../src/dsl/function_expr/schema.rs | 2 + crates/polars-plan/src/dsl/list.rs | 42 ++++++++++ crates/polars/Cargo.toml | 1 + py-polars/Cargo.toml | 2 + .../source/reference/expressions/list.rst | 1 + .../docs/source/reference/series/list.rst | 1 + py-polars/polars/expr/expr.py | 2 +- py-polars/polars/expr/list.py | 58 ++++++++++++++ py-polars/polars/series/list.py | 40 ++++++++++ py-polars/src/expr/list.rs | 30 +++++++ py-polars/tests/unit/namespaces/test_list.py | 31 +++++++ 18 files changed, 393 insertions(+), 4 deletions(-) diff --git a/crates/polars-core/src/chunked_array/list/iterator.rs b/crates/polars-core/src/chunked_array/list/iterator.rs index 2dc2c7eb8559..4c51fdb271c7 100644 --- a/crates/polars-core/src/chunked_array/list/iterator.rs +++ b/crates/polars-core/src/chunked_array/list/iterator.rs @@ -208,15 +208,14 @@ impl ListChunked { .map(|(opt_s, opt_v)| { let out = f(opt_s, opt_v); match out { - Some(out) if out.is_empty() => { - fast_explode = false; + Some(out) => { + fast_explode &= !out.is_empty(); Some(out) }, None => { fast_explode = false; out }, - _ => out, } }) .collect_trusted() @@ -229,6 +228,51 @@ impl ListChunked { out } + pub fn try_zip_and_apply_amortized<'a, T, I, F>( + &'a self, + ca: &'a ChunkedArray, + mut f: F, + ) -> PolarsResult + where + T: PolarsDataType, + &'a ChunkedArray: IntoIterator, + I: TrustedLen>>, + F: FnMut( + Option>, + Option>, + ) -> PolarsResult>, + { + if self.is_empty() { + return Ok(self.clone()); + } + let mut fast_explode = self.null_count() == 0; + // SAFETY: unstable series never lives longer than the iterator. + let mut out: ListChunked = unsafe { + self.amortized_iter() + .zip(ca) + .map(|(opt_s, opt_v)| { + let out = f(opt_s, opt_v)?; + match out { + Some(out) => { + fast_explode &= !out.is_empty(); + Ok(Some(out)) + }, + None => { + fast_explode = false; + Ok(out) + }, + } + }) + .collect::>()? + }; + + out.rename(self.name()); + if fast_explode { + out.set_fast_explode(); + } + Ok(out) + } + /// Apply a closure `F` elementwise. #[must_use] pub fn apply_amortized<'a, F>(&'a self, mut f: F) -> Self diff --git a/crates/polars-lazy/Cargo.toml b/crates/polars-lazy/Cargo.toml index 7c6a7770f8d0..785c2aeea51c 100644 --- a/crates/polars-lazy/Cargo.toml +++ b/crates/polars-lazy/Cargo.toml @@ -138,6 +138,7 @@ fused = ["polars-plan/fused", "polars-ops/fused"] list_sets = ["polars-plan/list_sets", "polars-ops/list_sets"] list_any_all = ["polars-ops/list_any_all", "polars-plan/list_any_all"] list_drop_nulls = ["polars-ops/list_drop_nulls", "polars-plan/list_drop_nulls"] +list_sample = ["polars-ops/list_sample", "polars-plan/list_sample"] cutqcut = ["polars-plan/cutqcut", "polars-ops/cutqcut"] rle = ["polars-plan/rle", "polars-ops/rle"] extract_groups = ["polars-plan/extract_groups"] diff --git a/crates/polars-ops/Cargo.toml b/crates/polars-ops/Cargo.toml index 1ea76b197cd5..72b00779d69f 100644 --- a/crates/polars-ops/Cargo.toml +++ b/crates/polars-ops/Cargo.toml @@ -107,6 +107,7 @@ list_take = [] list_sets = [] list_any_all = [] list_drop_nulls = [] +list_sample = [] extract_groups = ["dtype-struct", "polars-core/regex"] is_in = ["polars-core/reinterpret"] convert_index = [] diff --git a/crates/polars-ops/src/chunked_array/list/namespace.rs b/crates/polars-ops/src/chunked_array/list/namespace.rs index 37165179d187..374a81e5d2a6 100644 --- a/crates/polars-ops/src/chunked_array/list/namespace.rs +++ b/crates/polars-ops/src/chunked_array/list/namespace.rs @@ -404,6 +404,86 @@ pub trait ListNameSpaceImpl: AsList { list_ca.apply_amortized(|s| s.as_ref().drop_nulls()) } + #[cfg(feature = "list_sample")] + fn lst_sample_n( + &self, + n: &Series, + with_replacement: bool, + shuffle: bool, + seed: Option, + ) -> PolarsResult { + let ca = self.as_list(); + + let n_s = n.cast(&IDX_DTYPE)?; + let n = n_s.idx()?; + + let out = match n.len() { + 1 => { + if let Some(n) = n.get(0) { + ca.try_apply_amortized(|s| { + s.as_ref() + .sample_n(n as usize, with_replacement, shuffle, seed) + }) + } else { + Ok(ListChunked::full_null_with_dtype( + ca.name(), + ca.len(), + &ca.inner_dtype(), + )) + } + }, + _ => ca.try_zip_and_apply_amortized(n, |opt_s, opt_n| match (opt_s, opt_n) { + (Some(s), Some(n)) => s + .as_ref() + .sample_n(n as usize, with_replacement, shuffle, seed) + .map(Some), + _ => Ok(None), + }), + }; + out.map(|ok| self.same_type(ok)) + } + + #[cfg(feature = "list_sample")] + fn lst_sample_fraction( + &self, + fraction: &Series, + with_replacement: bool, + shuffle: bool, + seed: Option, + ) -> PolarsResult { + let ca = self.as_list(); + + let fraction_s = fraction.cast(&DataType::Float64)?; + let fraction = fraction_s.f64()?; + + let out = match fraction.len() { + 1 => { + if let Some(fraction) = fraction.get(0) { + ca.try_apply_amortized(|s| { + let n = (s.as_ref().len() as f64 * fraction) as usize; + s.as_ref().sample_n(n, with_replacement, shuffle, seed) + }) + } else { + Ok(ListChunked::full_null_with_dtype( + ca.name(), + ca.len(), + &ca.inner_dtype(), + )) + } + }, + _ => ca.try_zip_and_apply_amortized(fraction, |opt_s, opt_n| match (opt_s, opt_n) { + (Some(s), Some(fraction)) => { + let n = (s.as_ref().len() as f64 * fraction) as usize; + s.as_ref() + .sample_n(n, with_replacement, shuffle, seed) + .map(Some) + }, + _ => Ok(None), + }), + }; + out.map(|ok| self.same_type(ok)) + } + fn lst_concat(&self, other: &[Series]) -> PolarsResult { let ca = self.as_list(); let other_len = other.len(); diff --git a/crates/polars-plan/Cargo.toml b/crates/polars-plan/Cargo.toml index 1a088bc806ef..460881737b38 100644 --- a/crates/polars-plan/Cargo.toml +++ b/crates/polars-plan/Cargo.toml @@ -136,6 +136,7 @@ fused = ["polars-ops/fused"] list_sets = ["polars-ops/list_sets"] list_any_all = ["polars-ops/list_any_all"] list_drop_nulls = ["polars-ops/list_drop_nulls"] +list_sample = ["polars-ops/list_sample"] cutqcut = ["polars-ops/cutqcut"] rle = ["polars-ops/rle"] extract_groups = ["regex", "dtype-struct", "polars-ops/extract_groups"] diff --git a/crates/polars-plan/src/dsl/function_expr/list.rs b/crates/polars-plan/src/dsl/function_expr/list.rs index 35155902b788..bb238d09bd75 100644 --- a/crates/polars-plan/src/dsl/function_expr/list.rs +++ b/crates/polars-plan/src/dsl/function_expr/list.rs @@ -11,6 +11,13 @@ pub enum ListFunction { Contains, #[cfg(feature = "list_drop_nulls")] DropNulls, + #[cfg(feature = "list_sample")] + Sample { + is_fraction: bool, + with_replacement: bool, + shuffle: bool, + seed: Option, + }, Slice, Shift, Get, @@ -52,6 +59,14 @@ impl Display for ListFunction { Contains => "contains", #[cfg(feature = "list_drop_nulls")] DropNulls => "drop_nulls", + #[cfg(feature = "list_sample")] + Sample { is_fraction, .. } => { + if *is_fraction { + "sample_fraction" + } else { + "sample_n" + } + }, Slice => "slice", Shift => "shift", Get => "get", @@ -107,6 +122,32 @@ pub(super) fn drop_nulls(s: &Series) -> PolarsResult { Ok(list.lst_drop_nulls().into_series()) } +#[cfg(feature = "list_sample")] +pub(super) fn sample_n( + s: &[Series], + with_replacement: bool, + shuffle: bool, + seed: Option, +) -> PolarsResult { + let list = s[0].list()?; + let n = &s[1]; + list.lst_sample_n(n, with_replacement, shuffle, seed) + .map(|ok| ok.into_series()) +} + +#[cfg(feature = "list_sample")] +pub(super) fn sample_fraction( + s: &[Series], + with_replacement: bool, + shuffle: bool, + seed: Option, +) -> PolarsResult { + let list = s[0].list()?; + let fraction = &s[1]; + list.lst_sample_fraction(fraction, with_replacement, shuffle, seed) + .map(|ok| ok.into_series()) +} + fn check_slice_arg_shape(slice_len: usize, ca_len: usize, name: &str) -> PolarsResult<()> { polars_ensure!( slice_len == ca_len, diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index 40a5eef2f326..3ca6b5f2c7c9 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -671,6 +671,19 @@ impl From for SpecialEq> { Contains => wrap!(list::contains), #[cfg(feature = "list_drop_nulls")] DropNulls => map!(list::drop_nulls), + #[cfg(feature = "list_sample")] + Sample { + is_fraction, + with_replacement, + shuffle, + seed, + } => { + if is_fraction { + map_as_slice!(list::sample_fraction, with_replacement, shuffle, seed) + } else { + map_as_slice!(list::sample_n, with_replacement, shuffle, seed) + } + }, Slice => wrap!(list::slice), Shift => map_as_slice!(list::shift), Get => wrap!(list::get), diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index cf2178fcd42c..af99f7f81b52 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -70,6 +70,8 @@ impl FunctionExpr { Contains => mapper.with_dtype(DataType::Boolean), #[cfg(feature = "list_drop_nulls")] DropNulls => mapper.with_same_dtype(), + #[cfg(feature = "list_sample")] + Sample { .. } => mapper.with_same_dtype(), Slice => mapper.with_same_dtype(), Shift => mapper.with_same_dtype(), Get => mapper.map_to_list_inner_dtype(), diff --git a/crates/polars-plan/src/dsl/list.rs b/crates/polars-plan/src/dsl/list.rs index 6e9bde5b68eb..c8741dba73ff 100644 --- a/crates/polars-plan/src/dsl/list.rs +++ b/crates/polars-plan/src/dsl/list.rs @@ -34,6 +34,48 @@ impl ListNameSpace { .map_private(FunctionExpr::ListExpr(ListFunction::DropNulls)) } + #[cfg(feature = "list_sample")] + pub fn sample_n( + self, + n: Expr, + with_replacement: bool, + shuffle: bool, + seed: Option, + ) -> Expr { + self.0.map_many_private( + FunctionExpr::ListExpr(ListFunction::Sample { + is_fraction: false, + with_replacement, + shuffle, + seed, + }), + &[n], + false, + false, + ) + } + + #[cfg(feature = "list_sample")] + pub fn sample_fraction( + self, + fraction: Expr, + with_replacement: bool, + shuffle: bool, + seed: Option, + ) -> Expr { + self.0.map_many_private( + FunctionExpr::ListExpr(ListFunction::Sample { + is_fraction: true, + with_replacement, + shuffle, + seed, + }), + &[fraction], + false, + false, + ) + } + /// Return the number of elements in each list. /// /// Null values are treated like regular elements in this context. diff --git a/crates/polars/Cargo.toml b/crates/polars/Cargo.toml index 369b1d07f9cf..1f397e8e7aa6 100644 --- a/crates/polars/Cargo.toml +++ b/crates/polars/Cargo.toml @@ -190,6 +190,7 @@ fused = ["polars-ops/fused", "polars-lazy?/fused"] list_sets = ["polars-lazy?/list_sets"] list_any_all = ["polars-lazy?/list_any_all"] list_drop_nulls = ["polars-lazy?/list_drop_nulls"] +list_sample = ["polars-lazy?/list_sample"] cutqcut = ["polars-lazy?/cutqcut"] rle = ["polars-lazy?/rle"] extract_groups = ["polars-lazy?/extract_groups"] diff --git a/py-polars/Cargo.toml b/py-polars/Cargo.toml index 445646088c1d..f782192c869e 100644 --- a/py-polars/Cargo.toml +++ b/py-polars/Cargo.toml @@ -138,6 +138,7 @@ binary_encoding = ["polars/binary_encoding"] list_sets = ["polars-lazy/list_sets"] list_any_all = ["polars/list_any_all"] list_drop_nulls = ["polars/list_drop_nulls"] +list_sample = ["polars/list_sample"] cutqcut = ["polars/cutqcut"] rle = ["polars/rle"] extract_groups = ["polars/extract_groups"] @@ -165,6 +166,7 @@ operations = [ "list_sets", "list_any_all", "list_drop_nulls", + "list_sample", "cutqcut", "rle", "extract_groups", diff --git a/py-polars/docs/source/reference/expressions/list.rst b/py-polars/docs/source/reference/expressions/list.rst index d56b44abcc30..f43401e20561 100644 --- a/py-polars/docs/source/reference/expressions/list.rst +++ b/py-polars/docs/source/reference/expressions/list.rst @@ -34,6 +34,7 @@ The following methods are available under the `expr.list` attribute. Expr.list.mean Expr.list.min Expr.list.reverse + Expr.list.sample Expr.list.set_difference Expr.list.set_intersection Expr.list.set_symmetric_difference diff --git a/py-polars/docs/source/reference/series/list.rst b/py-polars/docs/source/reference/series/list.rst index 7f3b709e80db..ad766dd92eb9 100644 --- a/py-polars/docs/source/reference/series/list.rst +++ b/py-polars/docs/source/reference/series/list.rst @@ -34,6 +34,7 @@ The following methods are available under the `Series.list` attribute. Series.list.mean Series.list.min Series.list.reverse + Series.list.sample Series.list.set_difference Series.list.set_intersection Series.list.set_symmetric_difference diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 608cb0db4840..8ad68bb58117 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -8244,7 +8244,7 @@ def shuffle(self, seed: int | None = None) -> Self: def sample( self, - n: int | Expr | None = None, + n: int | IntoExprColumn | None = None, *, fraction: float | None = None, with_replacement: bool = False, diff --git a/py-polars/polars/expr/list.py b/py-polars/polars/expr/list.py index 5a41f4e2f213..249c17ab5cb4 100644 --- a/py-polars/polars/expr/list.py +++ b/py-polars/polars/expr/list.py @@ -138,6 +138,64 @@ def drop_nulls(self) -> Expr: """ return wrap_expr(self._pyexpr.list_drop_nulls()) + def sample( + self, + n: int | IntoExprColumn | None = None, + *, + fraction: float | IntoExprColumn | None = None, + with_replacement: bool = False, + shuffle: bool = False, + seed: int | None = None, + ) -> Expr: + """ + Sample from this list. + + Parameters + ---------- + n + Number of items to return. Cannot be used with `fraction`. Defaults to 1 if + `fraction` is None. + fraction + Fraction of items to return. Cannot be used with `n`. + with_replacement + Allow values to be sampled more than once. + shuffle + Shuffle the order of sampled data points. + seed + Seed for the random number generator. If set to None (default), a + random seed is generated for each sample operation. + + Examples + -------- + >>> df = pl.DataFrame({"values": [[1, 2, 3], [4, 5]], "n": [2, 1]}) + >>> df.select(pl.col("values").list.sample(n=pl.col("n"), seed=1)) + shape: (2, 1) + ┌───────────┐ + │ values │ + │ --- │ + │ list[i64] │ + ╞═══════════╡ + │ [2, 1] │ + │ [5] │ + └───────────┘ + + """ + if n is not None and fraction is not None: + raise ValueError("cannot specify both `n` and `fraction`") + + if fraction is not None: + fraction = parse_as_expression(fraction) + return wrap_expr( + self._pyexpr.list_sample_fraction( + fraction, with_replacement, shuffle, seed + ) + ) + + if n is None: + n = 1 + n = parse_as_expression(n) + return wrap_expr(self._pyexpr.list_sample_n(n, with_replacement, shuffle, seed)) + def sum(self) -> Expr: """ Sum all the lists in the array. diff --git a/py-polars/polars/series/list.py b/py-polars/polars/series/list.py index 3b883df4dea3..3f4b11d257b5 100644 --- a/py-polars/polars/series/list.py +++ b/py-polars/polars/series/list.py @@ -125,6 +125,46 @@ def drop_nulls(self) -> Series: """ + def sample( + self, + n: int | IntoExprColumn | None = None, + *, + fraction: float | IntoExprColumn | None = None, + with_replacement: bool = False, + shuffle: bool = False, + seed: int | None = None, + ) -> Series: + """ + Sample from this list. + + Parameters + ---------- + n + Number of items to return. Cannot be used with `fraction`. Defaults to 1 if + `fraction` is None. + fraction + Fraction of items to return. Cannot be used with `n`. + with_replacement + Allow values to be sampled more than once. + shuffle + Shuffle the order of sampled data points. + seed + Seed for the random number generator. If set to None (default), a + random seed is generated for each sample operation. + + Examples + -------- + >>> s = pl.Series("values", [[1, 2, 3], [4, 5]]) + >>> s.list.sample(n=pl.Series("n", [2, 1]), seed=1) + shape: (2,) + Series: 'values' [list[i64]] + [ + [2, 1] + [5] + ] + + """ + def sum(self) -> Series: """Sum all the arrays in the list.""" diff --git a/py-polars/src/expr/list.rs b/py-polars/src/expr/list.rs index dbac07c08a3b..a8a6db6613b9 100644 --- a/py-polars/src/expr/list.rs +++ b/py-polars/src/expr/list.rs @@ -115,6 +115,36 @@ impl PyExpr { self.inner.clone().list().drop_nulls().into() } + #[cfg(feature = "list_sample")] + fn list_sample_n( + &self, + n: PyExpr, + with_replacement: bool, + shuffle: bool, + seed: Option, + ) -> Self { + self.inner + .clone() + .list() + .sample_n(n.inner, with_replacement, shuffle, seed) + .into() + } + + #[cfg(feature = "list_sample")] + fn list_sample_fraction( + &self, + fraction: PyExpr, + with_replacement: bool, + shuffle: bool, + seed: Option, + ) -> Self { + self.inner + .clone() + .list() + .sample_fraction(fraction.inner, with_replacement, shuffle, seed) + .into() + } + #[cfg(feature = "list_take")] fn list_take(&self, index: PyExpr, null_on_oob: bool) -> Self { self.inner diff --git a/py-polars/tests/unit/namespaces/test_list.py b/py-polars/tests/unit/namespaces/test_list.py index 9c0eeb51e063..89ceff4e6a0a 100644 --- a/py-polars/tests/unit/namespaces/test_list.py +++ b/py-polars/tests/unit/namespaces/test_list.py @@ -179,6 +179,37 @@ def test_list_drop_nulls() -> None: assert_frame_equal(df, expected_df) +def test_list_sample() -> None: + s = pl.Series("values", [[1, 2, 3, None], [None, None], [1, 2], None]) + + expected_sample_n = pl.Series("values", [[3, 1], [None], [2], None]) + assert_series_equal( + s.list.sample(n=pl.Series([2, 1, 1, 1]), seed=1), expected_sample_n + ) + + expected_sample_frac = pl.Series("values", [[3, 1], [None], [1, 2], None]) + assert_series_equal( + s.list.sample(fraction=pl.Series([0.5, 0.5, 1.0, 0.3]), seed=1), + expected_sample_frac, + ) + + df = pl.DataFrame( + { + "values": [[1, 2, 3, None], [None, None], [3, 4]], + "n": [2, 1, 2], + "frac": [0.5, 0.5, 1.0], + } + ) + df = df.select( + sample_n=pl.col("values").list.sample(n=pl.col("n"), seed=1), + sample_frac=pl.col("values").list.sample(fraction=pl.col("frac"), seed=1), + ) + expected_df = pl.DataFrame( + {"sample_n": [[3, 1], [None], [3, 4]], "sample_frac": [[3, 1], [None], [3, 4]]} + ) + assert_frame_equal(df, expected_df) + + def test_list_diff() -> None: s = pl.Series("a", [[1, 2], [10, 2, 1]]) expected = pl.Series("a", [[None, 1], [None, -8, -1]]) From f509de91f2a56d0dee6b73bf36d0b0be88442b38 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Fri, 20 Oct 2023 10:11:24 +0200 Subject: [PATCH 067/103] chore: Fix Cargo warning for parquet2 dependency (#11882) --- Cargo.toml | 2 +- crates/polars-arrow/Cargo.toml | 2 +- crates/polars-error/Cargo.toml | 4 ++-- crates/polars-io/Cargo.toml | 2 +- crates/polars-lazy/Cargo.toml | 2 +- crates/polars-ops/Cargo.toml | 4 ++-- crates/polars-plan/Cargo.toml | 6 +++--- crates/polars-time/Cargo.toml | 2 +- crates/polars/Cargo.toml | 2 +- 9 files changed, 13 insertions(+), 13 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index eda1f6d2a3af..6a5f0b2a835e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -85,7 +85,7 @@ arrow-array = { version = ">=41", default-features = false } arrow-buffer = { version = ">=41", default-features = false } arrow-data = { version = ">=41", default-features = false } arrow-schema = { version = ">=41", default-features = false } -parquet2 = { version = "0.17.2", features = ["async"] } +parquet2 = { version = "0.17.2", features = ["async"], default-features = false } avro-schema = { version = "0.3" } [workspace.dependencies.arrow] diff --git a/crates/polars-arrow/Cargo.toml b/crates/polars-arrow/Cargo.toml index d464aad88f6b..58fb4d3ff72e 100644 --- a/crates/polars-arrow/Cargo.toml +++ b/crates/polars-arrow/Cargo.toml @@ -74,7 +74,7 @@ arrow-array = { workspace = true, optional = true } arrow-buffer = { workspace = true, optional = true } arrow-data = { workspace = true, optional = true } arrow-schema = { workspace = true, optional = true } -parquet2 = { workspace = true, optional = true, features = ["async"] } +parquet2 = { workspace = true, optional = true, default-features = true, features = ["async"] } [dev-dependencies] avro-rs = { version = "0.13", features = ["snappy"] } diff --git a/crates/polars-error/Cargo.toml b/crates/polars-error/Cargo.toml index 689e755e9a20..60e4800f073f 100644 --- a/crates/polars-error/Cargo.toml +++ b/crates/polars-error/Cargo.toml @@ -11,8 +11,8 @@ description = "Error definitions for the Polars DataFrame library" [dependencies] arrow-format = { version = "0.8.1", optional = true } avro-schema = { workspace = true, optional = true } -object_store = { workspace = true, default-features = false, optional = true } -parquet2 = { workspace = true, optional = true, default-features = false } +object_store = { workspace = true, optional = true } +parquet2 = { workspace = true, optional = true } regex = { workspace = true, optional = true } simdutf8 = { workspace = true } thiserror = { workspace = true } diff --git a/crates/polars-io/Cargo.toml b/crates/polars-io/Cargo.toml index 26819a2c58f0..18e69c0e9647 100644 --- a/crates/polars-io/Cargo.toml +++ b/crates/polars-io/Cargo.toml @@ -10,7 +10,7 @@ description = "IO related logic for the Polars DataFrame library" [dependencies] polars-core = { workspace = true } -polars-error = { workspace = true, default-features = false } +polars-error = { workspace = true } polars-json = { workspace = true, optional = true } polars-time = { workspace = true, features = [], optional = true } polars-utils = { workspace = true } diff --git a/crates/polars-lazy/Cargo.toml b/crates/polars-lazy/Cargo.toml index 785c2aeea51c..554e8eeda37d 100644 --- a/crates/polars-lazy/Cargo.toml +++ b/crates/polars-lazy/Cargo.toml @@ -10,7 +10,7 @@ description = "Lazy query engine for the Polars DataFrame library" [dependencies] arrow = { workspace = true } -polars-core = { workspace = true, features = ["lazy", "zip_with", "random"], default-features = false } +polars-core = { workspace = true, features = ["lazy", "zip_with", "random"] } polars-io = { workspace = true, features = ["lazy"] } polars-json = { workspace = true, optional = true } polars-ops = { workspace = true } diff --git a/crates/polars-ops/Cargo.toml b/crates/polars-ops/Cargo.toml index 72b00779d69f..1651ab8afb38 100644 --- a/crates/polars-ops/Cargo.toml +++ b/crates/polars-ops/Cargo.toml @@ -9,10 +9,10 @@ repository = { workspace = true } description = "More operations on Polars data structures" [dependencies] -polars-core = { workspace = true, features = ["algorithm_group_by"], default-features = false } +polars-core = { workspace = true, features = ["algorithm_group_by"] } polars-error = { workspace = true } polars-json = { workspace = true, optional = true } -polars-utils = { workspace = true, default-features = false } +polars-utils = { workspace = true } ahash = { workspace = true } argminmax = { version = "0.6.1", default-features = false, features = ["float"] } diff --git a/crates/polars-plan/Cargo.toml b/crates/polars-plan/Cargo.toml index 460881737b38..d70a90914835 100644 --- a/crates/polars-plan/Cargo.toml +++ b/crates/polars-plan/Cargo.toml @@ -13,10 +13,10 @@ doctest = false [dependencies] libloading = { version = "0.8.0", optional = true } -polars-core = { workspace = true, features = ["lazy", "zip_with", "random"], default-features = false } +polars-core = { workspace = true, features = ["lazy", "zip_with", "random"] } polars-ffi = { workspace = true, optional = true } -polars-io = { workspace = true, features = ["lazy"], default-features = false } -polars-ops = { workspace = true, features = ["zip_with"], default-features = false } +polars-io = { workspace = true, features = ["lazy"] } +polars-ops = { workspace = true, features = ["zip_with"] } polars-time = { workspace = true, optional = true } polars-utils = { workspace = true } diff --git a/crates/polars-time/Cargo.toml b/crates/polars-time/Cargo.toml index 77e8e861a6a1..abd23e909aa3 100644 --- a/crates/polars-time/Cargo.toml +++ b/crates/polars-time/Cargo.toml @@ -10,7 +10,7 @@ description = "Time related code for the Polars DataFrame library" [dependencies] arrow = { workspace = true, features = ["compute", "temporal"] } -polars-core = { workspace = true, default-features = false, features = ["dtype-datetime", "dtype-duration", "dtype-time", "dtype-date"] } +polars-core = { workspace = true, features = ["dtype-datetime", "dtype-duration", "dtype-time", "dtype-date"] } polars-error = { workspace = true } polars-ops = { workspace = true } polars-utils = { workspace = true } diff --git a/crates/polars/Cargo.toml b/crates/polars/Cargo.toml index 1f397e8e7aa6..165fea5806f0 100644 --- a/crates/polars/Cargo.toml +++ b/crates/polars/Cargo.toml @@ -14,7 +14,7 @@ description = "DataFrame library based on Apache Arrow" polars-algo = { workspace = true, optional = true } polars-core = { workspace = true } polars-io = { workspace = true, optional = true } -polars-lazy = { workspace = true, default-features = false, optional = true } +polars-lazy = { workspace = true, optional = true } polars-ops = { workspace = true } polars-sql = { workspace = true, optional = true } polars-time = { workspace = true, optional = true } From d9f2f5f7e59c171019edcd0033a5c1926eca03fc Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Fri, 20 Oct 2023 11:20:53 +0200 Subject: [PATCH 068/103] depr(python): Deprecate `DataType.is_nested` (#11844) --- py-polars/polars/__init__.py | 2 ++ py-polars/polars/datatypes/__init__.py | 2 ++ py-polars/polars/datatypes/classes.py | 32 +++++++++++++++++-- py-polars/polars/datatypes/constants.py | 4 +++ py-polars/polars/testing/asserts/series.py | 3 +- py-polars/tests/unit/datatypes/test_list.py | 18 +++++++---- py-polars/tests/unit/datatypes/test_struct.py | 2 +- 7 files changed, 53 insertions(+), 10 deletions(-) diff --git a/py-polars/polars/__init__.py b/py-polars/polars/__init__.py index 9abe028cbe25..b0495b3ef1df 100644 --- a/py-polars/polars/__init__.py +++ b/py-polars/polars/__init__.py @@ -29,6 +29,7 @@ DURATION_DTYPES, FLOAT_DTYPES, INTEGER_DTYPES, + NESTED_DTYPES, NUMERIC_DTYPES, TEMPORAL_DTYPES, Array, @@ -253,6 +254,7 @@ "DURATION_DTYPES", "FLOAT_DTYPES", "INTEGER_DTYPES", + "NESTED_DTYPES", "NUMERIC_DTYPES", "TEMPORAL_DTYPES", # polars.type_aliases diff --git a/py-polars/polars/datatypes/__init__.py b/py-polars/polars/datatypes/__init__.py index 9c5136f8d5e6..4576282539c4 100644 --- a/py-polars/polars/datatypes/__init__.py +++ b/py-polars/polars/datatypes/__init__.py @@ -40,6 +40,7 @@ FLOAT_DTYPES, INTEGER_DTYPES, N_INFER_DEFAULT, + NESTED_DTYPES, NUMERIC_DTYPES, SIGNED_INTEGER_DTYPES, TEMPORAL_DTYPES, @@ -113,6 +114,7 @@ "DURATION_DTYPES", "FLOAT_DTYPES", "INTEGER_DTYPES", + "NESTED_DTYPES", "NUMERIC_DTYPES", "N_INFER_DEFAULT", "SIGNED_INTEGER_DTYPES", diff --git a/py-polars/polars/datatypes/classes.py b/py-polars/polars/datatypes/classes.py index 0520866d76a8..aa9d316c21d3 100644 --- a/py-polars/polars/datatypes/classes.py +++ b/py-polars/polars/datatypes/classes.py @@ -144,7 +144,21 @@ def is_not(self, other: PolarsDataType) -> bool: @classproperty def is_nested(self) -> bool: - """Check if this data type is nested.""" + """ + Check if this data type is nested. + + .. deprecated:: 0.19.10 + Use `dtype in pl.NESTED_DTYPES` instead. + + """ + from polars.utils.deprecation import issue_deprecation_warning + + message = ( + "`DataType.is_nested` is deprecated and will be removed in the next breaking release." + " It will be changed to a classmethod rather than a property." + " To silence this warning, use `dtype in pl.NESTED_DTYPES` instead." + ) + issue_deprecation_warning(message, version="0.19.10") return False @@ -220,7 +234,21 @@ class NestedType(DataType): @classproperty def is_nested(self) -> bool: - """Check if this data type is nested.""" + """ + Check if this data type is nested. + + .. deprecated:: 0.19.10 + Use `dtype in pl.NESTED_DTYPES` instead. + + """ + from polars.utils.deprecation import issue_deprecation_warning + + message = ( + "`DataType.is_nested` is deprecated and will be removed in the next breaking release." + " It will be changed to a classmethod rather than a property." + " To silence this warning, use `dtype in pl.NESTED_DTYPES` instead." + ) + issue_deprecation_warning(message, version="0.19.10") return True diff --git a/py-polars/polars/datatypes/constants.py b/py-polars/polars/datatypes/constants.py index a1654442ecc9..24219b7e15ba 100644 --- a/py-polars/polars/datatypes/constants.py +++ b/py-polars/polars/datatypes/constants.py @@ -14,6 +14,8 @@ Int16, Int32, Int64, + List, + Struct, Time, UInt8, UInt16, @@ -73,5 +75,7 @@ FLOAT_DTYPES | INTEGER_DTYPES | frozenset([Decimal]) ) +NESTED_DTYPES: frozenset[PolarsDataType] = DataTypeGroup([List, Struct]) + # number of rows to scan by default when inferring datatypes N_INFER_DEFAULT = 100 diff --git a/py-polars/polars/testing/asserts/series.py b/py-polars/polars/testing/asserts/series.py index 028c1cff6db7..e285af716d6c 100644 --- a/py-polars/polars/testing/asserts/series.py +++ b/py-polars/polars/testing/asserts/series.py @@ -3,6 +3,7 @@ from polars import functions as F from polars.datatypes import ( FLOAT_DTYPES, + NESTED_DTYPES, UNSIGNED_INTEGER_DTYPES, Categorical, List, @@ -139,7 +140,7 @@ def _assert_series_values_equal( unequal = unequal | left.is_nan() | right.is_nan() # check nested dtypes in separate function - if left.dtype.is_nested or right.dtype.is_nested: + if left.dtype in NESTED_DTYPES or right.dtype in NESTED_DTYPES: if _assert_series_nested( left=left.filter(unequal), right=right.filter(unequal), diff --git a/py-polars/tests/unit/datatypes/test_list.py b/py-polars/tests/unit/datatypes/test_list.py index 9806a643bf0a..23e1a7be42b4 100644 --- a/py-polars/tests/unit/datatypes/test_list.py +++ b/py-polars/tests/unit/datatypes/test_list.py @@ -43,7 +43,7 @@ def test_dtype() -> None: "dt": pl.List(pl.Date), "dtm": pl.List(pl.Datetime), } - assert all(tp.is_nested for tp in df.dtypes) + assert all(tp in pl.NESTED_DTYPES for tp in df.dtypes) assert df.schema["i"].inner == pl.Int8 # type: ignore[union-attr] assert df.rows() == [ ( @@ -69,17 +69,15 @@ def test_categorical() -> None: out = ( df.group_by(["a", "b"]) .agg( - [ - pl.col("c").count().alias("num_different_c"), - pl.col("c").alias("c_values"), - ] + pl.col("c").count().alias("num_different_c"), + pl.col("c").alias("c_values"), ) .filter(pl.col("num_different_c") >= 2) .to_series(3) ) assert out.inner_dtype == pl.Categorical - assert not out.inner_dtype.is_nested + assert out.inner_dtype not in pl.NESTED_DTYPES def test_cast_inner() -> None: @@ -565,3 +563,11 @@ def test_list_inner_cast_physical_11513() -> None: }, ) assert df.select(pl.col("struct").take(0)).to_dict(False) == {"struct": [[]]} + + +@pytest.mark.parametrize( + ("dtype", "expected"), [(pl.List, True), (pl.Struct, True), (pl.Utf8, False)] +) +def test_list_is_nested_deprecated(dtype: PolarsDataType, expected: bool) -> None: + with pytest.deprecated_call(): + assert dtype.is_nested is expected diff --git a/py-polars/tests/unit/datatypes/test_struct.py b/py-polars/tests/unit/datatypes/test_struct.py index b372be41f045..1bc4a7a4b4a8 100644 --- a/py-polars/tests/unit/datatypes/test_struct.py +++ b/py-polars/tests/unit/datatypes/test_struct.py @@ -152,7 +152,7 @@ def test_struct_unnest_multiple() -> None: # List input result = df_structs.unnest(["s1", "s2"]) assert_frame_equal(result, df) - assert all(tp.is_nested for tp in df_structs.dtypes) + assert all(tp in pl.NESTED_DTYPES for tp in df_structs.dtypes) # Positional input result = df_structs.unnest("s1", "s2") From 1a0c17422f08cc95c2785bd9c9005321e44b6375 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Fri, 20 Oct 2023 11:53:27 +0200 Subject: [PATCH 069/103] fix: recursively apply `cast_unchecked` in lists (#11884) --- crates/polars-core/src/chunked_array/cast.rs | 34 ++++++++++++++++++- crates/polars-core/src/series/mod.rs | 3 +- .../tests/unit/datatypes/test_categorical.py | 5 +++ 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/crates/polars-core/src/chunked_array/cast.rs b/crates/polars-core/src/chunked_array/cast.rs index 801f5e285f3a..bf06636829f7 100644 --- a/crates/polars-core/src/chunked_array/cast.rs +++ b/crates/polars-core/src/chunked_array/cast.rs @@ -337,7 +337,11 @@ impl ChunkCast for ListChunked { } unsafe fn cast_unchecked(&self, data_type: &DataType) -> PolarsResult { - self.cast(data_type) + use DataType::*; + match data_type { + List(child_type) => cast_list_unchecked(self, child_type), + _ => self.cast(data_type), + } } } @@ -386,6 +390,8 @@ impl ChunkCast for ArrayChunked { // Returns inner data type. This is needed because a cast can instantiate the dtype inner // values for instance with categoricals fn cast_list(ca: &ListChunked, child_type: &DataType) -> PolarsResult<(ArrayRef, DataType)> { + // We still rechunk because we must bubble up a single data-type + // TODO!: consider a version that works on chunks and merges the data-types and arrays. let ca = ca.rechunk(); let arr = ca.downcast_iter().next().unwrap(); // safety: inner dtype is passed correctly @@ -409,6 +415,32 @@ fn cast_list(ca: &ListChunked, child_type: &DataType) -> PolarsResult<(ArrayRef, Ok((Box::new(new_arr), inner_dtype)) } +unsafe fn cast_list_unchecked(ca: &ListChunked, child_type: &DataType) -> PolarsResult { + // TODO! add chunked, but this must correct for list offsets. + let ca = ca.rechunk(); + let arr = ca.downcast_iter().next().unwrap(); + // safety: inner dtype is passed correctly + let s = unsafe { + Series::from_chunks_and_dtype_unchecked("", vec![arr.values().clone()], &ca.inner_dtype()) + }; + let new_inner = s.cast_unchecked(child_type)?; + let new_values = new_inner.array_ref(0).clone(); + + let data_type = ListArray::::default_datatype(new_values.data_type().clone()); + let new_arr = ListArray::::new( + data_type, + arr.offsets().clone(), + new_values, + arr.validity().cloned(), + ); + Ok(ListChunked::from_chunks_and_dtype_unchecked( + ca.name(), + vec![Box::new(new_arr)], + DataType::List(Box::new(child_type.clone())), + ) + .into_series()) +} + // Returns inner data type. This is needed because a cast can instantiate the dtype inner // values for instance with categoricals #[cfg(feature = "dtype-array")] diff --git a/crates/polars-core/src/series/mod.rs b/crates/polars-core/src/series/mod.rs index 5b978a29ce3c..4751f3ad2d23 100644 --- a/crates/polars-core/src/series/mod.rs +++ b/crates/polars-core/src/series/mod.rs @@ -888,7 +888,8 @@ impl Series { /// Packs every element into a list. pub fn as_list(&self) -> ListChunked { let s = self.rechunk(); - let values = s.to_arrow(0); + // don't use `to_arrow` as we need the physical types + let values = s.chunks()[0].clone(); let offsets = (0i64..(s.len() as i64 + 1)).collect::>(); let offsets = unsafe { Offsets::new_unchecked(offsets) }; diff --git a/py-polars/tests/unit/datatypes/test_categorical.py b/py-polars/tests/unit/datatypes/test_categorical.py index af727b187a75..f19a2732bf9c 100644 --- a/py-polars/tests/unit/datatypes/test_categorical.py +++ b/py-polars/tests/unit/datatypes/test_categorical.py @@ -422,3 +422,8 @@ def test_categorical_collect_11408() -> None: "groups": ["a", "b", "c"], "cats": ["a", "b", "c"], } + + +def test_categorical_nested_cast_unchecked() -> None: + s = pl.Series("cat", [["cat"]]).cast(pl.List(pl.Categorical)) + assert pl.Series([s]).to_list() == [[["cat"]]] From 6dae5502de62bde0ce35a4bc3608f0e8721de40d Mon Sep 17 00:00:00 2001 From: Danny van Kooten Date: Fri, 20 Oct 2023 11:59:58 +0200 Subject: [PATCH 070/103] docs: load 40x40 avatar from github and add loading=lazy attribute. (#11886) Co-authored-by: Danny van Kooten --- docs/_build/scripts/people.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/_build/scripts/people.py b/docs/_build/scripts/people.py index 10186549d4d8..72ba55c37f56 100644 --- a/docs/_build/scripts/people.py +++ b/docs/_build/scripts/people.py @@ -6,8 +6,7 @@ auth = Auth.Token(token) if token else None g = Github(auth=auth) -ICON_TEMPLATE = "[![{login}]({avatar_url}){{.contributor_icon}}]({html_url})" - +ICON_TEMPLATE = '{login}' def get_people_md(): repo = g.get_repo("pola-rs/polars") From cb92cb8bd48b5c9e1cc469307b74364072de5969 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Fri, 20 Oct 2023 14:05:27 +0400 Subject: [PATCH 071/103] perf(python): optimise `read_database` Databricks queries made using SQLAlchemy connections (#11885) --- py-polars/polars/dataframe/frame.py | 8 +- py-polars/polars/io/database.py | 83 +++++++--- py-polars/tests/unit/io/test_database_read.py | 156 ++++++++++++------ 3 files changed, 168 insertions(+), 79 deletions(-) diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 4f802362ba6b..0ffa0629edac 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -3968,8 +3968,8 @@ def filter( Provide multiple filters using `*args` syntax: >>> df.filter( - ... pl.col("foo") == 1, - ... pl.col("ham") == "a", + ... pl.col("foo") <= 2, + ... ~pl.col("ham").is_in(["b", "c"]), ... ) shape: (1, 3) ┌─────┬─────┬─────┐ @@ -3982,14 +3982,14 @@ def filter( Provide multiple filters using `**kwargs` syntax: - >>> df.filter(foo=1, ham="a") + >>> df.filter(foo=2, ham="b") shape: (1, 3) ┌─────┬─────┬─────┐ │ foo ┆ bar ┆ ham │ │ --- ┆ --- ┆ --- │ │ i64 ┆ i64 ┆ str │ ╞═════╪═════╪═════╡ - │ 1 ┆ 6 ┆ a │ + │ 2 ┆ 7 ┆ b │ └─────┴─────┴─────┘ """ diff --git a/py-polars/polars/io/database.py b/py-polars/polars/io/database.py index ed3bfa858d5a..1f0795b3df0d 100644 --- a/py-polars/polars/io/database.py +++ b/py-polars/polars/io/database.py @@ -31,42 +31,49 @@ Selectable: TypeAlias = Any # type: ignore[no-redef] -class _DriverProperties_(TypedDict): - fetch_all: str - fetch_batches: str | None - exact_batch_size: bool | None +class _ArrowDriverProperties_(TypedDict): + fetch_all: str # name of the method that fetches all arrow data + fetch_batches: str | None # name of the method that fetches arrow data in batches + exact_batch_size: bool | None # whether indicated batch size is respected exactly + repeat_batch_calls: bool # repeat batch calls (if False, batch call is generator) -_ARROW_DRIVER_REGISTRY_: dict[str, _DriverProperties_] = { +_ARROW_DRIVER_REGISTRY_: dict[str, _ArrowDriverProperties_] = { "adbc_.*": { "fetch_all": "fetch_arrow_table", "fetch_batches": None, "exact_batch_size": None, + "repeat_batch_calls": False, }, "arrow_odbc_proxy": { "fetch_all": "fetch_record_batches", "fetch_batches": "fetch_record_batches", "exact_batch_size": True, + "repeat_batch_calls": False, }, "databricks": { "fetch_all": "fetchall_arrow", "fetch_batches": "fetchmany_arrow", "exact_batch_size": True, + "repeat_batch_calls": True, }, "duckdb": { "fetch_all": "fetch_arrow_table", "fetch_batches": "fetch_record_batch", "exact_batch_size": True, + "repeat_batch_calls": False, }, "snowflake": { "fetch_all": "fetch_arrow_all", "fetch_batches": "fetch_arrow_batches", "exact_batch_size": False, + "repeat_batch_calls": False, }, "turbodbc": { "fetch_all": "fetchallarrow", "fetch_batches": "fetcharrowbatches", "exact_batch_size": False, + "repeat_batch_calls": False, }, } @@ -121,10 +128,9 @@ def fetch_record_batches( class ConnectionExecutor: """Abstraction for querying databases with user-supplied connection objects.""" - # indicate that we acquired a cursor (and are therefore responsible for closing - # it on scope-exit). note that we should never close the underlying connection, - # or a user-supplied cursor. - acquired_cursor: bool = False + # indicate if we can/should close the cursor on scope exit. note that we + # should never close the underlying connection, or a user-supplied cursor. + can_close_cursor: bool = False def __init__(self, connection: ConnectionOrCursor) -> None: self.driver_name = ( @@ -144,24 +150,57 @@ def __exit__( exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: - # iif we created it, close the cursor (NOT the connection) - if self.acquired_cursor: + # iif we created it and are finished with it, we can + # close the cursor (but NOT the connection) + if self.can_close_cursor: self.cursor.close() def __repr__(self) -> str: return f"<{type(self).__name__} module={self.driver_name!r}>" + def _arrow_batches( + self, + driver_properties: _ArrowDriverProperties_, + *, + batch_size: int | None, + iter_batches: bool, + ) -> Iterable[pa.RecordBatch]: + """Yield Arrow data in batches, or as a single 'fetchall' batch.""" + fetch_batches = driver_properties["fetch_batches"] + if not iter_batches or fetch_batches is None: + fetch_method = driver_properties["fetch_all"] + yield getattr(self.result, fetch_method)() + else: + size = batch_size if driver_properties["exact_batch_size"] else None + repeat_batch_calls = driver_properties["repeat_batch_calls"] + fetchmany_arrow = getattr(self.result, fetch_batches) + if not repeat_batch_calls: + yield from fetchmany_arrow(size) + else: + while True: + arrow = fetchmany_arrow(size) + if not arrow: + break + yield arrow + def _normalise_cursor(self, conn: ConnectionOrCursor) -> Cursor: """Normalise a connection object such that we have the query executor.""" if self.driver_name == "sqlalchemy" and type(conn).__name__ == "Engine": - # sqlalchemy engine; direct use is deprecated, so prefer the connection - self.acquired_cursor = True - return conn.connect() # type: ignore[union-attr] + self.can_close_cursor = True + if conn.driver == "databricks-sql-python": # type: ignore[union-attr] + # take advantage of the raw connection to get arrow integration + self.driver_name = "databricks" + return conn.raw_connection().cursor() # type: ignore[union-attr] + else: + # sqlalchemy engine; direct use is deprecated, so prefer the connection + return conn.connect() # type: ignore[union-attr] + elif hasattr(conn, "cursor"): # connection has a dedicated cursor; prefer over direct execute cursor = cursor() if callable(cursor := conn.cursor) else cursor - self.acquired_cursor = True + self.can_close_cursor = True return cursor + elif hasattr(conn, "execute"): # can execute directly (given cursor, sqlalchemy connection, etc) return conn # type: ignore[return-value] @@ -206,22 +245,20 @@ def _from_arrow( try: for driver, driver_properties in _ARROW_DRIVER_REGISTRY_.items(): if re.match(f"^{driver}$", self.driver_name): - size = batch_size if driver_properties["exact_batch_size"] else None fetch_batches = driver_properties["fetch_batches"] + self.can_close_cursor = fetch_batches is None or not iter_batches frames = ( from_arrow(batch, schema_overrides=schema_overrides) - for batch in ( - getattr(self.result, fetch_batches)(size) - if (iter_batches and fetch_batches is not None) - else [ - getattr(self.result, driver_properties["fetch_all"])() - ] + for batch in self._arrow_batches( + driver_properties, + iter_batches=iter_batches, + batch_size=batch_size, ) ) return frames if iter_batches else next(frames) # type: ignore[arg-type,return-value] except Exception as err: # eg: valid turbodbc/snowflake connection, but no arrow support - # available in the underlying driver or this connection + # compiled in to the underlying driver (or on this connection) arrow_not_supported = ( "does not support Apache Arrow", "Apache Arrow format is not supported", diff --git a/py-polars/tests/unit/io/test_database_read.py b/py-polars/tests/unit/io/test_database_read.py index aa872c00a85b..824ffb989fd2 100644 --- a/py-polars/tests/unit/io/test_database_read.py +++ b/py-polars/tests/unit/io/test_database_read.py @@ -15,9 +15,12 @@ import polars as pl from polars.exceptions import UnsuitableSQLError +from polars.io.database import _ARROW_DRIVER_REGISTRY_ from polars.testing import assert_frame_equal if TYPE_CHECKING: + import pyarrow as pa + from polars.type_aliases import DbReadEngine, SchemaDefinition, SchemaDict @@ -84,6 +87,77 @@ class ExceptionTestParams(NamedTuple): kwargs: dict[str, Any] | None = None +class MockConnection: + """Mock connection class for databases we can't test in CI.""" + + def __init__( + self, + driver: str, + batch_size: int | None, + test_data: pa.Table, + repeat_batch_calls: bool, + ) -> None: + self.__class__.__module__ = driver + self._cursor = MockCursor( + repeat_batch_calls=repeat_batch_calls, + batched=(batch_size is not None), + test_data=test_data, + ) + + def close(self) -> None: # noqa: D102 + pass + + def cursor(self) -> Any: # noqa: D102 + return self._cursor + + +class MockCursor: + """Mock cursor class for databases we can't test in CI.""" + + def __init__( + self, + batched: bool, + test_data: pa.Table, + repeat_batch_calls: bool, + ) -> None: + self.resultset = MockResultSet(test_data, batched, repeat_batch_calls) + self.called: list[str] = [] + self.batched = batched + self.n_calls = 1 + + def __getattr__(self, item: str) -> Any: + if "fetch" in item: + self.called.append(item) + return self.resultset + super().__getattr__(item) # type: ignore[misc] + + def close(self) -> Any: # noqa: D102 + pass + + def execute(self, query: str) -> Any: # noqa: D102 + return self + + +class MockResultSet: + """Mock resultset class for databases we can't test in CI.""" + + def __init__( + self, test_data: pa.Table, batched: bool, repeat_batch_calls: bool = False + ): + self.test_data = test_data + self.repeat_batched_calls = repeat_batch_calls + self.batched = batched + self.n_calls = 1 + + def __call__(self, *args: Any, **kwargs: Any) -> Any: # noqa: D102 + if self.repeat_batched_calls: + res = self.test_data[: None if self.n_calls else 0] + self.n_calls -= 1 + else: + res = iter((self.test_data,)) + return res + + @pytest.mark.write_disk() @pytest.mark.parametrize( ( @@ -307,45 +381,9 @@ def test_read_database_parameterisd(tmp_path: Path) -> None: ) -def test_read_database_mocked() -> None: - arr = pl.DataFrame({"x": [1, 2, 3], "y": ["aa", "bb", "cc"]}).to_arrow() - - class MockConnection: - def __init__(self, driver: str, batch_size: int | None = None) -> None: - self.__class__.__module__ = driver - self._cursor = MockCursor(batched=batch_size is not None) - - def close(self) -> None: - pass - - def cursor(self) -> Any: - return self._cursor - - class MockCursor: - def __init__(self, batched: bool) -> None: - self.called: list[str] = [] - self.batched = batched - - def __getattr__(self, item: str) -> Any: - if "fetch" in item: - res = ( - (lambda *args, **kwargs: (arr for _ in range(1))) - if self.batched - else (lambda *args, **kwargs: arr) - ) - self.called.append(item) - return res - super().__getattr__(item) # type: ignore[misc] - - def close(self) -> Any: - pass - - def execute(self, query: str) -> Any: - return self - - # since we don't have access to snowflake/databricks/etc from CI we - # mock them so we can check that we're calling the expected methods - for driver, batch_size, iter_batches, expected_call in ( +@pytest.mark.parametrize( + ("driver", "batch_size", "iter_batches", "expected_call"), + [ ("snowflake", None, False, "fetch_arrow_all"), ("snowflake", 10_000, False, "fetch_arrow_all"), ("snowflake", 10_000, True, "fetch_arrow_batches"), @@ -358,20 +396,34 @@ def execute(self, query: str) -> Any: ("adbc_driver_postgresql", None, False, "fetch_arrow_table"), ("adbc_driver_postgresql", 75_000, False, "fetch_arrow_table"), ("adbc_driver_postgresql", 75_000, True, "fetch_arrow_table"), - ): - mc = MockConnection(driver, batch_size) - res = pl.read_database( # type: ignore[call-overload] - query="SELECT * FROM test_data", - connection=mc, - iter_batches=iter_batches, - batch_size=batch_size, - ) - assert expected_call in mc.cursor().called - if iter_batches: - assert isinstance(res, GeneratorType) - res = pl.concat(res) + ], +) +def test_read_database_mocked( + driver: str, batch_size: int | None, iter_batches: bool, expected_call: str +) -> None: + # since we don't have access to snowflake/databricks/etc from CI we + # mock them so we can check that we're calling the expected methods + arrow = pl.DataFrame({"x": [1, 2, 3], "y": ["aa", "bb", "cc"]}).to_arrow() + mc = MockConnection( + driver, + batch_size, + test_data=arrow, + repeat_batch_calls=_ARROW_DRIVER_REGISTRY_.get(driver, {}).get( # type: ignore[call-overload] + "repeat_batch_calls", False + ), + ) + res = pl.read_database( # type: ignore[call-overload] + query="SELECT * FROM test_data", + connection=mc, + iter_batches=iter_batches, + batch_size=batch_size, + ) + if iter_batches: + assert isinstance(res, GeneratorType) + res = pl.concat(res) - assert res.rows() == [(1, "aa"), (2, "bb"), (3, "cc")] + assert expected_call in mc.cursor().called + assert res.rows() == [(1, "aa"), (2, "bb"), (3, "cc")] @pytest.mark.parametrize( From 0d9f865c21361f9e26f42df27b0c0b3f48c353f6 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Fri, 20 Oct 2023 13:30:57 +0200 Subject: [PATCH 072/103] refactor(python): Further assert utils refactor (#11888) --- py-polars/polars/expr/expr.py | 2 +- py-polars/polars/series/series.py | 2 +- py-polars/polars/testing/asserts/frame.py | 2 + py-polars/polars/testing/asserts/series.py | 240 ++++++++++-------- py-polars/polars/testing/asserts/utils.py | 7 +- py-polars/tests/unit/series/test_series.py | 25 ++ .../unit/testing/test_assert_series_equal.py | 100 ++++++-- 7 files changed, 234 insertions(+), 144 deletions(-) diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 8ad68bb58117..2fb89de12607 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -4494,7 +4494,7 @@ def eq(self, other: Any) -> Self: def eq_missing(self, other: Any) -> Self: """ - Method equivalent of equality operator ``expr == other`` where `None` == None`. + Method equivalent of equality operator ``expr == other`` where ``None == None``. This differs from default ``eq`` where null values are propagated. diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 14572ca9480f..e3ea688c33ee 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -618,7 +618,7 @@ def eq_missing(self, other: Expr) -> Expr: # type: ignore[misc] def eq_missing(self, other: Any) -> Self | Expr: """ - Method equivalent of equality operator ``series == other`` where `None` == None`. + Method equivalent of equality operator ``series == other`` where ``None == None``. This differs from the standard ``ne`` where null values are propagated. diff --git a/py-polars/polars/testing/asserts/frame.py b/py-polars/polars/testing/asserts/frame.py index 4faa4d810050..3920dc57e050 100644 --- a/py-polars/polars/testing/asserts/frame.py +++ b/py-polars/polars/testing/asserts/frame.py @@ -47,6 +47,7 @@ def assert_frame_equal( check_exact Require data values to match exactly. If set to ``False``, values are considered equal when within tolerance of each other (see ``rtol`` and ``atol``). + Logical types like dates are always checked exactly. rtol Relative tolerance for inexact checking. Fraction of values in ``right``. atol @@ -228,6 +229,7 @@ def assert_frame_not_equal( check_exact Require data values to match exactly. If set to ``False``, values are considered equal when within tolerance of each other (see ``rtol`` and ``atol``). + Logical types like dates are always checked exactly. rtol Relative tolerance for inexact checking. Fraction of values in ``right``. atol diff --git a/py-polars/polars/testing/asserts/series.py b/py-polars/polars/testing/asserts/series.py index e285af716d6c..53b41ac335fe 100644 --- a/py-polars/polars/testing/asserts/series.py +++ b/py-polars/polars/testing/asserts/series.py @@ -1,18 +1,19 @@ from __future__ import annotations -from polars import functions as F from polars.datatypes import ( FLOAT_DTYPES, NESTED_DTYPES, + NUMERIC_DTYPES, UNSIGNED_INTEGER_DTYPES, Categorical, + Int64, List, Struct, UInt64, Utf8, - dtype_to_py_type, unpack_dtypes, ) +from polars.exceptions import ComputeError from polars.series import Series from polars.testing.asserts.utils import raise_assertion_error @@ -48,8 +49,10 @@ def assert_series_equal( check_exact Require data values to match exactly. If set to ``False``, values are considered equal when within tolerance of each other (see ``rtol`` and ``atol``). + Logical types like dates are always checked exactly. rtol - Relative tolerance for inexact checking. Fraction of values in ``right``. + Relative tolerance for inexact checking, given as a fraction of the values in + ``right``. atol Absolute tolerance for inexact checking. nans_compare_equal @@ -122,24 +125,32 @@ def _assert_series_values_equal( categorical_as_str: bool, ) -> None: """Assert that the values in both Series are equal.""" - if categorical_as_str and left.dtype == Categorical: - left, right = left.cast(Utf8), right.cast(Utf8) - - # create mask of which (if any) values are unequal - unequal = left.ne_missing(right) + # Handle categoricals + if categorical_as_str: + if left.dtype == Categorical: + left = left.cast(Utf8) + if right.dtype == Categorical: + right = right.cast(Utf8) + + # Determine unequal elements + try: + unequal = left.ne_missing(right) + except ComputeError as exc: + raise_assertion_error( + "Series", + "incompatible data types", + left=left.dtype, + right=right.dtype, + cause=exc, + ) - # handle NaN values (which compare unequal to themselves) + # Handle NaN values (which compare unequal to themselves) comparing_floats = left.dtype in FLOAT_DTYPES and right.dtype in FLOAT_DTYPES - if unequal.any() and nans_compare_equal: - # when both dtypes are scalar floats - if comparing_floats: - unequal = unequal & ~( - (left.is_nan() & right.is_nan()).fill_null(F.lit(False)) - ) - if comparing_floats and not nans_compare_equal: - unequal = unequal | left.is_nan() | right.is_nan() + if comparing_floats and nans_compare_equal: + both_nan = (left.is_nan() & right.is_nan()).fill_null(False) + unequal = unequal & ~both_nan - # check nested dtypes in separate function + # Check nested dtypes in separate function if left.dtype in NESTED_DTYPES or right.dtype in NESTED_DTYPES: if _assert_series_nested( left=left.filter(unequal), @@ -152,84 +163,65 @@ def _assert_series_values_equal( ): return - try: - can_be_subtracted = hasattr(dtype_to_py_type(left.dtype), "__sub__") - except NotImplementedError: - can_be_subtracted = False - - check_exact = ( - check_exact or not can_be_subtracted or left.is_boolean() or left.is_temporal() - ) + # If no differences found during exact checking, we're done + if not unequal.any(): + return - # assert exact, or with tolerance - if unequal.any(): - if check_exact: - raise_assertion_error( - "Series", - "exact value mismatch", - left=left.to_list(), - right=right.to_list(), - ) - else: - equal, nan_info = _check_series_equal_inexact( - left, - right, - unequal, - rtol=rtol, - atol=atol, - nans_compare_equal=nans_compare_equal, - comparing_floats=comparing_floats, - ) + # Only do inexact checking for numeric types + if ( + check_exact + or left.dtype not in NUMERIC_DTYPES + or right.dtype not in NUMERIC_DTYPES + ): + raise_assertion_error( + "Series", + "exact value mismatch", + left=left.to_list(), + right=right.to_list(), + ) - if not equal: - raise_assertion_error( - "Series", - f"value mismatch{nan_info}", - left=left.to_list(), - right=right.to_list(), - ) + _assert_series_null_values_match(left, right) + if comparing_floats: + _assert_series_nan_values_match( + left, right, nans_compare_equal=nans_compare_equal + ) + _assert_series_values_within_tolerance( + left, + right, + unequal, + rtol=rtol, + atol=atol, + ) -def _check_series_equal_inexact( - left: Series, - right: Series, - unequal: Series, - *, - rtol: float, - atol: float, - nans_compare_equal: bool, - comparing_floats: bool, -) -> tuple[bool, str]: - # apply check with tolerance (to the known-unequal matches). - left, right = left.filter(unequal), right.filter(unequal) - - if all(tp in UNSIGNED_INTEGER_DTYPES for tp in (left.dtype, right.dtype)): - # avoid potential "subtract-with-overflow" panic on uint math - s_diff = Series( - "diff", [abs(v1 - v2) for v1, v2 in zip(left, right)], dtype=UInt64 +def _assert_series_null_values_match(left: Series, right: Series) -> None: + null_value_mismatch = left.is_null() != right.is_null() + if null_value_mismatch.any(): + raise_assertion_error( + "Series", "null value mismatch", left.to_list(), right.to_list() ) - else: - s_diff = (left - right).abs() - equal, nan_info = True, "" - if ((s_diff > (atol + rtol * right.abs())).sum() != 0) or ( - left.is_null() != right.is_null() - ).any(): - equal = False - elif comparing_floats: - # note: take special care with NaN values. - # if NaNs don't compare as equal, any NaN in the left Series is - # sufficient for a mismatch because the if condition above already - # compares the null values. - if not nans_compare_equal and left.is_nan().any(): - equal = False - nan_info = " (nans_compare_equal=False)" - elif (left.is_nan() != right.is_nan()).any(): - equal = False - nan_info = f" (nans_compare_equal={nans_compare_equal})" +def _assert_series_nan_values_match( + left: Series, right: Series, *, nans_compare_equal: bool +) -> None: + if nans_compare_equal: + nan_value_mismatch = left.is_nan() != right.is_nan() + if nan_value_mismatch.any(): + raise_assertion_error( + "Series", + "nan value mismatch - nans compare equal", + left.to_list(), + right.to_list(), + ) - return equal, nan_info + elif left.is_nan().any() or right.is_nan().any(): + raise_assertion_error( + "Series", + "nan value mismatch - nans compare unequal", + left.to_list(), + right.to_list(), + ) def _assert_series_nested( @@ -249,19 +241,9 @@ def _assert_series_nested( # compare nested lists element-wise elif left.dtype == List == right.dtype: for s1, s2 in zip(left, right): - if s1 is None and s2 is None: - if nans_compare_equal: - continue - else: - raise_assertion_error( - "Series", - f"Nested value mismatch (nans_compare_equal={nans_compare_equal})", - s1, - s2, - ) - elif (s1 is None and s2 is not None) or (s2 is None and s1 is not None): + if (s1 is None and s2 is not None) or (s2 is None and s1 is not None): raise_assertion_error("Series", "nested value mismatch", s1, s2) - elif len(s1) != len(s2): + elif s1.len() != s2.len(): raise_assertion_error( "Series", "nested list length mismatch", len(s1), len(s2) ) @@ -308,6 +290,42 @@ def _assert_series_nested( return False +def _assert_series_values_within_tolerance( + left: Series, + right: Series, + unequal: Series, + *, + rtol: float, + atol: float, +) -> None: + left_unequal, right_unequal = left.filter(unequal), right.filter(unequal) + + difference = _calc_absolute_diff(left_unequal, right_unequal) + tolerance = atol + rtol * right_unequal.abs() + exceeds_tolerance = difference > tolerance + + if exceeds_tolerance.any(): + raise_assertion_error( + "Series", + "value mismatch", + left.to_list(), + right.to_list(), + ) + + +def _calc_absolute_diff(left: Series, right: Series) -> Series: + if left.dtype in UNSIGNED_INTEGER_DTYPES and right.dtype in UNSIGNED_INTEGER_DTYPES: + try: + left = left.cast(Int64) + right = right.cast(Int64) + except ComputeError: + # Handle big UInt64 values through conversion to Python + diff = [abs(v1 - v2) for v1, v2 in zip(left, right)] + return Series(diff, dtype=UInt64) + + return (left - right).abs() + + def assert_series_not_equal( left: Series, right: Series, @@ -328,25 +346,27 @@ def assert_series_not_equal( Parameters ---------- left - the series to compare. + The first Series to compare. right - the series to compare with. + The second Series to compare. check_dtype - if True, data types need to match exactly. + Require data types to match. check_names - if True, names need to match. + Require names to match. check_exact - if False, test if values are within tolerance of each other - (see `rtol` & `atol`). + Require data values to match exactly. If set to ``False``, values are considered + equal when within tolerance of each other (see ``rtol`` and ``atol``). + Logical types like dates are always checked exactly. rtol - relative tolerance for inexact checking. Fraction of values in `right`. + Relative tolerance for inexact checking, given as a fraction of the values in + ``right``. atol - absolute tolerance for inexact checking. + Absolute tolerance for inexact checking. nans_compare_equal - if your assert/test requires float NaN != NaN, set this to False. + Consider NaN values to be equal. categorical_as_str Cast categorical columns to string before comparing. Enabling this helps - compare DataFrames that do not share the same string cache. + compare columns that do not share the same string cache. See Also -------- diff --git a/py-polars/polars/testing/asserts/utils.py b/py-polars/polars/testing/asserts/utils.py index 713e57170ac1..1b7ac40c7814 100644 --- a/py-polars/polars/testing/asserts/utils.py +++ b/py-polars/polars/testing/asserts/utils.py @@ -4,12 +4,9 @@ def raise_assertion_error( - objects: str, - detail: str, - left: Any, - right: Any, + objects: str, detail: str, left: Any, right: Any, *, cause: Exception | None = None ) -> NoReturn: """Raise a detailed assertion error.""" __tracebackhide__ = True msg = f"{objects} are different ({detail})\n[left]: {left}\n[right]: {right}" - raise AssertionError(msg) + raise AssertionError(msg) from cause diff --git a/py-polars/tests/unit/series/test_series.py b/py-polars/tests/unit/series/test_series.py index 6d31979f973c..b830442250c2 100644 --- a/py-polars/tests/unit/series/test_series.py +++ b/py-polars/tests/unit/series/test_series.py @@ -2760,3 +2760,28 @@ def test_series_getitem_out_of_bounds_negative() -> None: IndexError, match="index -10 is out of bounds for sequence of length 2" ): s[-10] + + +def test_series_cmp_fast_paths() -> None: + assert ( + pl.Series([None], dtype=pl.Int32) != pl.Series([1, 2], dtype=pl.Int32) + ).to_list() == [None, None] + assert ( + pl.Series([None], dtype=pl.Int32) == pl.Series([1, 2], dtype=pl.Int32) + ).to_list() == [None, None] + + assert ( + pl.Series([None], dtype=pl.Utf8) != pl.Series(["a", "b"], dtype=pl.Utf8) + ).to_list() == [None, None] + assert ( + pl.Series([None], dtype=pl.Utf8) == pl.Series(["a", "b"], dtype=pl.Utf8) + ).to_list() == [None, None] + + assert ( + pl.Series([None], dtype=pl.Boolean) + != pl.Series([True, False], dtype=pl.Boolean) + ).to_list() == [None, None] + assert ( + pl.Series([None], dtype=pl.Boolean) + == pl.Series([False, False], dtype=pl.Boolean) + ).to_list() == [None, None] diff --git a/py-polars/tests/unit/testing/test_assert_series_equal.py b/py-polars/tests/unit/testing/test_assert_series_equal.py index 13c298fc1746..b7dbd56c57d6 100644 --- a/py-polars/tests/unit/testing/test_assert_series_equal.py +++ b/py-polars/tests/unit/testing/test_assert_series_equal.py @@ -99,31 +99,6 @@ def test_compare_series_nulls() -> None: assert_series_equal(srs1, srs2) -def test_series_cmp_fast_paths() -> None: - assert ( - pl.Series([None], dtype=pl.Int32) != pl.Series([1, 2], dtype=pl.Int32) - ).to_list() == [None, None] - assert ( - pl.Series([None], dtype=pl.Int32) == pl.Series([1, 2], dtype=pl.Int32) - ).to_list() == [None, None] - - assert ( - pl.Series([None], dtype=pl.Utf8) != pl.Series(["a", "b"], dtype=pl.Utf8) - ).to_list() == [None, None] - assert ( - pl.Series([None], dtype=pl.Utf8) == pl.Series(["a", "b"], dtype=pl.Utf8) - ).to_list() == [None, None] - - assert ( - pl.Series([None], dtype=pl.Boolean) - != pl.Series([True, False], dtype=pl.Boolean) - ).to_list() == [None, None] - assert ( - pl.Series([None], dtype=pl.Boolean) - == pl.Series([False, False], dtype=pl.Boolean) - ).to_list() == [None, None] - - def test_compare_series_value_mismatch_string() -> None: srs1 = pl.Series(["hello", "no"]) srs2 = pl.Series(["hello", "yes"]) @@ -159,7 +134,7 @@ def test_compare_series_name_mismatch() -> None: assert_series_equal(srs1, srs2) -def test_compare_series_shape_mismatch() -> None: +def test_compare_series_length_mismatch() -> None: srs1 = pl.Series(values=[1, 2, 3, 4], name="srs1") srs2 = pl.Series(values=[1, 2, 3], name="srs2") @@ -621,7 +596,7 @@ def test_assert_series_equal_raises_assertion_error( def test_assert_series_equal_categorical() -> None: s1 = pl.Series(["a", "b", "a"], dtype=pl.Categorical) s2 = pl.Series(["a", "b", "a"], dtype=pl.Categorical) - with pytest.raises(pl.ComputeError, match="cannot compare categoricals"): + with pytest.raises(AssertionError, match="incompatible data types"): assert_series_equal(s1, s2) assert_series_equal(s1, s2, categorical_as_str=True) @@ -634,6 +609,17 @@ def test_assert_series_equal_categorical_vs_str() -> None: with pytest.raises(AssertionError, match="dtype mismatch"): assert_series_equal(s1, s2, categorical_as_str=True) + assert_series_equal(s1, s2, check_dtype=False, categorical_as_str=True) + assert_series_equal(s2, s1, check_dtype=False, categorical_as_str=True) + + +def test_assert_series_equal_incompatible_data_types() -> None: + s1 = pl.Series(["a", "b", "a"], dtype=pl.Categorical) + s2 = pl.Series([0, 1, 0], dtype=pl.Int8) + + with pytest.raises(AssertionError, match="incompatible data types"): + assert_series_equal(s1, s2, check_dtype=False) + def test_assert_series_equal_full_series() -> None: s1 = pl.Series([1, 2, 3]) @@ -651,3 +637,63 @@ def test_assert_series_not_equal() -> None: s = pl.Series("a", [1, 2]) with pytest.raises(AssertionError, match="Series are equal"): assert_series_not_equal(s, s) + + +def test_assert_series_equal_nested_list_float() -> None: + # First entry has only integers + s1 = pl.Series([[1.0, 2.0], [3.0, 4.0]], dtype=pl.List(pl.Float64)) + s2 = pl.Series([[1.0, 2.0], [3.0, 4.9]], dtype=pl.List(pl.Float64)) + + with pytest.raises(AssertionError): + assert_series_equal(s1, s2) + + +def test_assert_series_equal_nested_struct_float() -> None: + s1 = pl.Series( + [{"a": 1.0, "b": 2.0}, {"a": 3.0, "b": 4.0}], + dtype=pl.Struct({"a": pl.Float64, "b": pl.Float64}), + ) + s2 = pl.Series( + [{"a": 1.0, "b": 2.0}, {"a": 3.0, "b": 4.9}], + dtype=pl.Struct({"a": pl.Float64, "b": pl.Float64}), + ) + + with pytest.raises(AssertionError): + assert_series_equal(s1, s2) + + +def test_assert_series_equal_nested_list_full_null() -> None: + # First entry has only integers + s1 = pl.Series([None, None], dtype=pl.List(pl.Float64)) + s2 = pl.Series([None, None], dtype=pl.List(pl.Float64)) + + assert_series_equal(s1, s2) + + +def test_assert_series_equal_nested_list_nan() -> None: + s1 = pl.Series([[1.0, 2.0], [3.0, float("nan")]], dtype=pl.List(pl.Float64)) + s2 = pl.Series([[1.0, 2.0], [3.0, float("nan")]], dtype=pl.List(pl.Float64)) + + with pytest.raises(AssertionError): + assert_series_equal(s1, s2, nans_compare_equal=False) + + +def test_assert_series_equal_nested_list_none() -> None: + s1 = pl.Series([[1.0, 2.0], None], dtype=pl.List(pl.Float64)) + s2 = pl.Series([[1.0, 2.0], None], dtype=pl.List(pl.Float64)) + + assert_series_equal(s1, s2, nans_compare_equal=False) + + +def test_assert_series_equal_full_none_nested_not_nested() -> None: + s1 = pl.Series([None, None], dtype=pl.List(pl.Float64)) + s2 = pl.Series([None, None], dtype=pl.Float64) + + assert_series_equal(s1, s2, check_dtype=False) + + +def test_assert_series_equal_unsigned_ints_underflow() -> None: + s1 = pl.Series([1, 3], dtype=pl.UInt8) + s2 = pl.Series([2, 4], dtype=pl.Int64) + + assert_series_equal(s1, s2, atol=1, check_dtype=False) From d9c63161fa62ac67b560e40e5a85a6ee8c559a9b Mon Sep 17 00:00:00 2001 From: Marshall Date: Fri, 20 Oct 2023 07:58:58 -0400 Subject: [PATCH 073/103] fix(python): Add `include_nulls` parameter to `update` (#11830) --- py-polars/polars/dataframe/frame.py | 34 ++++++++++-- py-polars/polars/lazyframe/frame.py | 56 ++++++++++++++++---- py-polars/tests/unit/operations/test_join.py | 22 ++++++++ 3 files changed, 100 insertions(+), 12 deletions(-) diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 0ffa0629edac..ffc8f0f4223e 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -9754,13 +9754,19 @@ def update( left_on: str | Sequence[str] | None = None, right_on: str | Sequence[str] | None = None, how: Literal["left", "inner", "outer"] = "left", + include_nulls: bool | None = False, ) -> DataFrame: """ - Update the values in this `DataFrame` with the non-null values in `other`. + Update the values in this `DataFrame` with the values in `other`. + + By default, null values in the right dataframe are ignored. Use + `ignore_nulls=False` to overwrite values in this frame with null values in other + frame. Notes ----- - This is syntactic sugar for a left/inner join + coalesce + This is syntactic sugar for a left/inner join, with an optional coalesce when + `include_nulls = False`. Warnings -------- @@ -9784,6 +9790,9 @@ def update( * 'inner' keeps only those rows where the key exists in both frames. * 'outer' will update existing rows where the key matches while also adding any new rows contained in the given frame. + include_nulls + If True, null values from the right dataframe will be used to update the + left dataframe. Examples -------- @@ -9859,10 +9868,29 @@ def update( │ 5 ┆ -66 │ └─────┴─────┘ + Update `df` values including null values in `new_df`, using an outer join + strategy that defines explicit join columns in each frame: + + >>> df.update( + ... new_df, left_on="A", right_on="C", how="outer", include_nulls=True + ... ) + shape: (5, 2) + ┌─────┬──────┐ + │ A ┆ B │ + │ --- ┆ --- │ + │ i64 ┆ i64 │ + ╞═════╪══════╡ + │ 1 ┆ -99 │ + │ 2 ┆ 500 │ + │ 3 ┆ null │ + │ 4 ┆ 700 │ + │ 5 ┆ -66 │ + └─────┴──────┘ + """ return ( self.lazy() - .update(other.lazy(), on, left_on, right_on, how) + .update(other.lazy(), on, left_on, right_on, how, include_nulls) .collect(_eager=True) ) diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index 3a53048b723d..784c2afa7583 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -5656,6 +5656,7 @@ def update( left_on: str | Sequence[str] | None = None, right_on: str | Sequence[str] | None = None, how: Literal["left", "inner", "outer"] = "left", + include_nulls: bool | None = False, ) -> Self: """ Update the values in this `LazyFrame` with the non-null values in `other`. @@ -5677,10 +5678,14 @@ def update( * 'inner' keeps only those rows where the key exists in both frames. * 'outer' will update existing rows where the key matches while also adding any new rows contained in the given frame. + include_nulls + If True, null values from the right dataframe will be used to update the + left dataframe. Notes ----- - This is syntactic sugar for a join + coalesce (upsert) operation. + This is syntactic sugar for a left/inner join, with an optional coalesce when + `include_nulls = False`. Examples -------- @@ -5756,6 +5761,25 @@ def update( │ 5 ┆ -66 │ └─────┴─────┘ + Update `df` values including null values in `new_df`, using an outer join + strategy that defines explicit join columns in each frame: + + >>> lf.update( + ... new_lf, left_on="A", right_on="C", how="outer", include_nulls=True + ... ).collect() + shape: (5, 2) + ┌─────┬──────┐ + │ A ┆ B │ + │ --- ┆ --- │ + │ i64 ┆ i64 │ + ╞═════╪══════╡ + │ 1 ┆ -99 │ + │ 2 ┆ 500 │ + │ 3 ┆ null │ + │ 4 ┆ 700 │ + │ 5 ┆ -66 │ + └─────┴──────┘ + """ if how not in ("left", "inner", "outer"): raise ValueError( @@ -5804,24 +5828,38 @@ def update( # only use non-idx right columns present in left frame right_other = set(other.columns).intersection(self.columns) - set(right_on) + # When include_nulls is True, we need to distinguish records after the join that + # were originally null in the right frame, as opposed to records that were null + # because the key was missing from the right frame. + # Add a validity column to track whether row was matched or not. + if include_nulls: + validity = ("__POLARS_VALIDITY",) + other = other.with_columns(F.lit(True).alias(validity[0])) + else: + validity = () # type: ignore[assignment] + tmp_name = "__POLARS_RIGHT" + drop_columns = [*(f"{name}{tmp_name}" for name in right_other), *validity] result = ( self.join( - other.select(*right_on, *right_other), + other.select(*right_on, *right_other, *validity), left_on=left_on, right_on=right_on, how=how, suffix=tmp_name, ) .with_columns( - [ - F.coalesce([f"{column_name}{tmp_name}", F.col(column_name)]).alias( - column_name - ) - for column_name in right_other - ] + ( + # use left value only when right value failed to join + F.when(F.col(validity).is_null()) + .then(F.col(name)) + .otherwise(F.col(f"{name}{tmp_name}")) + if include_nulls + else F.coalesce([f"{name}{tmp_name}", F.col(name)]) + ).alias(name) + for name in right_other ) - .drop([f"{name}{tmp_name}" for name in right_other]) + .drop(drop_columns) ) if row_count_used: result = result.drop(row_count_name) diff --git a/py-polars/tests/unit/operations/test_join.py b/py-polars/tests/unit/operations/test_join.py index ce529cea450a..844b02988f3e 100644 --- a/py-polars/tests/unit/operations/test_join.py +++ b/py-polars/tests/unit/operations/test_join.py @@ -561,6 +561,28 @@ def test_update() -> None: a.update(b.rename({"b": "a"}), how="outer", on="a").collect().to_series() ) + # check behavior of include_nulls=True + df = pl.DataFrame( + { + "A": [1, 2, 3, 4], + "B": [400, 500, 600, 700], + } + ) + new_df = pl.DataFrame( + { + "B": [-66, None, -99], + "C": [5, 3, 1], + } + ) + out = df.update(new_df, left_on="A", right_on="C", how="outer", include_nulls=True) + expected = pl.DataFrame( + { + "A": [1, 2, 3, 4, 5], + "B": [-99, 500, None, 700, -66], + } + ) + assert_frame_equal(out, expected) + # edge-case #11684 x = pl.DataFrame({"a": [0, 1]}) y = pl.DataFrame({"a": [2, 3]}) From fada98bbed67faf70e64323fbf1d1140563eec26 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Fri, 20 Oct 2023 16:27:49 +0200 Subject: [PATCH 074/103] fix: use physcial append (#11894) --- crates/polars-lazy/src/physical_plan/planner/expr.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/crates/polars-lazy/src/physical_plan/planner/expr.rs b/crates/polars-lazy/src/physical_plan/planner/expr.rs index 0e614c4cb683..72b0b80c5d3c 100644 --- a/crates/polars-lazy/src/physical_plan/planner/expr.rs +++ b/crates/polars-lazy/src/physical_plan/planner/expr.rs @@ -603,10 +603,11 @@ where let mut iter = chunks.into_iter(); let first = iter.next().unwrap(); - let out = iter.fold(first, |mut acc, s| { - acc.append(&s).unwrap(); + let dtype = first.dtype(); + let out = iter.fold(first.to_physical_repr().into_owned(), |mut acc, s| { + acc.append(&s.to_physical_repr()).unwrap(); acc }); - f(out).map(Some) + unsafe { f(out.cast_unchecked(dtype).unwrap()).map(Some) } } From 5425f6a03fb985266a78054d779e6f60656b8bd3 Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Fri, 20 Oct 2023 16:28:59 +0200 Subject: [PATCH 075/103] perf: fix quadratic behavior in append sorted check (#11893) --- .../src/chunked_array/ops/append.rs | 33 ++++++++++++++----- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/crates/polars-core/src/chunked_array/ops/append.rs b/crates/polars-core/src/chunked_array/ops/append.rs index c14405ed377d..3aff4cc9e51c 100644 --- a/crates/polars-core/src/chunked_array/ops/append.rs +++ b/crates/polars-core/src/chunked_array/ops/append.rs @@ -19,16 +19,20 @@ where T: PolarsDataType, for<'a> T::Physical<'a>: TotalOrd, { - // If either is empty (or completely null), copy the sorted flag from the other. - if ca.len() == ca.null_count() { + // TODO: attempt to maintain sortedness better in case of nulls. + + // If either is empty, copy the sorted flag from the other. + if ca.is_empty() { ca.set_sorted_flag(other.is_sorted_flag()); return; } - if other.len() == other.null_count() { + if other.is_empty() { return; } - // Both need to be sorted, in the same order. + // Both need to be sorted, in the same order, if the order is maintained. + // TODO: rework sorted flags, ascending and descending are not mutually + // exclusive for all-equal/all-null arrays. let ls = ca.is_sorted_flag(); let rs = other.is_sorted_flag(); if ls != rs || ls == IsSorted::Not || rs == IsSorted::Not { @@ -38,12 +42,23 @@ where // Check the order is maintained. let still_sorted = { - let left = ca.get(ca.last_non_null().unwrap()).unwrap(); - let right = other.get(other.first_non_null().unwrap()).unwrap(); - if ca.is_sorted_ascending_flag() { - left.tot_le(&right) + // To prevent potential quadratic append behavior we do not find + // the last non-null element in ca. + if let Some(left) = ca.last() { + if let Some(right_idx) = other.first_non_null() { + let right = other.get(right_idx).unwrap(); + if ca.is_sorted_ascending_flag() { + left.tot_le(&right) + } else { + left.tot_ge(&right) + } + } else { + // Right is only nulls, trivially sorted. + true + } } else { - left.tot_ge(&right) + // Last element in left is null, pessimistically assume not sorted. + false } }; if !still_sorted { From c69722df98e90d2ea60d41a5c356973d7e86730e Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Fri, 20 Oct 2023 18:20:35 +0200 Subject: [PATCH 076/103] perf: fix accidental quadratic behavior; cache null_count (#11889) --- .../polars-core/src/chunked_array/builder/binary.rs | 6 +++++- .../polars-core/src/chunked_array/builder/boolean.rs | 4 ++-- .../src/chunked_array/builder/primitive.rs | 4 ++-- crates/polars-core/src/chunked_array/builder/utf8.rs | 4 ++-- crates/polars-core/src/chunked_array/from.rs | 11 ++++++++++- crates/polars-core/src/chunked_array/mod.rs | 5 ++++- .../polars-core/src/chunked_array/object/builder.rs | 6 ++++++ crates/polars-core/src/chunked_array/ops/append.rs | 4 ++++ crates/polars-core/src/chunked_array/ops/apply.rs | 1 + crates/polars-core/src/chunked_array/ops/chunkops.rs | 5 +++++ .../polars-core/src/chunked_array/upstream_traits.rs | 2 ++ crates/polars-core/src/series/mod.rs | 6 ++++++ crates/polars-core/src/utils/mod.rs | 3 +++ .../src/physical_plan/streaming/convert_alp.rs | 1 + 14 files changed, 53 insertions(+), 9 deletions(-) diff --git a/crates/polars-core/src/chunked_array/builder/binary.rs b/crates/polars-core/src/chunked_array/builder/binary.rs index 119dc461c7ed..bed05a434ba1 100644 --- a/crates/polars-core/src/chunked_array/builder/binary.rs +++ b/crates/polars-core/src/chunked_array/builder/binary.rs @@ -1,3 +1,5 @@ +use polars_error::constants::LENGTH_LIMIT_MSG; + use super::*; pub struct BinaryChunkedBuilder { @@ -40,7 +42,8 @@ impl BinaryChunkedBuilder { pub fn finish(mut self) -> BinaryChunked { let arr = self.builder.as_box(); - let length = arr.len() as IdxSize; + let length = IdxSize::try_from(arr.len()).expect(LENGTH_LIMIT_MSG); + let null_count = arr.null_count() as IdxSize; ChunkedArray { field: Arc::new(self.field), @@ -48,6 +51,7 @@ impl BinaryChunkedBuilder { phantom: PhantomData, bit_settings: Default::default(), length, + null_count, } } diff --git a/crates/polars-core/src/chunked_array/builder/boolean.rs b/crates/polars-core/src/chunked_array/builder/boolean.rs index 655d94ff1a7d..407bc3abcf53 100644 --- a/crates/polars-core/src/chunked_array/builder/boolean.rs +++ b/crates/polars-core/src/chunked_array/builder/boolean.rs @@ -21,14 +21,14 @@ impl ChunkedBuilder for BooleanChunkedBuilder { fn finish(mut self) -> BooleanChunked { let arr = self.array_builder.as_box(); - let length = arr.len() as IdxSize; let mut ca = ChunkedArray { field: Arc::new(self.field), chunks: vec![arr], phantom: PhantomData, bit_settings: Default::default(), - length, + length: 0, + null_count: 0, }; ca.compute_len(); ca diff --git a/crates/polars-core/src/chunked_array/builder/primitive.rs b/crates/polars-core/src/chunked_array/builder/primitive.rs index f5314a5fb62a..eae7977612fe 100644 --- a/crates/polars-core/src/chunked_array/builder/primitive.rs +++ b/crates/polars-core/src/chunked_array/builder/primitive.rs @@ -27,13 +27,13 @@ where fn finish(mut self) -> ChunkedArray { let arr = self.array_builder.as_box(); - let length = arr.len() as IdxSize; let mut ca = ChunkedArray { field: Arc::new(self.field), chunks: vec![arr], phantom: PhantomData, bit_settings: Default::default(), - length, + length: 0, + null_count: 0, }; ca.compute_len(); ca diff --git a/crates/polars-core/src/chunked_array/builder/utf8.rs b/crates/polars-core/src/chunked_array/builder/utf8.rs index 49f933c790ed..1a1c793563ed 100644 --- a/crates/polars-core/src/chunked_array/builder/utf8.rs +++ b/crates/polars-core/src/chunked_array/builder/utf8.rs @@ -41,14 +41,14 @@ impl Utf8ChunkedBuilder { pub fn finish(mut self) -> Utf8Chunked { let arr = self.builder.as_box(); - let length = arr.len() as IdxSize; let mut ca = ChunkedArray { field: Arc::new(self.field), chunks: vec![arr], phantom: PhantomData, bit_settings: Default::default(), - length, + length: 0, + null_count: 0, }; ca.compute_len(); ca diff --git a/crates/polars-core/src/chunked_array/from.rs b/crates/polars-core/src/chunked_array/from.rs index 1c67b3b75963..c384dec3e241 100644 --- a/crates/polars-core/src/chunked_array/from.rs +++ b/crates/polars-core/src/chunked_array/from.rs @@ -1,3 +1,5 @@ +use polars_error::constants::LENGTH_LIMIT_MSG; + use super::*; #[allow(clippy::all)] @@ -143,10 +145,12 @@ where ); let mut length = 0; + let mut null_count = 0; let chunks = chunks .into_iter() .map(|x| { length += x.len(); + null_count += x.null_count(); Box::new(x) as Box }) .collect(); @@ -156,7 +160,8 @@ where chunks, phantom: PhantomData, bit_settings: Default::default(), - length: length.try_into().unwrap(), + length: length.try_into().expect(LENGTH_LIMIT_MSG), + null_count: null_count as IdxSize, } } @@ -184,6 +189,7 @@ where phantom: PhantomData, bit_settings: Default::default(), length: 0, + null_count: 0, }; out.compute_len(); out @@ -213,6 +219,7 @@ where phantom: PhantomData, bit_settings: Default::default(), length: 0, + null_count: 0, }; out.compute_len(); out @@ -235,6 +242,7 @@ where phantom: PhantomData, bit_settings, length: 0, + null_count: 0, }; out.compute_len(); if !keep_sorted { @@ -258,6 +266,7 @@ where phantom: PhantomData, bit_settings: Default::default(), length: 0, + null_count: 0, }; out.compute_len(); out diff --git a/crates/polars-core/src/chunked_array/mod.rs b/crates/polars-core/src/chunked_array/mod.rs index fa79caea754c..3a619c61d4a5 100644 --- a/crates/polars-core/src/chunked_array/mod.rs +++ b/crates/polars-core/src/chunked_array/mod.rs @@ -140,6 +140,7 @@ pub struct ChunkedArray { phantom: PhantomData, pub(crate) bit_settings: Settings, length: IdxSize, + null_count: IdxSize, } bitflags! { @@ -303,6 +304,7 @@ impl ChunkedArray { /// /// # Safety /// The caller must ensure to not change the [`DataType`] or `length` of any of the chunks. + /// And the `null_count` remains correct. #[inline] pub unsafe fn chunks_mut(&mut self) -> &mut Vec { &mut self.chunks @@ -316,7 +318,7 @@ impl ChunkedArray { /// Count the null values. #[inline] pub fn null_count(&self) -> usize { - self.chunks.iter().map(|arr| arr.null_count()).sum() + self.null_count as usize } /// Create a new [`ChunkedArray`] from self, where the chunks are replaced. @@ -610,6 +612,7 @@ impl Clone for ChunkedArray { phantom: PhantomData, bit_settings: self.bit_settings, length: self.length, + null_count: self.null_count, } } } diff --git a/crates/polars-core/src/chunked_array/object/builder.rs b/crates/polars-core/src/chunked_array/object/builder.rs index 351cdc58a383..a6f8b9072c98 100644 --- a/crates/polars-core/src/chunked_array/object/builder.rs +++ b/crates/polars-core/src/chunked_array/object/builder.rs @@ -59,6 +59,10 @@ where let null_bitmap: Option = self.bitmask_builder.into(); let len = self.values.len(); + let null_count = null_bitmap + .as_ref() + .map(|validity| validity.unset_bits()) + .unwrap_or(0) as IdxSize; let arr = Box::new(ObjectArray { values: Arc::new(self.values), @@ -72,6 +76,7 @@ where phantom: PhantomData, bit_settings: Default::default(), length: len as IdxSize, + null_count, } } } @@ -136,6 +141,7 @@ where phantom: PhantomData, bit_settings: Default::default(), length: len as IdxSize, + null_count: 0, } } diff --git a/crates/polars-core/src/chunked_array/ops/append.rs b/crates/polars-core/src/chunked_array/ops/append.rs index 3aff4cc9e51c..027ccb09d168 100644 --- a/crates/polars-core/src/chunked_array/ops/append.rs +++ b/crates/polars-core/src/chunked_array/ops/append.rs @@ -78,6 +78,7 @@ where update_sorted_flag_before_append::(self, other); let len = self.len(); self.length += other.length; + self.null_count += other.null_count; new_chunks(&mut self.chunks, &other.chunks, len); } } @@ -90,6 +91,7 @@ impl ListChunked { let len = self.len(); self.length += other.length; + self.null_count += other.null_count; new_chunks(&mut self.chunks, &other.chunks, len); self.set_sorted_flag(IsSorted::Not); if !other._can_fast_explode() { @@ -108,6 +110,7 @@ impl ArrayChunked { let len = self.len(); self.length += other.length; + self.null_count += other.null_count; new_chunks(&mut self.chunks, &other.chunks, len); self.set_sorted_flag(IsSorted::Not); Ok(()) @@ -120,6 +123,7 @@ impl ObjectChunked { pub fn append(&mut self, other: &Self) { let len = self.len(); self.length += other.length; + self.null_count += other.null_count; self.set_sorted_flag(IsSorted::Not); new_chunks(&mut self.chunks, &other.chunks, len); } diff --git a/crates/polars-core/src/chunked_array/ops/apply.rs b/crates/polars-core/src/chunked_array/ops/apply.rs index 093e6c172d95..1254363eaa75 100644 --- a/crates/polars-core/src/chunked_array/ops/apply.rs +++ b/crates/polars-core/src/chunked_array/ops/apply.rs @@ -220,6 +220,7 @@ impl ChunkedArray { .for_each(|arr| arrow::compute::arity_assign::unary(arr, f)) }; // can be in any order now + self.compute_len(); self.set_sorted_flag(IsSorted::Not); } } diff --git a/crates/polars-core/src/chunked_array/ops/chunkops.rs b/crates/polars-core/src/chunked_array/ops/chunkops.rs index 076dc6476702..a60502afb130 100644 --- a/crates/polars-core/src/chunked_array/ops/chunkops.rs +++ b/crates/polars-core/src/chunked_array/ops/chunkops.rs @@ -74,6 +74,11 @@ impl ChunkedArray { } } self.length = IdxSize::try_from(inner(&self.chunks)).expect(LENGTH_LIMIT_MSG); + self.null_count = self + .chunks + .iter() + .map(|arr| arr.null_count()) + .sum::() as IdxSize; if self.length <= 1 { self.set_sorted_flag(IsSorted::Ascending) diff --git a/crates/polars-core/src/chunked_array/upstream_traits.rs b/crates/polars-core/src/chunked_array/upstream_traits.rs index af24444fdf14..fac284c4615e 100644 --- a/crates/polars-core/src/chunked_array/upstream_traits.rs +++ b/crates/polars-core/src/chunked_array/upstream_traits.rs @@ -30,6 +30,7 @@ impl Default for ChunkedArray { phantom: PhantomData, bit_settings: Default::default(), length: 0, + null_count: 0, } } } @@ -330,6 +331,7 @@ impl FromIterator> for ObjectChunked { phantom: PhantomData, bit_settings: Default::default(), length: 0, + null_count: 0, }; out.compute_len(); out diff --git a/crates/polars-core/src/series/mod.rs b/crates/polars-core/src/series/mod.rs index 4751f3ad2d23..0cce71998a2a 100644 --- a/crates/polars-core/src/series/mod.rs +++ b/crates/polars-core/src/series/mod.rs @@ -181,6 +181,7 @@ impl Series { /// # Safety /// The caller must ensure the length and the data types of `ArrayRef` does not change. + /// And that the null_count is updated (e.g. with a `compute_len()`) pub unsafe fn chunks_mut(&mut self) -> &mut Vec { #[allow(unused_mut)] let mut ca = self._get_inner_mut(); @@ -254,6 +255,11 @@ impl Series { Ok(self) } + /// Redo a length and null_count compute + pub fn compute_len(&mut self) { + self._get_inner_mut().compute_len() + } + /// Extend the memory backed by this array with the values from `other`. /// /// See [`ChunkedArray::extend`] and [`ChunkedArray::append`]. diff --git a/crates/polars-core/src/utils/mod.rs b/crates/polars-core/src/utils/mod.rs index 43ee28b644f9..c85c807096f5 100644 --- a/crates/polars-core/src/utils/mod.rs +++ b/crates/polars-core/src/utils/mod.rs @@ -879,6 +879,7 @@ pub fn coalesce_nulls<'a, T: PolarsDataType>( *arr_b = arr_b.with_validity(arr.validity().cloned()) } } + b.compute_len(); (Cow::Owned(a), Cow::Owned(b)) } else { (Cow::Borrowed(a), Cow::Borrowed(b)) @@ -899,6 +900,8 @@ pub fn coalesce_nulls_series(a: &Series, b: &Series) -> (Series, Series) { *arr_a = arr_a.with_validity(validity.clone()); *arr_b = arr_b.with_validity(validity); } + a.compute_len(); + b.compute_len(); (a, b) } else { (a.clone(), b.clone()) diff --git a/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs b/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs index bc2af56f18d7..e3d7125b9e27 100644 --- a/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs +++ b/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs @@ -358,6 +358,7 @@ pub(crate) fn insert_streaming_nodes( #[cfg(feature = "dtype-categorical")] DataType::Categorical(_) => string_cache, DataType::List(inner) => allowed_dtype(inner, string_cache), + #[cfg(feature = "dtype-struct")] DataType::Struct(fields) => fields .iter() .all(|fld| allowed_dtype(fld.data_type(), string_cache)), From eb469b407e52b9835baf059bd63993581bbeda1d Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Fri, 20 Oct 2023 18:43:20 +0200 Subject: [PATCH 077/103] python polars 0.19.10 (#11895) --- py-polars/Cargo.lock | 2 +- py-polars/Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/py-polars/Cargo.lock b/py-polars/Cargo.lock index 549ce47c3d89..8e2c6af72fd0 100644 --- a/py-polars/Cargo.lock +++ b/py-polars/Cargo.lock @@ -1925,7 +1925,7 @@ dependencies = [ [[package]] name = "py-polars" -version = "0.19.9" +version = "0.19.10" dependencies = [ "ahash", "built", diff --git a/py-polars/Cargo.toml b/py-polars/Cargo.toml index f782192c869e..e6950298d037 100644 --- a/py-polars/Cargo.toml +++ b/py-polars/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "py-polars" -version = "0.19.9" +version = "0.19.10" edition = "2021" [lib] From fe04f4a264a686886f134a8e79fd28c21a0e727f Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Sat, 21 Oct 2023 00:31:39 +0400 Subject: [PATCH 078/103] fix(python): raise a suitable error from `read_excel` and/or `read_ods` when target sheet does not exist (#11906) --- py-polars/polars/io/spreadsheet/functions.py | 4 +++ py-polars/tests/unit/io/test_spreadsheet.py | 30 ++++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/py-polars/polars/io/spreadsheet/functions.py b/py-polars/polars/io/spreadsheet/functions.py index 198b69404290..7e88687a6493 100644 --- a/py-polars/polars/io/spreadsheet/functions.py +++ b/py-polars/polars/io/spreadsheet/functions.py @@ -445,6 +445,10 @@ def _read_spreadsheet( if hasattr(parser, "close"): parser.close() + if not parsed_sheets: + param, value = ("id", sheet_id) if sheet_name is None else ("name", sheet_name) + raise ValueError(f"no matching sheets found when `sheet_{param}` is {value!r}") + if return_multi: return parsed_sheets return next(iter(parsed_sheets.values())) diff --git a/py-polars/tests/unit/io/test_spreadsheet.py b/py-polars/tests/unit/io/test_spreadsheet.py index a777e8af318b..79ad6300fd7e 100644 --- a/py-polars/tests/unit/io/test_spreadsheet.py +++ b/py-polars/tests/unit/io/test_spreadsheet.py @@ -194,6 +194,36 @@ def test_read_excel_basic_datatypes( assert_frame_equal(df, df) +@pytest.mark.parametrize( + ("read_spreadsheet", "source", "params"), + [ + (pl.read_excel, "path_xlsx", {"engine": "xlsx2csv"}), + (pl.read_excel, "path_xlsx", {"engine": "openpyxl"}), + (pl.read_excel, "path_xlsb", {"engine": "pyxlsb"}), + (pl.read_ods, "path_ods", {}), + ], +) +def test_read_invalid_worksheet( + read_spreadsheet: Callable[..., dict[str, pl.DataFrame]], + source: str, + params: dict[str, str], + request: pytest.FixtureRequest, +) -> None: + spreadsheet_path = request.getfixturevalue(source) + for param, sheet_id, sheet_name in ( + ("id", 999, None), + ("name", None, "not_a_sheet_name"), + ): + value = sheet_id if param == "id" else sheet_name + with pytest.raises( + ValueError, + match=f"no matching sheets found when `sheet_{param}` is {value!r}", + ): + read_spreadsheet( + spreadsheet_path, sheet_id=sheet_id, sheet_name=sheet_name, **params + ) + + @pytest.mark.parametrize("engine", ["xlsx2csv", "openpyxl"]) def test_write_excel_bytes(engine: Literal["xlsx2csv", "openpyxl", "pyxlsb"]) -> None: df = pl.DataFrame({"A": [1, 2, 3, 4, 5]}) From d847f69767eecee3ed2e88ac86ee2cf25231b679 Mon Sep 17 00:00:00 2001 From: Brayan Jules Date: Fri, 20 Oct 2023 22:21:50 -0300 Subject: [PATCH 079/103] fix: fixed the bug that incorrectly enabled the conversion from epoch string to datetime. --- crates/polars-core/src/chunked_array/cast.rs | 6 ++---- .../src/chunked_array/temporal/mod.rs | 19 +------------------ py-polars/tests/unit/test_lazy.py | 19 ++++++++++++++++++- 3 files changed, 21 insertions(+), 23 deletions(-) diff --git a/crates/polars-core/src/chunked_array/cast.rs b/crates/polars-core/src/chunked_array/cast.rs index 96f5fa1ead04..8daddeac1d81 100644 --- a/crates/polars-core/src/chunked_array/cast.rs +++ b/crates/polars-core/src/chunked_array/cast.rs @@ -5,8 +5,6 @@ use arrow::compute::cast::CastOptions; #[cfg(feature = "dtype-categorical")] use crate::chunked_array::categorical::CategoricalChunkedBuilder; -#[cfg(feature = "temporal")] -use crate::chunked_array::temporal::validate_is_number; #[cfg(feature = "timezones")] use crate::chunked_array::temporal::validate_time_zone; use crate::prelude::DataType::Datetime; @@ -199,13 +197,13 @@ impl ChunkCast for Utf8Chunked { }, }, #[cfg(feature = "dtype-date")] - DataType::Date if !validate_is_number(&self.chunks) => { + DataType::Date => { let result = cast_chunks(&self.chunks, data_type, true)?; let out = Series::try_from((self.name(), result))?; Ok(out) }, #[cfg(feature = "dtype-datetime")] - DataType::Datetime(tu, tz) if !validate_is_number(&self.chunks) => { + DataType::Datetime(tu, tz) => { let out = match tz { #[cfg(feature = "timezones")] Some(tz) => { diff --git a/crates/polars-core/src/chunked_array/temporal/mod.rs b/crates/polars-core/src/chunked_array/temporal/mod.rs index 3b6a38aede8b..0a89825f6959 100644 --- a/crates/polars-core/src/chunked_array/temporal/mod.rs +++ b/crates/polars-core/src/chunked_array/temporal/mod.rs @@ -19,9 +19,9 @@ use chrono_tz::Tz; pub use time::time_to_time64ns; pub use self::conversion::*; +use crate::prelude::ArrayRef; #[cfg(feature = "timezones")] use crate::prelude::{polars_bail, PolarsResult}; -use crate::prelude::{ArrayRef, LargeStringArray}; pub fn unix_time() -> NaiveDateTime { NaiveDateTime::from_timestamp_opt(0, 0).unwrap() @@ -36,20 +36,3 @@ pub(crate) fn validate_time_zone(tz: &str) -> PolarsResult<()> { }, } } - -pub(crate) fn validate_is_number(vec_array: &[ArrayRef]) -> bool { - vec_array.iter().all(is_parsable_as_number) -} - -fn is_parsable_as_number(array: &ArrayRef) -> bool { - if let Some(array) = array.as_any().downcast_ref::() { - array.iter().all(|value| { - value - .expect("Unable to parse int string to datetime") - .parse::() - .is_ok() - }) - } else { - false - } -} diff --git a/py-polars/tests/unit/test_lazy.py b/py-polars/tests/unit/test_lazy.py index 2fedbf853435..7dc6478ab5d2 100644 --- a/py-polars/tests/unit/test_lazy.py +++ b/py-polars/tests/unit/test_lazy.py @@ -1375,7 +1375,7 @@ def test_quadratic_behavior_4736() -> None: ldf.select(reduce(add, (pl.col(fld) for fld in ldf.columns))) -@pytest.mark.parametrize("input_dtype", [pl.Int64, pl.Float64, pl.Utf8]) +@pytest.mark.parametrize("input_dtype", [pl.Int64, pl.Float64]) def test_from_epoch(input_dtype: pl.PolarsDataType) -> None: ldf = pl.LazyFrame( [ @@ -1415,6 +1415,23 @@ def test_from_epoch(input_dtype: pl.PolarsDataType) -> None: _ = ldf.select(pl.from_epoch(ts_col, time_unit="s2")) # type: ignore[call-overload] +def test_from_epoch_str() -> None: + ldf = pl.LazyFrame( + [ + pl.Series("timestamp_ms", [1147880044 * 1_000]).cast(pl.Utf8), + pl.Series("timestamp_us", [1147880044 * 1_000_000]).cast(pl.Utf8), + ] + ) + + with pytest.raises(ComputeError): + ldf.select( + [ + pl.from_epoch(pl.col("timestamp_ms"), time_unit="ms"), + pl.from_epoch(pl.col("timestamp_us"), time_unit="us"), + ] + ).collect() + + def test_cumagg_types() -> None: ldf = pl.LazyFrame({"a": [1, 2], "b": [True, False], "c": [1.3, 2.4]}) cumsum_lf = ldf.select( From c5459f1226218c09478fb86ffa41453dd62c80a7 Mon Sep 17 00:00:00 2001 From: Brayan Jules Date: Fri, 20 Oct 2023 22:37:26 -0300 Subject: [PATCH 080/103] fix: removed unused import --- crates/polars-core/src/chunked_array/temporal/mod.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/crates/polars-core/src/chunked_array/temporal/mod.rs b/crates/polars-core/src/chunked_array/temporal/mod.rs index 0a89825f6959..737ff5086d47 100644 --- a/crates/polars-core/src/chunked_array/temporal/mod.rs +++ b/crates/polars-core/src/chunked_array/temporal/mod.rs @@ -19,7 +19,6 @@ use chrono_tz::Tz; pub use time::time_to_time64ns; pub use self::conversion::*; -use crate::prelude::ArrayRef; #[cfg(feature = "timezones")] use crate::prelude::{polars_bail, PolarsResult}; From 6a1731eed2417eae46d00d72f145cd75393d9449 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Sat, 21 Oct 2023 09:37:24 +0400 Subject: [PATCH 081/103] fix(python): address issue with inadvertently shared options dict in `read_excel` (#11908) --- py-polars/polars/io/spreadsheet/functions.py | 13 +++++++------ py-polars/tests/unit/io/test_spreadsheet.py | 17 +++++++++++++++++ 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/py-polars/polars/io/spreadsheet/functions.py b/py-polars/polars/io/spreadsheet/functions.py index 7e88687a6493..29bfac7acec1 100644 --- a/py-polars/polars/io/spreadsheet/functions.py +++ b/py-polars/polars/io/spreadsheet/functions.py @@ -28,7 +28,7 @@ def read_excel( engine: Literal["xlsx2csv", "openpyxl", "pyxlsb"] | None = ..., xlsx2csv_options: dict[str, Any] | None = ..., read_csv_options: dict[str, Any] | None = ..., - schema_overrides: SchemaDict | None = None, + schema_overrides: SchemaDict | None = ..., raise_if_empty: bool = ..., ) -> pl.DataFrame: ... @@ -43,7 +43,7 @@ def read_excel( engine: Literal["xlsx2csv", "openpyxl", "pyxlsb"] | None = ..., xlsx2csv_options: dict[str, Any] | None = ..., read_csv_options: dict[str, Any] | None = ..., - schema_overrides: SchemaDict | None = None, + schema_overrides: SchemaDict | None = ..., raise_if_empty: bool = ..., ) -> pl.DataFrame: ... @@ -58,7 +58,7 @@ def read_excel( engine: Literal["xlsx2csv", "openpyxl", "pyxlsb"] | None = ..., xlsx2csv_options: dict[str, Any] | None = ..., read_csv_options: dict[str, Any] | None = ..., - schema_overrides: SchemaDict | None = None, + schema_overrides: SchemaDict | None = ..., raise_if_empty: bool = ..., ) -> NoReturn: ... @@ -75,7 +75,7 @@ def read_excel( engine: Literal["xlsx2csv", "openpyxl", "pyxlsb"] | None = ..., xlsx2csv_options: dict[str, Any] | None = ..., read_csv_options: dict[str, Any] | None = ..., - schema_overrides: SchemaDict | None = None, + schema_overrides: SchemaDict | None = ..., raise_if_empty: bool = ..., ) -> dict[str, pl.DataFrame]: ... @@ -90,7 +90,7 @@ def read_excel( engine: Literal["xlsx2csv", "openpyxl", "pyxlsb"] | None = ..., xlsx2csv_options: dict[str, Any] | None = ..., read_csv_options: dict[str, Any] | None = ..., - schema_overrides: SchemaDict | None = None, + schema_overrides: SchemaDict | None = ..., raise_if_empty: bool = ..., ) -> pl.DataFrame: ... @@ -105,7 +105,7 @@ def read_excel( engine: Literal["xlsx2csv", "openpyxl", "pyxlsb"] | None = ..., xlsx2csv_options: dict[str, Any] | None = ..., read_csv_options: dict[str, Any] | None = ..., - schema_overrides: SchemaDict | None = None, + schema_overrides: SchemaDict | None = ..., raise_if_empty: bool = ..., ) -> dict[str, pl.DataFrame]: ... @@ -552,6 +552,7 @@ def _csv_buffer_to_frame( raise ParameterCollisionError( "cannot specify columns in both `schema_overrides` and `read_csv_options['dtypes']`" ) + read_csv_options = read_csv_options.copy() read_csv_options["dtypes"] = {**csv_dtypes, **schema_overrides} # otherwise rewind the buffer and parse as csv diff --git a/py-polars/tests/unit/io/test_spreadsheet.py b/py-polars/tests/unit/io/test_spreadsheet.py index 79ad6300fd7e..94868f0e467d 100644 --- a/py-polars/tests/unit/io/test_spreadsheet.py +++ b/py-polars/tests/unit/io/test_spreadsheet.py @@ -1,6 +1,7 @@ from __future__ import annotations import warnings +from collections import OrderedDict from datetime import date, datetime from io import BytesIO from typing import TYPE_CHECKING, Any, Callable, Literal @@ -305,6 +306,22 @@ def test_schema_overrides(path_xlsx: Path, path_xlsb: Path, path_ods: Path) -> N read_csv_options={"dtypes": {"cardinality": pl.Int32}}, ) + # read multiple sheets in conjunction with 'schema_overrides' + # (note: reading the same sheet twice simulates the issue in #11850) + overrides = OrderedDict( + [ + ("cardinality", pl.UInt32), + ("rows_by_key", pl.Float32), + ("iter_groups", pl.Float64), + ] + ) + df = pl.read_excel( # type: ignore[call-overload] + path_xlsx, + sheet_name=["test4", "test4"], + schema_overrides=overrides, + ) + assert df["test4"].schema == overrides + def test_unsupported_engine() -> None: with pytest.raises(NotImplementedError): From a75eacae6efe1eef220a4c070a58591067886fd9 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Sat, 21 Oct 2023 09:37:52 +0400 Subject: [PATCH 082/103] docs(python): add missing 'diagonal_relaxed' to `pl.concat` "how" param docstring signature (#11909) --- py-polars/polars/functions/eager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py-polars/polars/functions/eager.py b/py-polars/polars/functions/eager.py index c953991eaa09..6c225c37f47b 100644 --- a/py-polars/polars/functions/eager.py +++ b/py-polars/polars/functions/eager.py @@ -33,7 +33,7 @@ def concat( ---------- items DataFrames, LazyFrames, or Series to concatenate. - how : {'vertical', 'vertical_relaxed', 'diagonal', 'horizontal', 'align'} + how : {'vertical', 'vertical_relaxed', 'diagonal', 'diagonal_relaxed', 'horizontal', 'align'} Series only support the `vertical` strategy. LazyFrames do not support the `horizontal` strategy. @@ -125,7 +125,7 @@ def concat( │ 3 ┆ null ┆ 6 ┆ 8 │ └─────┴──────┴──────┴──────┘ - """ + """ # noqa: W505 # unpack/standardise (handles generator input) elems = list(items) From 96b465ef51f9dfa4427591dd22d5a60063ffcc45 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Sat, 21 Oct 2023 09:38:32 +0400 Subject: [PATCH 083/103] fix(python): address DataFrame construction error with lists of `numpy` arrays (#11905) --- py-polars/polars/utils/_construction.py | 6 +++++- py-polars/tests/unit/test_constructors.py | 4 ++++ py-polars/tests/unit/test_errors.py | 2 +- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/py-polars/polars/utils/_construction.py b/py-polars/polars/utils/_construction.py index 9e38b3e4d38f..8baf75cd19fb 100644 --- a/py-polars/polars/utils/_construction.py +++ b/py-polars/polars/utils/_construction.py @@ -1024,7 +1024,11 @@ def _sequence_of_sequence_to_pydf( local_schema_override = ( include_unknowns(schema_overrides, column_names) if schema_overrides else {} ) - if column_names and first_element and len(first_element) != len(column_names): + if ( + column_names + and len(first_element) > 0 + and len(first_element) != len(column_names) + ): raise ShapeError("the row data does not match the number of columns") unpack_nested = False diff --git a/py-polars/tests/unit/test_constructors.py b/py-polars/tests/unit/test_constructors.py index 82263ba09a81..a09eb469eeab 100644 --- a/py-polars/tests/unit/test_constructors.py +++ b/py-polars/tests/unit/test_constructors.py @@ -592,6 +592,10 @@ def test_init_ndarray(monkeypatch: Any) -> None: assert df.shape == (2, 1) assert df.rows() == [([0, 1, 2, 3, 4],), ([5, 6, 7, 8, 9],)] + test_rows = [(1, 2), (3, 4)] + df = pl.DataFrame([np.array(test_rows[0]), np.array(test_rows[1])], orient="row") + assert_frame_equal(df, pl.DataFrame(test_rows, orient="row")) + # numpy arrays containing NaN df0 = pl.DataFrame( data={"x": [1.0, 2.5, float("nan")], "y": [4.0, float("nan"), 6.5]}, diff --git a/py-polars/tests/unit/test_errors.py b/py-polars/tests/unit/test_errors.py index 059c63776c75..f45ee6445bf1 100644 --- a/py-polars/tests/unit/test_errors.py +++ b/py-polars/tests/unit/test_errors.py @@ -105,7 +105,7 @@ def test_string_numeric_comp_err() -> None: def test_panic_error() -> None: with pytest.raises( pl.PolarsPanicError, - match="""dimensions cannot be empty""", + match="dimensions cannot be empty", ): pl.Series("a", [1, 2, 3]).reshape(()) From 6155e7f13d6b7629dceffaccce433c02ef332a30 Mon Sep 17 00:00:00 2001 From: Marshall Date: Sat, 21 Oct 2023 01:44:13 -0400 Subject: [PATCH 084/103] feat(python): upcast int->float and date->datetime for certain Series comparisons (#11779) --- py-polars/polars/series/series.py | 35 +++++++++++++++------- py-polars/polars/utils/convert.py | 2 +- py-polars/tests/unit/series/test_series.py | 15 ++++++++++ 3 files changed, 41 insertions(+), 11 deletions(-) diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index e3ea688c33ee..5041b8991899 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -481,14 +481,30 @@ def _comp(self, other: Any, op: ComparisonOperator) -> Series: return self.clone() elif (other is False and op == "eq") or (other is True and op == "neq"): return ~self - - if isinstance(other, datetime) and self.dtype == Datetime: - time_zone = self.dtype.time_zone # type: ignore[union-attr] - if str(other.tzinfo) != str(time_zone): - raise TypeError( - f"Datetime time zone {other.tzinfo!r} does not match Series timezone {time_zone!r}" + elif isinstance(other, float) and self.dtype in INTEGER_DTYPES: + # require upcast when comparing int series to float value + self = self.cast(Float64) + f = get_ffi_func(op + "_<>", Float64, self._s) + assert f is not None + return self._from_pyseries(f(other)) + elif isinstance(other, datetime): + if self.dtype == Date: + # require upcast when comparing date series to datetime + self = self.cast(Datetime("us")) + time_unit = "us" + elif self.dtype == Datetime: + # Use local time zone info + time_zone = self.dtype.time_zone # type: ignore[union-attr] + if str(other.tzinfo) != str(time_zone): + raise TypeError( + f"Datetime time zone {other.tzinfo!r} does not match Series timezone {time_zone!r}" + ) + time_unit = self.dtype.time_unit # type: ignore[union-attr] + else: + raise ValueError( + f"cannot compare datetime.datetime to series of type {self.dtype}" ) - ts = _datetime_to_pl_timestamp(other, self.dtype.time_unit) # type: ignore[union-attr] + ts = _datetime_to_pl_timestamp(other, time_unit) # type: ignore[arg-type] f = get_ffi_func(op + "_<>", Int64, self._s) assert f is not None return self._from_pyseries(f(ts)) @@ -497,14 +513,13 @@ def _comp(self, other: Any, op: ComparisonOperator) -> Series: f = get_ffi_func(op + "_<>", Int64, self._s) assert f is not None return self._from_pyseries(f(d)) + elif self.dtype == Categorical and not isinstance(other, Series): + other = Series([other]) elif isinstance(other, date) and self.dtype == Date: d = _date_to_pl_date(other) f = get_ffi_func(op + "_<>", Int32, self._s) assert f is not None return self._from_pyseries(f(d)) - elif self.dtype == Categorical and not isinstance(other, Series): - other = Series([other]) - if isinstance(other, Sequence) and not isinstance(other, str): other = Series("", other, dtype_if_empty=self.dtype) if isinstance(other, Series): diff --git a/py-polars/polars/utils/convert.py b/py-polars/polars/utils/convert.py index 85caae3266b3..ec380519b8f6 100644 --- a/py-polars/polars/utils/convert.py +++ b/py-polars/polars/utils/convert.py @@ -97,7 +97,7 @@ def _negate_duration(duration: str) -> str: return f"-{duration}" -def _datetime_to_pl_timestamp(dt: datetime, time_unit: TimeUnit | None) -> int: +def _datetime_to_pl_timestamp(dt: datetime, time_unit: TimeUnit) -> int: """Convert a python datetime to a timestamp in given time unit.""" if dt.tzinfo is None: # Make sure to use UTC rather than system time zone. diff --git a/py-polars/tests/unit/series/test_series.py b/py-polars/tests/unit/series/test_series.py index b830442250c2..c77cb2cecb5b 100644 --- a/py-polars/tests/unit/series/test_series.py +++ b/py-polars/tests/unit/series/test_series.py @@ -1568,6 +1568,21 @@ def test_comparisons_int_series_to_float() -> None: assert_series_equal(srs_int - True, pl.Series([0, 1, 2, 3])) +def test_comparisons_int_series_to_float_scalar() -> None: + srs_int = pl.Series([1, 2, 3, 4]) + + assert_series_equal(srs_int < 1.5, pl.Series([True, False, False, False])) + assert_series_equal(srs_int > 1.5, pl.Series([False, True, True, True])) + + +def test_comparisons_datetime_series_to_date_scalar() -> None: + srs_date = pl.Series([date(2023, 1, 1), date(2023, 1, 2), date(2023, 1, 3)]) + dt = datetime(2023, 1, 1, 12, 0, 0) + + assert_series_equal(srs_date < dt, pl.Series([True, False, False])) + assert_series_equal(srs_date > dt, pl.Series([False, True, True])) + + def test_comparisons_float_series_to_int() -> None: srs_float = pl.Series([1.0, 2.0, 3.0, 4.0]) From ff358caa23e594970e1f27698b355be2a36ffdd6 Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Sat, 21 Oct 2023 13:47:03 +0800 Subject: [PATCH 085/103] fix: predicate push-down remove predicate refers to alias for more branch (#11887) --- .../optimizer/predicate_pushdown/utils.rs | 47 +++++++++++-------- py-polars/tests/unit/test_predicates.py | 11 +++++ 2 files changed, 39 insertions(+), 19 deletions(-) diff --git a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/utils.rs b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/utils.rs index 58179b8aae8b..9d83ec2edfa4 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/utils.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/utils.rs @@ -213,7 +213,7 @@ fn rename_predicate_columns_due_to_aliased_projection( ) -> LoopBehavior { let projection_aexpr = expr_arena.get(projection_node); if let AExpr::Alias(_, alias_name) = projection_aexpr { - let alias_name = alias_name.as_ref(); + let alias_name = alias_name.clone(); let projection_leaves = aexpr_to_leaf_names(projection_node, expr_arena); // this means the leaf is a literal @@ -223,9 +223,10 @@ fn rename_predicate_columns_due_to_aliased_projection( // if this alias refers to one of the predicates in the upper nodes // we rename the column of the predicate before we push it downwards. - if let Some(predicate) = acc_predicates.remove(alias_name) { + if let Some(predicate) = acc_predicates.remove(&alias_name) { if projection_maybe_boundary { local_predicates.push(predicate); + remove_predicate_refers_to_alias(acc_predicates, local_predicates, &alias_name); return LoopBehavior::Continue; } if projection_leaves.len() == 1 { @@ -240,28 +241,36 @@ fn rename_predicate_columns_due_to_aliased_projection( // on this projected column so we do filter locally. local_predicates.push(predicate) } - } else { - // we could not find the alias name - // that could still mean that a predicate that is a complicated binary expression - // refers to the aliased name. If we find it, we remove it for now - // TODO! rename the expression. - let mut remove_names = vec![]; - for (composed_name, _) in acc_predicates.iter() { - if key_has_name(composed_name, alias_name) { - remove_names.push(composed_name.clone()); - break; - } - } - - for composed_name in remove_names { - let predicate = acc_predicates.remove(&composed_name).unwrap(); - local_predicates.push(predicate) - } } + + remove_predicate_refers_to_alias(acc_predicates, local_predicates, &alias_name); } LoopBehavior::Nothing } +/// we could not find the alias name +/// that could still mean that a predicate that is a complicated binary expression +/// refers to the aliased name. If we find it, we remove it for now +/// TODO! rename the expression. +fn remove_predicate_refers_to_alias( + acc_predicates: &mut PlHashMap, Node>, + local_predicates: &mut Vec, + alias_name: &str, +) { + let mut remove_names = vec![]; + for (composed_name, _) in acc_predicates.iter() { + if key_has_name(composed_name, alias_name) { + remove_names.push(composed_name.clone()); + break; + } + } + + for composed_name in remove_names { + let predicate = acc_predicates.remove(&composed_name).unwrap(); + local_predicates.push(predicate) + } +} + /// Implementation for both Hstack and Projection pub(super) fn rewrite_projection_node( expr_arena: &mut Arena, diff --git a/py-polars/tests/unit/test_predicates.py b/py-polars/tests/unit/test_predicates.py index 2af6376d05ef..41e2207e844c 100644 --- a/py-polars/tests/unit/test_predicates.py +++ b/py-polars/tests/unit/test_predicates.py @@ -192,3 +192,14 @@ def test_predicate_pushdown_group_by_keys() -> None: .filter(pl.col("group") == 1) .explain() ) + + +def test_no_predicate_push_down_with_cast_and_alias_11883() -> None: + df = pl.DataFrame({"a": [1, 2, 3]}) + out = ( + df.lazy() + .select(pl.col("a").cast(pl.Int64).alias("b")) + .filter(pl.col("b") == 1) + .filter((pl.col("b") >= 1) & (pl.col("b") < 1)) + ) + assert 'SELECTION: "None"' in out.explain(predicate_pushdown=True) From 6e8ce9c391b0ddec16259c83d96368c61ecf42c5 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Sat, 21 Oct 2023 08:42:53 +0200 Subject: [PATCH 086/103] fix(python): set null_count on categorical append (#11914) --- .../chunked_array/logical/categorical/from.rs | 4 +- .../chunked_array/logical/categorical/mod.rs | 56 ++++++++-------- .../logical/categorical/ops/append.rs | 22 +++++-- .../logical/categorical/ops/unique.rs | 16 ++--- .../logical/categorical/ops/zip.rs | 6 +- crates/polars-core/src/chunked_array/mod.rs | 8 +-- .../src/chunked_array/ops/chunkops.rs | 7 ++ .../src/chunked_array/ops/compare_inner.rs | 2 +- .../ops/sort/arg_sort_multiple.rs | 2 +- .../src/chunked_array/ops/sort/categorical.rs | 12 ++-- .../polars-core/src/frame/group_by/perfect.rs | 6 +- crates/polars-core/src/series/comparison.rs | 19 +++--- .../src/series/implementations/categorical.rs | 64 +++++++++---------- crates/polars-core/src/series/into.rs | 2 +- .../src/frame/join/hash_join/zip_outer.rs | 4 +- crates/polars-ops/src/series/ops/is_in.rs | 2 +- .../tests/unit/datatypes/test_categorical.py | 10 +++ 17 files changed, 136 insertions(+), 106 deletions(-) diff --git a/crates/polars-core/src/chunked_array/logical/categorical/from.rs b/crates/polars-core/src/chunked_array/logical/categorical/from.rs index e268144c9ddd..1c7e028fdf87 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/from.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/from.rs @@ -7,7 +7,7 @@ use crate::using_string_cache; impl From<&CategoricalChunked> for DictionaryArray { fn from(ca: &CategoricalChunked) -> Self { - let keys = ca.logical().rechunk(); + let keys = ca.physical().rechunk(); let keys = keys.downcast_iter().next().unwrap(); let map = &**ca.get_rev_map(); let dtype = ArrowDataType::Dictionary( @@ -42,7 +42,7 @@ impl From<&CategoricalChunked> for DictionaryArray { } impl From<&CategoricalChunked> for DictionaryArray { fn from(ca: &CategoricalChunked) -> Self { - let keys = ca.logical().rechunk(); + let keys = ca.physical().rechunk(); let keys = keys.downcast_iter().next().unwrap(); let map = &**ca.get_rev_map(); let dtype = ArrowDataType::Dictionary( diff --git a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs index d66e6318b5ef..d1671e6e7f05 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs @@ -23,7 +23,7 @@ bitflags! { #[derive(Clone)] pub struct CategoricalChunked { - logical: Logical, + physical: Logical, /// 1st bit: original local categorical /// meaning that n_unique is the same as the cat map length /// 2nd bit: use lexical sorting @@ -32,7 +32,7 @@ pub struct CategoricalChunked { impl CategoricalChunked { pub(crate) fn field(&self) -> Field { - let name = self.logical().name(); + let name = self.physical().name(); Field::new(name, self.dtype().clone()) } @@ -40,23 +40,29 @@ impl CategoricalChunked { self.len() == 0 } + #[inline] pub fn len(&self) -> usize { - self.logical.len() + self.physical.len() + } + + #[inline] + pub fn null_count(&self) -> usize { + self.physical.null_count() } pub fn name(&self) -> &str { - self.logical.name() + self.physical.name() } // TODO: Rename this /// Get a reference to the physical array (the categories). - pub fn logical(&self) -> &UInt32Chunked { - &self.logical + pub fn physical(&self) -> &UInt32Chunked { + &self.physical } /// Get a mutable reference to the physical array (the categories). - pub(crate) fn logical_mut(&mut self) -> &mut UInt32Chunked { - &mut self.logical + pub(crate) fn physical_mut(&mut self) -> &mut UInt32Chunked { + &mut self.physical } /// Convert a categorical column to its local representation. @@ -72,7 +78,7 @@ impl CategoricalChunked { // if all physical map keys are equal to their values, // we can skip the apply and only update the rev_map let local_ca = self - .logical() + .physical() .apply(|opt_v| opt_v.map(|v| *physical_map.get(&v).unwrap())); let mut out = @@ -84,12 +90,12 @@ impl CategoricalChunked { } pub(crate) fn get_flags(&self) -> Settings { - self.logical().get_flags() + self.physical().get_flags() } /// Set flags for the Chunked Array pub(crate) fn set_flags(&mut self, flags: Settings) { - self.logical_mut().set_flags(flags) + self.physical_mut().set_flags(flags) } /// Build a categorical from an original RevMap. That means that the number of categories in the `RevMapping == self.unique().len()`. @@ -105,7 +111,7 @@ impl CategoricalChunked { let mut bit_settings = BitSettings::default(); bit_settings.insert(BitSettings::ORIGINAL); Self { - logical, + physical: logical, bit_settings, } } @@ -135,7 +141,7 @@ impl CategoricalChunked { let mut logical = Logical::::new_logical::(idx); logical.2 = Some(DataType::Categorical(Some(rev_map))); Self { - logical, + physical: logical, bit_settings: Default::default(), } } @@ -143,14 +149,14 @@ impl CategoricalChunked { /// # Safety /// The existing index values must be in bounds of the new [`RevMapping`]. pub(crate) unsafe fn set_rev_map(&mut self, rev_map: Arc, keep_fast_unique: bool) { - self.logical.2 = Some(DataType::Categorical(Some(rev_map))); + self.physical.2 = Some(DataType::Categorical(Some(rev_map))); if !keep_fast_unique { self.set_fast_unique(false) } } pub(crate) fn can_fast_unique(&self) -> bool { - self.bit_settings.contains(BitSettings::ORIGINAL) && self.logical.chunks.len() == 1 + self.bit_settings.contains(BitSettings::ORIGINAL) && self.physical.chunks.len() == 1 } pub(crate) fn set_fast_unique(&mut self, toggle: bool) { @@ -163,7 +169,7 @@ impl CategoricalChunked { /// Get a reference to the mapping of categorical types to the string values. pub fn get_rev_map(&self) -> &Arc { - if let DataType::Categorical(Some(rev_map)) = &self.logical.2.as_ref().unwrap() { + if let DataType::Categorical(Some(rev_map)) = &self.physical.2.as_ref().unwrap() { rev_map } else { panic!("implementation error") @@ -172,7 +178,7 @@ impl CategoricalChunked { /// Create an `[Iterator]` that iterates over the `&str` values of the `[CategoricalChunked]`. pub fn iter_str(&self) -> CatIter<'_> { - let iter = self.logical().into_iter(); + let iter = self.physical().into_iter(); CatIter { rev: self.get_rev_map(), iter, @@ -182,7 +188,7 @@ impl CategoricalChunked { impl LogicalType for CategoricalChunked { fn dtype(&self) -> &DataType { - self.logical.2.as_ref().unwrap() + self.physical.2.as_ref().unwrap() } fn get_any_value(&self, i: usize) -> PolarsResult> { @@ -191,7 +197,7 @@ impl LogicalType for CategoricalChunked { } unsafe fn get_any_value_unchecked(&self, i: usize) -> AnyValue<'_> { - match self.logical.0.get_unchecked(i) { + match self.physical.0.get_unchecked(i) { Some(i) => AnyValue::Categorical(i, self.get_rev_map(), SyncPtr::new_null()), None => AnyValue::Null, } @@ -203,16 +209,16 @@ impl LogicalType for CategoricalChunked { let mapping = &**self.get_rev_map(); let mut builder = - Utf8ChunkedBuilder::new(self.logical.name(), self.len(), self.len() * 5); + Utf8ChunkedBuilder::new(self.physical.name(), self.len(), self.len() * 5); let f = |idx: u32| mapping.get(idx); - if !self.logical.has_validity() { - self.logical + if !self.physical.has_validity() { + self.physical .into_no_null_iter() .for_each(|idx| builder.append_value(f(idx))); } else { - self.logical.into_iter().for_each(|opt_idx| { + self.physical.into_iter().for_each(|opt_idx| { builder.append_option(opt_idx.map(f)); }); } @@ -222,13 +228,13 @@ impl LogicalType for CategoricalChunked { }, DataType::UInt32 => { let ca = unsafe { - UInt32Chunked::from_chunks(self.logical.name(), self.logical.chunks.clone()) + UInt32Chunked::from_chunks(self.physical.name(), self.physical.chunks.clone()) }; Ok(ca.into_series()) }, #[cfg(feature = "dtype-categorical")] DataType::Categorical(_) => Ok(self.clone().into_series()), - _ => self.logical.cast(dtype), + _ => self.physical.cast(dtype), } } } diff --git a/crates/polars-core/src/chunked_array/logical/categorical/ops/append.rs b/crates/polars-core/src/chunked_array/logical/categorical/ops/append.rs index 6385cfba3a1b..190c46dbf352 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/ops/append.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/ops/append.rs @@ -1,13 +1,23 @@ +use polars_error::constants::LENGTH_LIMIT_MSG; + use super::*; use crate::chunked_array::ops::append::new_chunks; use crate::series::IsSorted; impl CategoricalChunked { + fn set_lengths(&mut self, other: &Self) { + let length_self = &mut self.physical_mut().length; + *length_self = length_self + .checked_add(other.len() as IdxSize) + .expect(LENGTH_LIMIT_MSG); + self.physical_mut().null_count += other.null_count() as IdxSize; + } + pub fn append(&mut self, other: &Self) -> PolarsResult<()> { - if self.logical.null_count() == self.len() && other.logical.null_count() == other.len() { + if self.physical.null_count() == self.len() && other.physical.null_count() == other.len() { let len = self.len(); - self.logical_mut().length += other.len() as IdxSize; - new_chunks(&mut self.logical.chunks, &other.logical().chunks, len); + self.set_lengths(other); + new_chunks(&mut self.physical.chunks, &other.physical().chunks, len); return Ok(()); } let is_local_different_source = @@ -23,10 +33,10 @@ impl CategoricalChunked { let new_rev_map = self._merge_categorical_map(other)?; unsafe { self.set_rev_map(new_rev_map, false) }; - self.logical_mut().length += other.len() as IdxSize; - new_chunks(&mut self.logical.chunks, &other.logical().chunks, len); + self.set_lengths(other); + new_chunks(&mut self.physical.chunks, &other.physical().chunks, len); } - self.logical.set_sorted_flag(IsSorted::Not); + self.physical.set_sorted_flag(IsSorted::Not); Ok(()) } } diff --git a/crates/polars-core/src/chunked_array/logical/categorical/ops/unique.rs b/crates/polars-core/src/chunked_array/logical/categorical/ops/unique.rs index 9ac7d32ae749..42a448ecbbd8 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/ops/unique.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/ops/unique.rs @@ -7,10 +7,10 @@ impl CategoricalChunked { if self.can_fast_unique() { let ca = match &**cat_map { RevMapping::Local(a) => { - UInt32Chunked::from_iter_values(self.logical().name(), 0..(a.len() as u32)) + UInt32Chunked::from_iter_values(self.physical().name(), 0..(a.len() as u32)) }, RevMapping::Global(map, _, _) => { - UInt32Chunked::from_iter_values(self.logical().name(), map.keys().copied()) + UInt32Chunked::from_iter_values(self.physical().name(), map.keys().copied()) }, }; // safety: @@ -22,7 +22,7 @@ impl CategoricalChunked { Ok(out) } } else { - let ca = self.logical().unique()?; + let ca = self.physical().unique()?; // safety: // we only removed some indexes so we are still in bounds unsafe { @@ -38,14 +38,14 @@ impl CategoricalChunked { if self.can_fast_unique() { Ok(self.get_rev_map().len()) } else { - self.logical().n_unique() + self.physical().n_unique() } } pub fn value_counts(&self) -> PolarsResult { - let groups = self.logical().group_tuples(true, false).unwrap(); - let logical_values = unsafe { - self.logical() + let groups = self.physical().group_tuples(true, false).unwrap(); + let physical_values = unsafe { + self.physical() .clone() .into_series() .agg_first(&groups) @@ -55,7 +55,7 @@ impl CategoricalChunked { }; let mut values = self.clone(); - *values.logical_mut() = logical_values; + *values.physical_mut() = physical_values; let mut counts = groups.group_count(); counts.rename("counts"); diff --git a/crates/polars-core/src/chunked_array/logical/categorical/ops/zip.rs b/crates/polars-core/src/chunked_array/logical/categorical/ops/zip.rs index 7fcad0c73cbe..8ece943cb0fc 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/ops/zip.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/ops/zip.rs @@ -10,10 +10,10 @@ impl CategoricalChunked { RevMapping::Local(rev_map) => { // the logic for merging the rev maps will concatenate utf8 arrays // to make sure the indexes still make sense we need to offset the right hand side - self.logical() - .zip_with(mask, &(other.logical() + rev_map.len() as u32))? + self.physical() + .zip_with(mask, &(other.physical() + rev_map.len() as u32))? }, - _ => self.logical().zip_with(mask, other.logical())?, + _ => self.physical().zip_with(mask, other.physical())?, }; let new_state = self._merge_categorical_map(other)?; diff --git a/crates/polars-core/src/chunked_array/mod.rs b/crates/polars-core/src/chunked_array/mod.rs index 3a619c61d4a5..2ae1f7fd0c3a 100644 --- a/crates/polars-core/src/chunked_array/mod.rs +++ b/crates/polars-core/src/chunked_array/mod.rs @@ -315,12 +315,6 @@ impl ChunkedArray { self.chunks.len() == 1 && self.null_count() == 0 } - /// Count the null values. - #[inline] - pub fn null_count(&self) -> usize { - self.null_count as usize - } - /// Create a new [`ChunkedArray`] from self, where the chunks are replaced. /// /// # Safety @@ -836,7 +830,7 @@ pub(crate) mod test { let ca = Utf8Chunked::new("", &[Some("foo"), None, Some("bar"), Some("ham")]); let ca = ca.cast(&DataType::Categorical(None)).unwrap(); let ca = ca.categorical().unwrap(); - let v: Vec<_> = ca.logical().into_iter().collect(); + let v: Vec<_> = ca.physical().into_iter().collect(); assert_eq!(v, &[Some(0), None, Some(1), Some(2)]); } diff --git a/crates/polars-core/src/chunked_array/ops/chunkops.rs b/crates/polars-core/src/chunked_array/ops/chunkops.rs index a60502afb130..b4cc3b6c5ec2 100644 --- a/crates/polars-core/src/chunked_array/ops/chunkops.rs +++ b/crates/polars-core/src/chunked_array/ops/chunkops.rs @@ -55,10 +55,17 @@ fn slice( impl ChunkedArray { /// Get the length of the ChunkedArray + #[inline] pub fn len(&self) -> usize { self.length as usize } + /// Count the null values. + #[inline] + pub fn null_count(&self) -> usize { + self.null_count as usize + } + /// Check if ChunkedArray is empty. pub fn is_empty(&self) -> bool { self.len() == 0 diff --git a/crates/polars-core/src/chunked_array/ops/compare_inner.rs b/crates/polars-core/src/chunked_array/ops/compare_inner.rs index c62d291e8de2..7d710d92361a 100644 --- a/crates/polars-core/src/chunked_array/ops/compare_inner.rs +++ b/crates/polars-core/src/chunked_array/ops/compare_inner.rs @@ -155,7 +155,7 @@ impl<'a> GetInner for GlobalCategorical<'a> { #[cfg(feature = "dtype-categorical")] impl<'a> IntoPartialOrdInner<'a> for &'a CategoricalChunked { fn into_partial_ord_inner(self) -> Box { - let cats = self.logical(); + let cats = self.physical(); match &**self.get_rev_map() { RevMapping::Global(p1, p2, _) => Box::new(GlobalCategorical { p1, p2, cats }), RevMapping::Local(rev_map) => Box::new(LocalCategorical { rev_map, cats }), diff --git a/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs b/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs index 9f0d5a4ebcf2..d74dcfca91f2 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs @@ -75,7 +75,7 @@ pub fn _get_rows_encoded_compat_array(by: &Series) -> PolarsResult { if ca.uses_lexical_ordering() { by.to_arrow(0) } else { - ca.logical().chunks[0].clone() + ca.physical().chunks[0].clone() } }, _ => by.to_arrow(0), diff --git a/crates/polars-core/src/chunked_array/ops/sort/categorical.rs b/crates/polars-core/src/chunked_array/ops/sort/categorical.rs index 337dee580b2a..de25449f976e 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/categorical.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/categorical.rs @@ -31,7 +31,7 @@ impl CategoricalChunked { if self.uses_lexical_ordering() { let mut vals = self - .logical() + .physical() .into_no_null_iter() .zip(self.iter_str()) .collect_trusted::>(); @@ -57,7 +57,7 @@ impl CategoricalChunked { ) }; } - let cats = self.logical().sort_with(options); + let cats = self.physical().sort_with(options); // safety: // we only reordered the indexes so we are still in bounds unsafe { @@ -84,11 +84,11 @@ impl CategoricalChunked { self.name(), iters, options, - self.logical().null_count(), + self.physical().null_count(), self.len(), ) } else { - self.logical().arg_sort(options) + self.physical().arg_sort(options) } } @@ -96,7 +96,7 @@ impl CategoricalChunked { pub(crate) fn arg_sort_multiple(&self, options: &SortMultipleOptions) -> PolarsResult { if self.uses_lexical_ordering() { - args_validate(self.logical(), &options.other, &options.descending)?; + args_validate(self.physical(), &options.other, &options.descending)?; let mut count: IdxSize = 0; // we use bytes to save a monomorphisized str impl @@ -112,7 +112,7 @@ impl CategoricalChunked { arg_sort_multiple_impl(vals, options) } else { - self.logical().arg_sort_multiple(options) + self.physical().arg_sort_multiple(options) } } } diff --git a/crates/polars-core/src/frame/group_by/perfect.rs b/crates/polars-core/src/frame/group_by/perfect.rs index 9e44c883315d..2d050a157d1d 100644 --- a/crates/polars-core/src/frame/group_by/perfect.rs +++ b/crates/polars-core/src/frame/group_by/perfect.rs @@ -196,7 +196,7 @@ impl CategoricalChunked { if self.is_empty() { return GroupsProxy::Idx(GroupsIdx::new(vec![], vec![], true)); } - let cats = self.logical(); + let cats = self.physical(); let mut out = match &**rev_map { RevMapping::Local(cached) => { @@ -208,7 +208,7 @@ impl CategoricalChunked { // but on huge tables, this can be > 2x faster cats.group_tuples_perfect(cached.len() - 1, multithreaded, 0) } else { - self.logical().group_tuples(multithreaded, sorted).unwrap() + self.physical().group_tuples(multithreaded, sorted).unwrap() } }, RevMapping::Global(_mapping, _cached, _) => { @@ -216,7 +216,7 @@ impl CategoricalChunked { // the problem is that the global categories are not guaranteed packed together // so we might need to deref them first to local ones, but that might be more // expensive than just hashing (benchmark first) - self.logical().group_tuples(multithreaded, sorted).unwrap() + self.physical().group_tuples(multithreaded, sorted).unwrap() }, }; if sorted { diff --git a/crates/polars-core/src/series/comparison.rs b/crates/polars-core/src/series/comparison.rs index e274770e88fe..f65a6e502e64 100644 --- a/crates/polars-core/src/series/comparison.rs +++ b/crates/polars-core/src/series/comparison.rs @@ -126,7 +126,7 @@ impl ChunkCompare<&Series> for Series { #[cfg(feature = "dtype-categorical")] (Categorical(Some(rev_map_l)), Categorical(Some(rev_map_r)), _, _) => { if rev_map_l.same_src(rev_map_r) { - let rhs = rhs.categorical().unwrap().logical(); + let rhs = rhs.categorical().unwrap().physical(); // first check the rev-map if rhs.len() == 1 && rhs.null_count() == 0 { @@ -136,7 +136,7 @@ impl ChunkCompare<&Series> for Series { } } - self.categorical().unwrap().logical().equal(rhs) + self.categorical().unwrap().physical().equal(rhs) } else { polars_bail!( ComputeError: @@ -182,7 +182,7 @@ impl ChunkCompare<&Series> for Series { #[cfg(feature = "dtype-categorical")] (Categorical(Some(rev_map_l)), Categorical(Some(rev_map_r)), _, _) => { if rev_map_l.same_src(rev_map_r) { - let rhs = rhs.categorical().unwrap().logical(); + let rhs = rhs.categorical().unwrap().physical(); // first check the rev-map if rhs.len() == 1 && rhs.null_count() == 0 { @@ -192,7 +192,7 @@ impl ChunkCompare<&Series> for Series { } } - self.categorical().unwrap().logical().equal_missing(rhs) + self.categorical().unwrap().physical().equal_missing(rhs) } else { polars_bail!( ComputeError: @@ -238,7 +238,7 @@ impl ChunkCompare<&Series> for Series { #[cfg(feature = "dtype-categorical")] (Categorical(Some(rev_map_l)), Categorical(Some(rev_map_r)), _, _) => { if rev_map_l.same_src(rev_map_r) { - let rhs = rhs.categorical().unwrap().logical(); + let rhs = rhs.categorical().unwrap().physical(); // first check the rev-map if rhs.len() == 1 && rhs.null_count() == 0 { @@ -248,7 +248,7 @@ impl ChunkCompare<&Series> for Series { } } - self.categorical().unwrap().logical().not_equal(rhs) + self.categorical().unwrap().physical().not_equal(rhs) } else { polars_bail!( ComputeError: @@ -294,7 +294,7 @@ impl ChunkCompare<&Series> for Series { #[cfg(feature = "dtype-categorical")] (Categorical(Some(rev_map_l)), Categorical(Some(rev_map_r)), _, _) => { if rev_map_l.same_src(rev_map_r) { - let rhs = rhs.categorical().unwrap().logical(); + let rhs = rhs.categorical().unwrap().physical(); // first check the rev-map if rhs.len() == 1 && rhs.null_count() == 0 { @@ -304,7 +304,10 @@ impl ChunkCompare<&Series> for Series { } } - self.categorical().unwrap().logical().not_equal_missing(rhs) + self.categorical() + .unwrap() + .physical() + .not_equal_missing(rhs) } else { polars_bail!( ComputeError: diff --git a/crates/polars-core/src/series/implementations/categorical.rs b/crates/polars-core/src/series/implementations/categorical.rs index a3b096e54e4c..1544f5ee1f8c 100644 --- a/crates/polars-core/src/series/implementations/categorical.rs +++ b/crates/polars-core/src/series/implementations/categorical.rs @@ -35,7 +35,7 @@ impl SeriesWrap { where F: Fn(&UInt32Chunked) -> UInt32Chunked, { - let cats = apply(self.0.logical()); + let cats = apply(self.0.physical()); self.finish_with_state(keep_fast_unique, cats) } @@ -47,14 +47,14 @@ impl SeriesWrap { where F: for<'b> Fn(&'a UInt32Chunked) -> PolarsResult, { - let cats = apply(self.0.logical())?; + let cats = apply(self.0.physical())?; Ok(self.finish_with_state(keep_fast_unique, cats)) } } impl private::PrivateSeries for SeriesWrap { fn compute_len(&mut self) { - self.0.logical_mut().compute_len() + self.0.physical_mut().compute_len() } fn _field(&self) -> Cow { Cow::Owned(self.0.field()) @@ -78,7 +78,7 @@ impl private::PrivateSeries for SeriesWrap { } unsafe fn equal_element(&self, idx_self: usize, idx_other: usize, other: &Series) -> bool { - self.0.logical().equal_element(idx_self, idx_other, other) + self.0.physical().equal_element(idx_self, idx_other, other) } #[cfg(feature = "zip_with")] @@ -91,24 +91,24 @@ impl private::PrivateSeries for SeriesWrap { if self.0.uses_lexical_ordering() { (&self.0).into_partial_ord_inner() } else { - self.0.logical().into_partial_ord_inner() + self.0.physical().into_partial_ord_inner() } } fn vec_hash(&self, random_state: RandomState, buf: &mut Vec) -> PolarsResult<()> { - self.0.logical().vec_hash(random_state, buf)?; + self.0.physical().vec_hash(random_state, buf)?; Ok(()) } fn vec_hash_combine(&self, build_hasher: RandomState, hashes: &mut [u64]) -> PolarsResult<()> { - self.0.logical().vec_hash_combine(build_hasher, hashes)?; + self.0.physical().vec_hash_combine(build_hasher, hashes)?; Ok(()) } #[cfg(feature = "algorithm_group_by")] unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { // we cannot cast and dispatch as the inner type of the list would be incorrect - let list = self.0.logical().agg_list(groups); + let list = self.0.physical().agg_list(groups); let mut list = list.list().unwrap().clone(); list.to_logical(self.dtype().clone()); list.into_series() @@ -122,7 +122,7 @@ impl private::PrivateSeries for SeriesWrap { } #[cfg(not(feature = "performant"))] { - self.0.logical().group_tuples(multithreaded, sorted) + self.0.physical().group_tuples(multithreaded, sorted) } } @@ -133,24 +133,24 @@ impl private::PrivateSeries for SeriesWrap { impl SeriesTrait for SeriesWrap { fn rename(&mut self, name: &str) { - self.0.logical_mut().rename(name); + self.0.physical_mut().rename(name); } fn chunk_lengths(&self) -> ChunkIdIter { - self.0.logical().chunk_id() + self.0.physical().chunk_id() } fn name(&self) -> &str { - self.0.logical().name() + self.0.physical().name() } fn chunks(&self) -> &Vec { - self.0.logical().chunks() + self.0.physical().chunks() } unsafe fn chunks_mut(&mut self) -> &mut Vec { - self.0.logical_mut().chunks_mut() + self.0.physical_mut().chunks_mut() } fn shrink_to_fit(&mut self) { - self.0.logical_mut().shrink_to_fit() + self.0.physical_mut().shrink_to_fit() } fn slice(&self, offset: i64, length: usize) -> Series { @@ -166,7 +166,7 @@ impl SeriesTrait for SeriesWrap { fn extend(&mut self, other: &Series) -> PolarsResult<()> { polars_ensure!(self.0.dtype() == other.dtype(), extend); let other = other.categorical()?; - self.0.logical_mut().extend(other.logical()); + self.0.physical_mut().extend(other.physical()); let new_rev_map = self.0._merge_categorical_map(other)?; // SAFETY // rev_maps are merged @@ -181,13 +181,13 @@ impl SeriesTrait for SeriesWrap { #[cfg(feature = "chunked_ids")] unsafe fn _take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Series { - let cats = self.0.logical().take_chunked_unchecked(by, sorted); + let cats = self.0.physical().take_chunked_unchecked(by, sorted); self.finish_with_state(false, cats).into_series() } #[cfg(feature = "chunked_ids")] unsafe fn _take_opt_chunked_unchecked(&self, by: &[Option]) -> Series { - let cats = self.0.logical().take_opt_chunked_unchecked(by); + let cats = self.0.physical().take_opt_chunked_unchecked(by); self.finish_with_state(false, cats).into_series() } @@ -246,11 +246,11 @@ impl SeriesTrait for SeriesWrap { } fn null_count(&self) -> usize { - self.0.logical().null_count() + self.0.physical().null_count() } fn has_validity(&self) -> bool { - self.0.logical().has_validity() + self.0.physical().has_validity() } #[cfg(feature = "algorithm_group_by")] @@ -265,15 +265,15 @@ impl SeriesTrait for SeriesWrap { #[cfg(feature = "algorithm_group_by")] fn arg_unique(&self) -> PolarsResult { - self.0.logical().arg_unique() + self.0.physical().arg_unique() } fn is_null(&self) -> BooleanChunked { - self.0.logical().is_null() + self.0.physical().is_null() } fn is_not_null(&self) -> BooleanChunked { - self.0.logical().is_not_null() + self.0.physical().is_not_null() } fn reverse(&self) -> Series { @@ -281,7 +281,7 @@ impl SeriesTrait for SeriesWrap { } fn as_single_ptr(&mut self) -> PolarsResult { - self.0.logical_mut().as_single_ptr() + self.0.physical_mut().as_single_ptr() } fn shift(&self, periods: i64) -> Series { @@ -289,29 +289,29 @@ impl SeriesTrait for SeriesWrap { } fn _sum_as_series(&self) -> Series { - CategoricalChunked::full_null(self.0.logical().name(), 1).into_series() + CategoricalChunked::full_null(self.0.physical().name(), 1).into_series() } fn max_as_series(&self) -> Series { - CategoricalChunked::full_null(self.0.logical().name(), 1).into_series() + CategoricalChunked::full_null(self.0.physical().name(), 1).into_series() } fn min_as_series(&self) -> Series { - CategoricalChunked::full_null(self.0.logical().name(), 1).into_series() + CategoricalChunked::full_null(self.0.physical().name(), 1).into_series() } fn median_as_series(&self) -> Series { - CategoricalChunked::full_null(self.0.logical().name(), 1).into_series() + CategoricalChunked::full_null(self.0.physical().name(), 1).into_series() } fn var_as_series(&self, _ddof: u8) -> Series { - CategoricalChunked::full_null(self.0.logical().name(), 1).into_series() + CategoricalChunked::full_null(self.0.physical().name(), 1).into_series() } fn std_as_series(&self, _ddof: u8) -> Series { - CategoricalChunked::full_null(self.0.logical().name(), 1).into_series() + CategoricalChunked::full_null(self.0.physical().name(), 1).into_series() } fn quantile_as_series( &self, _quantile: f64, _interpol: QuantileInterpolOptions, ) -> PolarsResult { - Ok(CategoricalChunked::full_null(self.0.logical().name(), 1).into_series()) + Ok(CategoricalChunked::full_null(self.0.physical().name(), 1).into_series()) } fn clone_inner(&self) -> Arc { @@ -324,6 +324,6 @@ impl private::PrivateSeriesNumeric for SeriesWrap { false } fn bit_repr_small(&self) -> UInt32Chunked { - self.0.logical().clone() + self.0.physical().clone() } } diff --git a/crates/polars-core/src/series/into.rs b/crates/polars-core/src/series/into.rs index 3f97222d17a7..e718bca831fc 100644 --- a/crates/polars-core/src/series/into.rs +++ b/crates/polars-core/src/series/into.rs @@ -59,7 +59,7 @@ impl Series { #[cfg(feature = "dtype-categorical")] DataType::Categorical(_) => { let ca = self.categorical().unwrap(); - let arr = ca.logical().chunks()[chunk_idx].clone(); + let arr = ca.physical().chunks()[chunk_idx].clone(); // SAFETY: categoricals are always u32's. let cats = unsafe { UInt32Chunked::from_chunks("", vec![arr]) }; diff --git a/crates/polars-ops/src/frame/join/hash_join/zip_outer.rs b/crates/polars-ops/src/frame/join/hash_join/zip_outer.rs index ba23f32c2910..57cf99db752e 100644 --- a/crates/polars-ops/src/frame/join/hash_join/zip_outer.rs +++ b/crates/polars-ops/src/frame/join/hash_join/zip_outer.rs @@ -17,11 +17,11 @@ pub(crate) unsafe fn zip_outer_join_column( let new_rev_map = left_column ._merge_categorical_map(right_column.categorical().unwrap()) .unwrap(); - let left = left_column.logical(); + let left = left_column.physical(); let right = right_column .categorical() .unwrap() - .logical() + .physical() .clone() .into_series(); diff --git a/crates/polars-ops/src/series/ops/is_in.rs b/crates/polars-ops/src/series/ops/is_in.rs index cbc7f822eeda..9c1e7f1e4ee6 100644 --- a/crates/polars-ops/src/series/ops/is_in.rs +++ b/crates/polars-ops/src/series/ops/is_in.rs @@ -343,7 +343,7 @@ pub fn is_in(s: &Series, other: &Series) -> PolarsResult { use crate::frame::join::_check_categorical_src; _check_categorical_src(s.dtype(), other.dtype())?; let ca = s.categorical().unwrap(); - let ca = ca.logical(); + let ca = ca.physical(); is_in_numeric(ca, &other.to_physical_repr()) }, #[cfg(feature = "dtype-struct")] diff --git a/py-polars/tests/unit/datatypes/test_categorical.py b/py-polars/tests/unit/datatypes/test_categorical.py index f19a2732bf9c..be707cb207a4 100644 --- a/py-polars/tests/unit/datatypes/test_categorical.py +++ b/py-polars/tests/unit/datatypes/test_categorical.py @@ -427,3 +427,13 @@ def test_categorical_collect_11408() -> None: def test_categorical_nested_cast_unchecked() -> None: s = pl.Series("cat", [["cat"]]).cast(pl.List(pl.Categorical)) assert pl.Series([s]).to_list() == [[["cat"]]] + + +def test_categorical_update_lengths() -> None: + with pl.StringCache(): + s1 = pl.Series(["", ""], dtype=pl.Categorical) + s2 = pl.Series([None, "", ""], dtype=pl.Categorical) + + s = pl.concat([s1, s2], rechunk=False) + assert s.null_count() == 1 + assert s.len() == 5 From 3251703baba52c9febe9499a203c2a0d6ed587de Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Sat, 21 Oct 2023 10:11:19 +0200 Subject: [PATCH 087/103] refactor(rust): prepare for multiple files in a node (#11918) --- .../src/physical_plan/executors/scan/csv.rs | 2 +- .../src/physical_plan/executors/scan/ipc.rs | 2 +- .../physical_plan/executors/scan/parquet.rs | 2 +- .../src/physical_plan/planner/lp.rs | 60 +++++++++++-------- crates/polars-lazy/src/utils.rs | 8 ++- crates/polars-pipe/src/pipeline/convert.rs | 8 ++- crates/polars-plan/src/dot.rs | 30 ++++++---- crates/polars-plan/src/logical_plan/alp.rs | 6 +- .../polars-plan/src/logical_plan/builder.rs | 10 ++-- .../src/logical_plan/conversion.rs | 8 +-- crates/polars-plan/src/logical_plan/format.rs | 22 +++++-- crates/polars-plan/src/logical_plan/mod.rs | 2 +- .../src/logical_plan/optimizer/cse.rs | 4 +- .../logical_plan/optimizer/file_caching.rs | 22 +++---- .../optimizer/predicate_pushdown/mod.rs | 9 +-- .../optimizer/projection_pushdown/mod.rs | 4 +- .../optimizer/slice_pushdown_lp.rs | 8 +-- 17 files changed, 122 insertions(+), 85 deletions(-) diff --git a/crates/polars-lazy/src/physical_plan/executors/scan/csv.rs b/crates/polars-lazy/src/physical_plan/executors/scan/csv.rs index 8542432b31bc..bac591b84f86 100644 --- a/crates/polars-lazy/src/physical_plan/executors/scan/csv.rs +++ b/crates/polars-lazy/src/physical_plan/executors/scan/csv.rs @@ -54,7 +54,7 @@ impl CsvExec { impl Executor for CsvExec { fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult { let finger_print = FileFingerPrint { - path: self.path.clone(), + paths: Arc::new([self.path.clone()]), predicate: self .predicate .as_ref() diff --git a/crates/polars-lazy/src/physical_plan/executors/scan/ipc.rs b/crates/polars-lazy/src/physical_plan/executors/scan/ipc.rs index e5ee49c06a16..5256252d3a5d 100644 --- a/crates/polars-lazy/src/physical_plan/executors/scan/ipc.rs +++ b/crates/polars-lazy/src/physical_plan/executors/scan/ipc.rs @@ -34,7 +34,7 @@ impl IpcExec { impl Executor for IpcExec { fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult { let finger_print = FileFingerPrint { - path: self.path.clone(), + paths: Arc::new([self.path.clone()]), predicate: self .predicate .as_ref() diff --git a/crates/polars-lazy/src/physical_plan/executors/scan/parquet.rs b/crates/polars-lazy/src/physical_plan/executors/scan/parquet.rs index 3579bd9de004..9f99c8580870 100644 --- a/crates/polars-lazy/src/physical_plan/executors/scan/parquet.rs +++ b/crates/polars-lazy/src/physical_plan/executors/scan/parquet.rs @@ -100,7 +100,7 @@ impl ParquetExec { impl Executor for ParquetExec { fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult { let finger_print = FileFingerPrint { - path: self.path.clone(), + paths: Arc::new([self.path.clone()]), predicate: self .predicate .as_ref() diff --git a/crates/polars-lazy/src/physical_plan/planner/lp.rs b/crates/polars-lazy/src/physical_plan/planner/lp.rs index 45af5e1f3a11..71fc599be227 100644 --- a/crates/polars-lazy/src/physical_plan/planner/lp.rs +++ b/crates/polars-lazy/src/physical_plan/planner/lp.rs @@ -189,7 +189,7 @@ pub fn create_physical_plan( ))) }, Scan { - path, + paths, file_info, output_schema, scan_type, @@ -213,35 +213,47 @@ pub fn create_physical_plan( #[cfg(feature = "csv")] FileScan::Csv { options: csv_options, - } => Ok(Box::new(executors::CsvExec { - path, - schema: file_info.schema, - options: csv_options, - predicate, - file_options, - })), + } => { + assert_eq!(paths.len(), 1); + let path = paths[0].clone(); + Ok(Box::new(executors::CsvExec { + path, + schema: file_info.schema, + options: csv_options, + predicate, + file_options, + })) + }, #[cfg(feature = "ipc")] - FileScan::Ipc { options } => Ok(Box::new(executors::IpcExec { - path, - schema: file_info.schema, - predicate, - options, - file_options, - })), + FileScan::Ipc { options } => { + assert_eq!(paths.len(), 1); + let path = paths[0].clone(); + Ok(Box::new(executors::IpcExec { + path, + schema: file_info.schema, + predicate, + options, + file_options, + })) + }, #[cfg(feature = "parquet")] FileScan::Parquet { options, cloud_options, metadata - } => Ok(Box::new(executors::ParquetExec::new( - path, - file_info, - predicate, - options, - cloud_options, - file_options, - metadata - ))), + } => { + assert_eq!(paths.len(), 1); + let path = paths[0].clone(); + Ok(Box::new(executors::ParquetExec::new( + path, + file_info, + predicate, + options, + cloud_options, + file_options, + metadata + ))) + }, FileScan::Anonymous { function, .. diff --git a/crates/polars-lazy/src/utils.rs b/crates/polars-lazy/src/utils.rs index e8fa1ed4df79..fac410b109fb 100644 --- a/crates/polars-lazy/src/utils.rs +++ b/crates/polars-lazy/src/utils.rs @@ -6,13 +6,15 @@ use polars_plan::prelude::*; /// Get a set of the data source paths in this LogicalPlan pub(crate) fn agg_source_paths( root_lp: Node, - paths: &mut PlHashSet, + acc_paths: &mut PlHashSet, lp_arena: &Arena, ) { lp_arena.iter(root_lp).for_each(|(_, lp)| { use ALogicalPlan::*; - if let Scan { path, .. } = lp { - paths.insert(path.clone()); + if let Scan { paths, .. } = lp { + for path in paths.as_ref() { + acc_paths.insert(path.clone()); + } } }) } diff --git a/crates/polars-pipe/src/pipeline/convert.rs b/crates/polars-pipe/src/pipeline/convert.rs index 0c2d48ffa89e..04fb4c287e62 100644 --- a/crates/polars-pipe/src/pipeline/convert.rs +++ b/crates/polars-pipe/src/pipeline/convert.rs @@ -68,7 +68,7 @@ where Ok(Box::new(sources::DataFrameSource::from_df(df)) as Box) }, Scan { - path, + paths, file_info, file_options, predicate, @@ -87,8 +87,9 @@ where FileScan::Csv { options: csv_options, } => { + assert_eq!(paths.len(), 1); let src = sources::CsvSource::new( - path, + paths[0].clone(), file_info.schema, csv_options, file_options, @@ -102,8 +103,9 @@ where cloud_options, metadata, } => { + assert_eq!(paths.len(), 1); let src = sources::ParquetSource::new( - path, + paths[0].clone(), parquet_options, cloud_options, metadata, diff --git a/crates/polars-plan/src/dot.rs b/crates/polars-plan/src/dot.rs index a581eb7dafe6..0c3bb9ce9085 100644 --- a/crates/polars-plan/src/dot.rs +++ b/crates/polars-plan/src/dot.rs @@ -1,5 +1,6 @@ +use std::borrow::Cow; use std::fmt::{Display, Write}; -use std::path::Path; +use std::path::PathBuf; use polars_core::prelude::*; @@ -150,9 +151,9 @@ impl LogicalPlan { count, } => { let fmt = if *count == usize::MAX { - "CACHE".to_string() + Cow::Borrowed("CACHE") } else { - format!("CACHE: {}times", *count) + Cow::Owned(format!("CACHE: {}times", *count)) }; let current_node = DotNode { branch: *cache_id, @@ -181,7 +182,7 @@ impl LogicalPlan { acc_str, prev_node, "PYTHON", - Path::new(""), + &[], options.with_columns.as_ref().map(|s| s.as_slice()), options.schema.len(), &options.predicate, @@ -312,7 +313,7 @@ impl LogicalPlan { } }, Scan { - path, + paths, file_info, predicate, scan_type, @@ -324,7 +325,7 @@ impl LogicalPlan { acc_str, prev_node, name, - path.as_ref(), + paths.as_ref(), options.with_columns.as_ref().map(|cols| cols.as_slice()), file_info.schema.len(), predicate, @@ -409,7 +410,7 @@ impl LogicalPlan { acc_str: &mut String, prev_node: DotNode, name: &str, - path: &Path, + path: &[PathBuf], with_columns: Option<&[String]>, total_columns: usize, predicate: &Option

, @@ -422,13 +423,20 @@ impl LogicalPlan { n_columns_fmt = format!("{}", columns.len()); } + let fmt = if path.len() == 1 { + path[0].to_string_lossy() + } else { + Cow::Owned(format!( + "{} files: first file: {}", + path.len(), + path[0].to_string_lossy() + )) + }; + let pred = fmt_predicate(predicate.as_ref()); let fmt = format!( "{name} SCAN {};\nπ {}/{};\nσ {}", - path.to_string_lossy(), - n_columns_fmt, - total_columns, - pred, + fmt, n_columns_fmt, total_columns, pred, ); let current_node = DotNode { branch, diff --git a/crates/polars-plan/src/logical_plan/alp.rs b/crates/polars-plan/src/logical_plan/alp.rs index d6a96e2394a1..1c63851a0844 100644 --- a/crates/polars-plan/src/logical_plan/alp.rs +++ b/crates/polars-plan/src/logical_plan/alp.rs @@ -30,7 +30,7 @@ pub enum ALogicalPlan { predicate: Node, }, Scan { - path: PathBuf, + paths: Arc<[PathBuf]>, file_info: FileInfo, predicate: Option, /// schema of the projected file @@ -293,7 +293,7 @@ impl ALogicalPlan { options: *options, }, Scan { - path, + paths, file_info, output_schema, predicate, @@ -305,7 +305,7 @@ impl ALogicalPlan { new_predicate = exprs.pop() } Scan { - path: path.clone(), + paths: paths.clone(), file_info: file_info.clone(), output_schema: output_schema.clone(), file_options: options.clone(), diff --git a/crates/polars-plan/src/logical_plan/builder.rs b/crates/polars-plan/src/logical_plan/builder.rs index 6ccfd66afb93..b736c6cc8a98 100644 --- a/crates/polars-plan/src/logical_plan/builder.rs +++ b/crates/polars-plan/src/logical_plan/builder.rs @@ -119,7 +119,7 @@ impl LogicalPlanBuilder { }; Ok(LogicalPlan::Scan { - path: "".into(), + paths: Arc::new([]), file_info, predicate: None, file_options, @@ -201,7 +201,7 @@ impl LogicalPlanBuilder { hive_partitioning, }; Ok(LogicalPlan::Scan { - path, + paths: Arc::new([path]), file_info, file_options: options, predicate: None, @@ -253,7 +253,7 @@ impl LogicalPlanBuilder { hive_partitioning: false, }; Ok(LogicalPlan::Scan { - path, + paths: Arc::new([path]), file_info, file_options, predicate: None, @@ -299,6 +299,8 @@ impl LogicalPlanBuilder { } })?; + let paths = Arc::new([path]); + let mut magic_nr = [0u8; 2]; let res = file.read_exact(&mut magic_nr); if raise_if_empty { @@ -362,7 +364,7 @@ impl LogicalPlanBuilder { hive_partitioning: false, }; Ok(LogicalPlan::Scan { - path, + paths, file_info, file_options: options, predicate: None, diff --git a/crates/polars-plan/src/logical_plan/conversion.rs b/crates/polars-plan/src/logical_plan/conversion.rs index f1910f2be2a9..2aecd7a693fc 100644 --- a/crates/polars-plan/src/logical_plan/conversion.rs +++ b/crates/polars-plan/src/logical_plan/conversion.rs @@ -168,13 +168,13 @@ pub fn to_alp( let v = match lp { LogicalPlan::Scan { file_info, - path, + paths, predicate, scan_type, file_options: options, } => ALogicalPlan::Scan { file_info, - path, + paths, output_schema: None, predicate: predicate.map(|expr| to_aexpr(expr, expr_arena)), scan_type, @@ -597,14 +597,14 @@ impl ALogicalPlan { }; match lp { ALogicalPlan::Scan { - path, + paths, file_info, predicate, scan_type, output_schema: _, file_options: options, } => LogicalPlan::Scan { - path, + paths, file_info, predicate: predicate.map(|n| node_to_expr(n, expr_arena)), scan_type, diff --git a/crates/polars-plan/src/logical_plan/format.rs b/crates/polars-plan/src/logical_plan/format.rs index ae7e4e48efd6..3f6631163716 100644 --- a/crates/polars-plan/src/logical_plan/format.rs +++ b/crates/polars-plan/src/logical_plan/format.rs @@ -1,7 +1,7 @@ use std::borrow::Cow; use std::fmt; use std::fmt::{Debug, Display, Formatter, Write}; -use std::path::Path; +use std::path::PathBuf; use crate::prelude::*; @@ -9,7 +9,7 @@ use crate::prelude::*; fn write_scan( f: &mut Formatter, name: &str, - path: &Path, + path: &[PathBuf], indent: usize, n_columns: i64, total_columns: usize, @@ -19,7 +19,17 @@ fn write_scan( if indent != 0 { writeln!(f)?; } - write!(f, "{:indent$}{} SCAN {}", "", name, path.display())?; + let path_fmt = if path.len() == 1 { + path[0].to_string_lossy() + } else { + Cow::Owned(format!( + "{} files: first file: {}", + path.len(), + path[0].to_string_lossy() + )) + }; + + write!(f, "{:indent$}{} SCAN {}", "", name, path_fmt)?; if n_columns > 0 { write!( f, @@ -58,7 +68,7 @@ impl LogicalPlan { write_scan( f, "PYTHON", - Path::new(""), + &[], sub_indent, n_columns, total_columns, @@ -91,7 +101,7 @@ impl LogicalPlan { input._format(f, sub_indent) }, Scan { - path, + paths, file_info, predicate, scan_type, @@ -106,7 +116,7 @@ impl LogicalPlan { write_scan( f, scan_type.into(), - path, + paths, sub_indent, n_columns, file_info.schema.len(), diff --git a/crates/polars-plan/src/logical_plan/mod.rs b/crates/polars-plan/src/logical_plan/mod.rs index ecb6d1e8917b..d394327e41f4 100644 --- a/crates/polars-plan/src/logical_plan/mod.rs +++ b/crates/polars-plan/src/logical_plan/mod.rs @@ -156,7 +156,7 @@ pub enum LogicalPlan { count: usize, }, Scan { - path: PathBuf, + paths: Arc<[PathBuf]>, file_info: FileInfo, predicate: Option, file_options: FileScanOptions, diff --git a/crates/polars-plan/src/logical_plan/optimizer/cse.rs b/crates/polars-plan/src/logical_plan/optimizer/cse.rs index 11e4c45ca925..dce16147e60a 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/cse.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/cse.rs @@ -117,13 +117,13 @@ fn lp_node_equal(a: &ALogicalPlan, b: &ALogicalPlan, expr_arena: &Arena) ) => Arc::ptr_eq(left_df, right_df), ( Scan { - path: path_left, + paths: path_left, predicate: predicate_left, scan_type: scan_type_left, .. }, Scan { - path: path_right, + paths: path_right, predicate: predicate_right, scan_type: scan_type_right, .. diff --git a/crates/polars-plan/src/logical_plan/optimizer/file_caching.rs b/crates/polars-plan/src/logical_plan/optimizer/file_caching.rs index 92e47d12c303..23791d3dd6b0 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/file_caching.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/file_caching.rs @@ -1,4 +1,4 @@ -use std::path::{Path, PathBuf}; +use std::path::PathBuf; use std::sync::Arc; use polars_core::datatypes::PlHashMap; @@ -9,14 +9,14 @@ use crate::prelude::*; #[derive(Hash, Eq, PartialEq, Clone, Debug)] pub struct FileFingerPrint { - pub path: PathBuf, + pub paths: Arc<[PathBuf]>, pub predicate: Option, pub slice: (usize, Option), } #[allow(clippy::type_complexity)] fn process_with_columns( - path: &Path, + paths: &Arc<[PathBuf]>, with_columns: Option<&Vec>, predicate: Option, slice: (usize, Option), @@ -25,7 +25,7 @@ fn process_with_columns( ) { let cols = file_count_and_column_union .entry(FileFingerPrint { - path: path.into(), + paths: paths.clone(), predicate, slice, }) @@ -59,7 +59,7 @@ pub fn collect_fingerprints( use ALogicalPlan::*; match lp_arena.get(root) { Scan { - path, + paths, file_options: options, predicate, scan_type, @@ -68,7 +68,7 @@ pub fn collect_fingerprints( let slice = (scan_type.skip_rows(), options.n_rows); let predicate = predicate.map(|node| node_to_expr(node, expr_arena)); let fp = FileFingerPrint { - path: path.clone(), + paths: paths.clone(), predicate, slice, }; @@ -96,7 +96,7 @@ pub fn find_column_union_and_fingerprints( use ALogicalPlan::*; match lp_arena.get(root) { Scan { - path, + paths, file_options: options, predicate, file_info, @@ -106,7 +106,7 @@ pub fn find_column_union_and_fingerprints( let slice = (scan_type.skip_rows(), options.n_rows); let predicate = predicate.map(|node| node_to_expr(node, expr_arena)); process_with_columns( - path, + paths, options.with_columns.as_deref(), predicate, slice, @@ -204,7 +204,7 @@ impl FileCacher { let lp = lp_arena.take(root); match lp { ALogicalPlan::Scan { - path, + paths, file_info, predicate, output_schema, @@ -213,7 +213,7 @@ impl FileCacher { } => { let predicate_expr = predicate.map(|node| node_to_expr(node, expr_arena)); let finger_print = FileFingerPrint { - path, + paths, predicate: predicate_expr, slice: (scan_type.skip_rows(), options.n_rows), }; @@ -230,7 +230,7 @@ impl FileCacher { options.with_columns = with_columns; let lp = ALogicalPlan::Scan { - path: finger_print.path.clone(), + paths: finger_print.paths.clone(), file_info, output_schema, predicate, diff --git a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs index eec0ddaff940..615d3f6dcc8a 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs @@ -225,7 +225,7 @@ impl<'a> PredicatePushDown<'a> { Ok(lp) } Scan { - path, + paths, file_info, predicate, scan_type, @@ -235,12 +235,13 @@ impl<'a> PredicatePushDown<'a> { let local_predicates = partition_by_full_context(&mut acc_predicates, expr_arena); let predicate = predicate_at_scan(acc_predicates, predicate, expr_arena); + // TODO! this still assumes a single file. Fix hive partitioning for multiple files if let (Some(hive_part_stats), Some(predicate)) = (file_info.hive_parts.as_deref(), predicate) { if let Some(io_expr) = self.hive_partition_eval.unwrap()(predicate, expr_arena) { if let Some(stats_evaluator) = io_expr.as_stats_evaluator() { if !stats_evaluator.should_read(hive_part_stats.get_statistics())? { if self.verbose { - eprintln!("hive partitioning: skipped: {}", path.display()) + eprintln!("hive partitioning: skipped: {}", paths[0].display()) } let schema = output_schema.as_ref().unwrap_or(&file_info.schema); let df = DataFrame::from(schema.as_ref()); @@ -267,7 +268,7 @@ impl<'a> PredicatePushDown<'a> { let lp = if do_optimization { Scan { - path, + paths, file_info, predicate, file_options: options, @@ -276,7 +277,7 @@ impl<'a> PredicatePushDown<'a> { } } else { let lp = Scan { - path, + paths, file_info, predicate: None, file_options: options, diff --git a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/mod.rs b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/mod.rs index 3ff672683211..b1ed7963aab6 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/mod.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/mod.rs @@ -377,7 +377,7 @@ impl ProjectionPushDown { Ok(PythonScan { options, predicate }) }, Scan { - path, + paths, file_info, scan_type, predicate, @@ -421,7 +421,7 @@ impl ProjectionPushDown { } let lp = Scan { - path, + paths, file_info, output_schema, scan_type, diff --git a/crates/polars-plan/src/logical_plan/optimizer/slice_pushdown_lp.rs b/crates/polars-plan/src/logical_plan/optimizer/slice_pushdown_lp.rs index 66887f25ee62..63e5a9fdd2af 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/slice_pushdown_lp.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/slice_pushdown_lp.rs @@ -121,7 +121,7 @@ impl SlicePushDown { } #[cfg(feature = "csv")] (Scan { - path, + paths, file_info, output_schema, file_options: mut options, @@ -132,7 +132,7 @@ impl SlicePushDown { csv_options.skip_rows += state.offset as usize; let lp = Scan { - path, + paths, file_info, output_schema, scan_type: FileScan::Csv {options: csv_options}, @@ -143,7 +143,7 @@ impl SlicePushDown { }, // TODO! we currently skip slice pushdown if there is a predicate. (Scan { - path, + paths, file_info, output_schema, file_options: mut options, @@ -152,7 +152,7 @@ impl SlicePushDown { }, Some(state)) if state.offset == 0 && predicate.is_none() => { options.n_rows = Some(state.len as usize); let lp = Scan { - path, + paths, file_info, output_schema, predicate, From 8af94b0b8b9c46abacbcfcb31b162019d8ed8a77 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Sat, 21 Oct 2023 11:49:59 +0200 Subject: [PATCH 088/103] docs: fix some typos and add polars-business to curated plugin list (#11916) Co-authored-by: Stijn de Gooijer --- docs/user-guide/concepts/lazy-vs-eager.md | 2 +- docs/user-guide/expressions/plugins.md | 16 +++++++++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/docs/user-guide/concepts/lazy-vs-eager.md b/docs/user-guide/concepts/lazy-vs-eager.md index 1b84a0272aa5..987d07aa8807 100644 --- a/docs/user-guide/concepts/lazy-vs-eager.md +++ b/docs/user-guide/concepts/lazy-vs-eager.md @@ -6,7 +6,7 @@ In this example we use the eager API to: -1. Read the iris [dataset](https://archive.ics.uci.edu/ml/datasets/iris). +1. Read the iris [dataset](https://archive.ics.uci.edu/dataset/53/iris). 1. Filter the dataset based on sepal length 1. Calculate the mean of the sepal width per species diff --git a/docs/user-guide/expressions/plugins.md b/docs/user-guide/expressions/plugins.md index bc39ecfb7eca..3c40ef2b8cf4 100644 --- a/docs/user-guide/expressions/plugins.md +++ b/docs/user-guide/expressions/plugins.md @@ -5,7 +5,7 @@ and register that as an expression into the polars library. The polars engine wi and your expression will run almost as fast as native expressions. Note that this works without any interference of python and thus no GIL contention. -They will benefit from the same benefits default expression have: +They will benefit from the same benefits default expressions have: - Optimization - Parallelism @@ -18,8 +18,8 @@ To get started we will see what is needed to create a custom expression. For our first expression we are going to create a pig latin converter. Pig latin is a silly language where in every word the first letter is removed, added to the back and finally "ay" is added. So the word "pig" would convert to "igpay". -We could of course already do that with expressions, e.g. `col(..) + col(..).str.slice(0, 1) + "ay"`, but a specialized -function for this would perform better and allows us to learn about the plugins. +We could of course already do that with expressions, e.g. `col("name").str.slice(1) + col("name").str.slice(0, 1) + "ay"`, +but a specialized function for this would perform better and allows us to learn about the plugins. ### Setting up @@ -37,9 +37,9 @@ crate-type = ["cdylib"] [dependencies] polars = { version = "*" } -pyo3 = { version = "0.20.0", features = ["extension-module"] } +pyo3 = { version = "*", features = ["extension-module"] } pyo3-polars = { version = "*", features = ["derive"] } -serde = { version = "1", features = ["derive"] } +serde = { version = "*", features = ["derive"] } ``` ### Writing the expression @@ -229,3 +229,9 @@ fn haversine(inputs: &[Series]) -> PolarsResult { ``` That's all you need to know to get started. Take a look at this [repo](https://github.com/pola-rs/pyo3-polars/tree/main/example/derive_expression) to see how this all fits together. + +## Community plugins + +Here is a curated (non-exhaustive) list of community implemented plugins. + +- [polars-business](https://github.com/MarcoGorelli/polars-business) Polars extension offering utilities for business day operations From 5e96abd2df5a1bdcf429fb653caf93160fe49836 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Sat, 21 Oct 2023 11:50:33 +0200 Subject: [PATCH 089/103] feat: error instead of panic in unsupported sinks (#11915) --- .../src/physical_plan/planner/lp.rs | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/crates/polars-lazy/src/physical_plan/planner/lp.rs b/crates/polars-lazy/src/physical_plan/planner/lp.rs index 71fc599be227..1e2683d4bdd4 100644 --- a/crates/polars-lazy/src/physical_plan/planner/lp.rs +++ b/crates/polars-lazy/src/physical_plan/planner/lp.rs @@ -150,16 +150,20 @@ pub fn create_physical_plan( match logical_plan { #[cfg(feature = "python")] PythonScan { options, .. } => Ok(Box::new(executors::PythonScanExec { options })), - Sink { payload, .. } => { - match payload { - SinkType::Memory => panic!("Memory Sink not supported in the standard engine."), - SinkType::File{file_type, ..} => panic!( + Sink { payload, .. } => match payload { + SinkType::Memory => { + polars_bail!(InvalidOperation: "memory sink not supported in the standard engine") + }, + SinkType::File { file_type, .. } => { + polars_bail!(InvalidOperation: "sink_{file_type:?} not yet supported in standard engine. Use 'collect().write_parquet()'" - ), - #[cfg(feature = "cloud")] - SinkType::Cloud{..} => panic!("Cloud Sink not supported in standard engine.") - } - } + ) + }, + #[cfg(feature = "cloud")] + SinkType::Cloud { .. } => { + polars_bail!(InvalidOperation: "cloud sink not supported in standard engine.") + }, + }, Union { inputs, options } => { let inputs = inputs .into_iter() @@ -240,7 +244,7 @@ pub fn create_physical_plan( FileScan::Parquet { options, cloud_options, - metadata + metadata, } => { assert_eq!(paths.len(), 1); let path = paths[0].clone(); @@ -251,13 +255,10 @@ pub fn create_physical_plan( options, cloud_options, file_options, - metadata + metadata, ))) }, - FileScan::Anonymous { - function, - .. - } => { + FileScan::Anonymous { function, .. } => { Ok(Box::new(executors::AnonymousScanExec { function, predicate, @@ -266,8 +267,7 @@ pub fn create_physical_plan( output_schema, predicate_has_windows: state.has_windows, })) - - } + }, } }, Projection { From 492a3c137792ac551c34cd380d74eb94bb918e20 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Sat, 21 Oct 2023 15:55:02 +0200 Subject: [PATCH 090/103] fix(python): Fix `Array` data type initialization (#11907) --- .../python/user-guide/expressions/lists.py | 5 +- py-polars/polars/datatypes/classes.py | 53 +++++++++++++++---- py-polars/polars/datatypes/constants.py | 3 +- py-polars/polars/datatypes/convert.py | 3 +- py-polars/polars/expr/array.py | 8 +-- py-polars/polars/series/array.py | 8 +-- py-polars/src/conversion.rs | 5 +- py-polars/tests/unit/datatypes/test_array.py | 37 ++++++++----- py-polars/tests/unit/datatypes/test_struct.py | 4 +- py-polars/tests/unit/namespaces/test_array.py | 8 +-- .../tests/unit/operations/test_explode.py | 2 +- 11 files changed, 94 insertions(+), 42 deletions(-) diff --git a/docs/src/python/user-guide/expressions/lists.py b/docs/src/python/user-guide/expressions/lists.py index 5703a01a5518..de4b97fc8d87 100644 --- a/docs/src/python/user-guide/expressions/lists.py +++ b/docs/src/python/user-guide/expressions/lists.py @@ -97,7 +97,10 @@ pl.Series("Array_1", [[1, 3], [2, 5]]), pl.Series("Array_2", [[1, 7, 3], [8, 1, 0]]), ], - schema={"Array_1": pl.Array(2, pl.Int64), "Array_2": pl.Array(3, pl.Int64)}, + schema={ + "Array_1": pl.Array(inner=pl.Int64, width=2), + "Array_2": pl.Array(inner=pl.Int64, width=3), + }, ) print(array_df) # --8<-- [end:array_df] diff --git a/py-polars/polars/datatypes/classes.py b/py-polars/polars/datatypes/classes.py index aa9d316c21d3..4e830cffe76e 100644 --- a/py-polars/polars/datatypes/classes.py +++ b/py-polars/polars/datatypes/classes.py @@ -456,18 +456,18 @@ class Unknown(DataType): class List(NestedType): - """Nested list/array type with variable length of inner lists.""" + """Variable length list type.""" inner: PolarsDataType | None = None def __init__(self, inner: PolarsDataType | PythonDataType): """ - Nested list/array type with variable length of inner lists. + Variable length list type. Parameters ---------- inner - The `DataType` of values within the list + The ``DataType`` of the values within each list. Examples -------- @@ -518,26 +518,31 @@ def __repr__(self) -> str: class Array(NestedType): - """Nested list/array type with fixed length of inner arrays.""" + """Fixed length list type.""" inner: PolarsDataType | None = None width: int - def __init__(self, width: int, inner: PolarsDataType | PythonDataType = Null): + def __init__( # noqa: D417 + self, + *args: Any, + width: int | None = None, + inner: PolarsDataType | PythonDataType | None = None, + ): """ - Nested list/array type with fixed length of inner arrays. + Fixed length list type. Parameters ---------- width - The fixed size length of the inner arrays. + The length of the arrays. inner - The `DataType` of values within the inner arrays + The ``DataType`` of the values within each array. Examples -------- >>> s = pl.Series( - ... "a", [[1, 2], [4, 3]], dtype=pl.Array(width=2, inner=pl.Int64) + ... "a", [[1, 2], [4, 3]], dtype=pl.Array(inner=pl.Int64, width=2) ... ) >>> s shape: (2,) @@ -548,6 +553,32 @@ def __init__(self, width: int, inner: PolarsDataType | PythonDataType = Null): ] """ + from polars.utils.deprecation import issue_deprecation_warning + + if args: + # TODO: When removing this deprecation, update the `to_object` + # implementation in py-polars/src/conversion.rs to use `call1` instead of + # `call` + issue_deprecation_warning( + "Parameters `inner` and `width` will change positions in the next breaking release." + " Use keyword arguments to keep current behavior and silence this warning.", + version="0.19.11", + ) + if len(args) == 1: + width = args[0] + else: + width, inner = args[:2] + if width is None: + raise TypeError("`width` must be specified when initializing an `Array`") + + if inner is None: + issue_deprecation_warning( + "The default value for the `inner` parameter of `Array` will be removed in the next breaking release." + " Pass `inner=pl.Null`to keep current behavior and silence this warning.", + version="0.19.11", + ) + inner = Null + self.width = width self.inner = polars.datatypes.py_type_to_dtype(inner) @@ -570,11 +601,11 @@ def __eq__(self, other: PolarsDataType) -> bool: # type: ignore[override] return False def __hash__(self) -> int: - return hash((self.__class__, self.inner)) + return hash((self.__class__, self.inner, self.width)) def __repr__(self) -> str: class_name = self.__class__.__name__ - return f"{class_name}({self.inner!r})" + return f"{class_name}({self.inner!r}, {self.width})" class Field: diff --git a/py-polars/polars/datatypes/constants.py b/py-polars/polars/datatypes/constants.py index 24219b7e15ba..f3f2efc56431 100644 --- a/py-polars/polars/datatypes/constants.py +++ b/py-polars/polars/datatypes/constants.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING from polars.datatypes import ( + Array, DataTypeGroup, Date, Datetime, @@ -75,7 +76,7 @@ FLOAT_DTYPES | INTEGER_DTYPES | frozenset([Decimal]) ) -NESTED_DTYPES: frozenset[PolarsDataType] = DataTypeGroup([List, Struct]) +NESTED_DTYPES: frozenset[PolarsDataType] = DataTypeGroup([List, Struct, Array]) # number of rows to scan by default when inferring datatypes N_INFER_DEFAULT = 100 diff --git a/py-polars/polars/datatypes/convert.py b/py-polars/polars/datatypes/convert.py index 7d5d61e45c7b..496ed957d427 100644 --- a/py-polars/polars/datatypes/convert.py +++ b/py-polars/polars/datatypes/convert.py @@ -21,6 +21,7 @@ ) from polars.datatypes import ( + Array, Binary, Boolean, Categorical, @@ -203,7 +204,7 @@ def unpack_dtypes( unpacked: set[PolarsDataType] = set() for tp in dtypes: - if isinstance(tp, List): + if isinstance(tp, (List, Array)): if include_compound: unpacked.add(tp) unpacked.update(unpack_dtypes(tp.inner, include_compound=include_compound)) diff --git a/py-polars/polars/expr/array.py b/py-polars/polars/expr/array.py index 1f80cee68899..d845fb7263e9 100644 --- a/py-polars/polars/expr/array.py +++ b/py-polars/polars/expr/array.py @@ -24,7 +24,7 @@ def min(self) -> Expr: -------- >>> df = pl.DataFrame( ... data={"a": [[1, 2], [4, 3]]}, - ... schema={"a": pl.Array(width=2, inner=pl.Int64)}, + ... schema={"a": pl.Array(inner=pl.Int64, width=2)}, ... ) >>> df.select(pl.col("a").arr.min()) shape: (2, 1) @@ -48,7 +48,7 @@ def max(self) -> Expr: -------- >>> df = pl.DataFrame( ... data={"a": [[1, 2], [4, 3]]}, - ... schema={"a": pl.Array(width=2, inner=pl.Int64)}, + ... schema={"a": pl.Array(inner=pl.Int64, width=2)}, ... ) >>> df.select(pl.col("a").arr.max()) shape: (2, 1) @@ -72,7 +72,7 @@ def sum(self) -> Expr: -------- >>> df = pl.DataFrame( ... data={"a": [[1, 2], [4, 3]]}, - ... schema={"a": pl.Array(width=2, inner=pl.Int64)}, + ... schema={"a": pl.Array(inner=pl.Int64, width=2)}, ... ) >>> df.select(pl.col("a").arr.sum()) shape: (2, 1) @@ -103,7 +103,7 @@ def unique(self, *, maintain_order: bool = False) -> Expr: ... { ... "a": [[1, 1, 2]], ... }, - ... schema_overrides={"a": pl.Array(width=3, inner=pl.Int64)}, + ... schema_overrides={"a": pl.Array(inner=pl.Int64, width=3)}, ... ) >>> df.select(pl.col("a").arr.unique()) shape: (1, 1) diff --git a/py-polars/polars/series/array.py b/py-polars/polars/series/array.py index 2ece17871ed5..5c61e4ab54f8 100644 --- a/py-polars/polars/series/array.py +++ b/py-polars/polars/series/array.py @@ -25,7 +25,7 @@ def min(self) -> Series: Examples -------- >>> s = pl.Series( - ... "a", [[1, 2], [4, 3]], dtype=pl.Array(width=2, inner=pl.Int64) + ... "a", [[1, 2], [4, 3]], dtype=pl.Array(inner=pl.Int64, width=2) ... ) >>> s.arr.min() shape: (2,) @@ -44,7 +44,7 @@ def max(self) -> Series: Examples -------- >>> s = pl.Series( - ... "a", [[1, 2], [4, 3]], dtype=pl.Array(width=2, inner=pl.Int64) + ... "a", [[1, 2], [4, 3]], dtype=pl.Array(inner=pl.Int64, width=2) ... ) >>> s.arr.max() shape: (2,) @@ -64,7 +64,7 @@ def sum(self) -> Series: -------- >>> df = pl.DataFrame( ... data={"a": [[1, 2], [4, 3]]}, - ... schema={"a": pl.Array(width=2, inner=pl.Int64)}, + ... schema={"a": pl.Array(inner=pl.Int64, width=2)}, ... ) >>> df.select(pl.col("a").arr.sum()) shape: (2, 1) @@ -94,7 +94,7 @@ def unique(self, *, maintain_order: bool = False) -> Series: ... { ... "a": [[1, 1, 2]], ... }, - ... schema_overrides={"a": pl.Array(width=3, inner=pl.Int64)}, + ... schema_overrides={"a": pl.Array(inner=pl.Int64, width=3)}, ... ) >>> df.select(pl.col("a").arr.unique()) shape: (1, 1) diff --git a/py-polars/src/conversion.rs b/py-polars/src/conversion.rs index 71871a48bfca..cdc0e9b69b8c 100644 --- a/py-polars/src/conversion.rs +++ b/py-polars/src/conversion.rs @@ -312,7 +312,10 @@ impl ToPyObject for Wrap { DataType::Array(inner, size) => { let inner = Wrap(*inner.clone()).to_object(py); let list_class = pl.getattr(intern!(py, "Array")).unwrap(); - list_class.call1((*size, inner)).unwrap().into() + let kwargs = PyDict::new(py); + kwargs.set_item("inner", inner).unwrap(); + kwargs.set_item("width", size).unwrap(); + list_class.call((), Some(kwargs)).unwrap().into() }, DataType::List(inner) => { let inner = Wrap(*inner.clone()).to_object(py); diff --git a/py-polars/tests/unit/datatypes/test_array.py b/py-polars/tests/unit/datatypes/test_array.py index a0eaeb3d7e0c..fc4064addbb6 100644 --- a/py-polars/tests/unit/datatypes/test_array.py +++ b/py-polars/tests/unit/datatypes/test_array.py @@ -9,7 +9,7 @@ def test_cast_list_array() -> None: payload = [[1, 2, 3], [4, 2, 3]] s = pl.Series(payload) - dtype = pl.Array(width=3, inner=pl.Int64) + dtype = pl.Array(inner=pl.Int64, width=3) out = s.cast(dtype) assert out.dtype == dtype assert out.to_list() == payload @@ -20,19 +20,19 @@ def test_cast_list_array() -> None: pl.ComputeError, match=r"incompatible offsets in source list", ): - s.cast(pl.Array(width=2, inner=pl.Int64)) + s.cast(pl.Array(inner=pl.Int64, width=2)) def test_array_construction() -> None: payload = [[1, 2, 3], [4, 2, 3]] - dtype = pl.Array(width=3, inner=pl.Int64) + dtype = pl.Array(inner=pl.Int64, width=3) s = pl.Series(payload, dtype=dtype) assert s.dtype == dtype assert s.to_list() == payload # inner type - dtype = pl.Array(2, pl.UInt8) + dtype = pl.Array(inner=pl.UInt8, width=2) payload = [[1, 2], [3, 4]] s = pl.Series(payload, dtype=dtype) assert s.dtype == dtype @@ -41,13 +41,13 @@ def test_array_construction() -> None: # create using schema df = pl.DataFrame( schema={ - "a": pl.Array(width=3, inner=pl.Float32), - "b": pl.Array(width=5, inner=pl.Datetime("ms")), + "a": pl.Array(inner=pl.Float32, width=3), + "b": pl.Array(inner=pl.Datetime("ms"), width=5), } ) assert df.dtypes == [ - pl.Array(width=3, inner=pl.Float32), - pl.Array(width=5, inner=pl.Datetime("ms")), + pl.Array(inner=pl.Float32, width=3), + pl.Array(inner=pl.Datetime("ms"), width=5), ] assert df.rows() == [] @@ -56,7 +56,9 @@ def test_array_in_group_by() -> None: df = pl.DataFrame( [ pl.Series("id", [1, 2]), - pl.Series("list", [[1, 2], [5, 5]], dtype=pl.Array(2, pl.UInt8)), + pl.Series( + "list", [[1, 2], [5, 5]], dtype=pl.Array(inner=pl.UInt8, width=2) + ), ] ) @@ -83,7 +85,7 @@ def test_array_in_group_by() -> None: def test_array_invalid_operation() -> None: s = pl.Series( [[1, 2], [8, 9]], - dtype=pl.Array(width=2, inner=pl.Int32), + dtype=pl.Array(inner=pl.Int32, width=2), ) with pytest.raises( InvalidOperationError, @@ -94,11 +96,22 @@ def test_array_invalid_operation() -> None: def test_array_concat() -> None: a_df = pl.DataFrame({"a": [[0, 1], [1, 0]]}).select( - pl.col("a").cast(pl.Array(width=2, inner=pl.Int32)) + pl.col("a").cast(pl.Array(inner=pl.Int32, width=2)) ) b_df = pl.DataFrame({"a": [[1, 1], [0, 0]]}).select( - pl.col("a").cast(pl.Array(width=2, inner=pl.Int32)) + pl.col("a").cast(pl.Array(inner=pl.Int32, width=2)) ) assert pl.concat([a_df, b_df]).to_dict(False) == { "a": [[0, 1], [1, 0], [1, 1], [0, 0]] } + + +def test_array_init_deprecation() -> None: + with pytest.deprecated_call(): + pl.Array(2) + with pytest.deprecated_call(): + pl.Array(2, pl.Utf8) + with pytest.deprecated_call(): + pl.Array(2, inner=pl.Utf8) + with pytest.deprecated_call(): + pl.Array(width=2) diff --git a/py-polars/tests/unit/datatypes/test_struct.py b/py-polars/tests/unit/datatypes/test_struct.py index 1bc4a7a4b4a8..542c6f3b1b4d 100644 --- a/py-polars/tests/unit/datatypes/test_struct.py +++ b/py-polars/tests/unit/datatypes/test_struct.py @@ -645,8 +645,8 @@ def test_empty_struct() -> None: pl.List, pl.List(pl.Null), pl.List(pl.Utf8), - pl.Array(32), - pl.Array(16, inner=pl.UInt8), + pl.Array(inner=pl.Null, width=32), + pl.Array(inner=pl.UInt8, width=16), pl.Struct, pl.Struct([pl.Field("", pl.Null)]), pl.Struct([pl.Field("x", pl.UInt32), pl.Field("y", pl.Float64)]), diff --git a/py-polars/tests/unit/namespaces/test_array.py b/py-polars/tests/unit/namespaces/test_array.py index ac69510cd8ed..cc20cba7feca 100644 --- a/py-polars/tests/unit/namespaces/test_array.py +++ b/py-polars/tests/unit/namespaces/test_array.py @@ -5,19 +5,19 @@ def test_arr_min_max() -> None: - s = pl.Series("a", [[1, 2], [4, 3]], dtype=pl.Array(width=2, inner=pl.Int64)) + s = pl.Series("a", [[1, 2], [4, 3]], dtype=pl.Array(inner=pl.Int64, width=2)) assert s.arr.max().to_list() == [2, 4] assert s.arr.min().to_list() == [1, 3] def test_arr_sum() -> None: - s = pl.Series("a", [[1, 2], [4, 3]], dtype=pl.Array(width=2, inner=pl.Int64)) + s = pl.Series("a", [[1, 2], [4, 3]], dtype=pl.Array(inner=pl.Int64, width=2)) assert s.arr.sum().to_list() == [3, 7] def test_arr_unique() -> None: df = pl.DataFrame( - {"a": pl.Series("a", [[1, 1], [4, 3]], dtype=pl.Array(width=2, inner=pl.Int64))} + {"a": pl.Series("a", [[1, 1], [4, 3]], dtype=pl.Array(inner=pl.Int64, width=2))} ) out = df.select(pl.col("a").arr.unique(maintain_order=True)) @@ -26,5 +26,5 @@ def test_arr_unique() -> None: def test_array_to_numpy() -> None: - s = pl.Series([[1, 2], [3, 4], [5, 6]], dtype=pl.Array(width=2, inner=pl.Int64)) + s = pl.Series([[1, 2], [3, 4], [5, 6]], dtype=pl.Array(inner=pl.Int64, width=2)) assert (s.to_numpy() == np.array([[1, 2], [3, 4], [5, 6]])).all() diff --git a/py-polars/tests/unit/operations/test_explode.py b/py-polars/tests/unit/operations/test_explode.py index 754a3e27bc1c..85b7bf892dba 100644 --- a/py-polars/tests/unit/operations/test_explode.py +++ b/py-polars/tests/unit/operations/test_explode.py @@ -309,7 +309,7 @@ def test_explode_inner_null() -> None: def test_explode_array() -> None: df = pl.LazyFrame( {"a": [[1, 2], [2, 3]], "b": [1, 2]}, - schema_overrides={"a": pl.Array(2, inner=pl.Int64)}, + schema_overrides={"a": pl.Array(inner=pl.Int64, width=2)}, ) expected = pl.DataFrame({"a": [1, 2, 2, 3], "b": [1, 1, 2, 2]}) for ex in ("a", ~cs.integer()): From 04357ef2d13915858a5f9e0aa46dba60ceeacc21 Mon Sep 17 00:00:00 2001 From: Laurynas Date: Sat, 21 Oct 2023 15:18:22 +0100 Subject: [PATCH 091/103] docs(python): Fix docstring for `diff` methods (#11921) --- py-polars/polars/expr/expr.py | 2 +- py-polars/polars/expr/list.py | 2 +- py-polars/polars/series/list.py | 2 +- py-polars/polars/series/series.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 2fb89de12607..05eb15530917 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -7409,7 +7409,7 @@ def rank( def diff(self, n: int = 1, null_behavior: NullBehavior = "ignore") -> Self: """ - Calculate the n-th discrete difference. + Calculate the first discrete difference between shifted items. Parameters ---------- diff --git a/py-polars/polars/expr/list.py b/py-polars/polars/expr/list.py index 249c17ab5cb4..88c4cb993402 100644 --- a/py-polars/polars/expr/list.py +++ b/py-polars/polars/expr/list.py @@ -664,7 +664,7 @@ def arg_max(self) -> Expr: def diff(self, n: int = 1, null_behavior: NullBehavior = "ignore") -> Expr: """ - Calculate the n-th discrete difference of every sublist. + Calculate the first discrete difference between shifted items of every sublist. Parameters ---------- diff --git a/py-polars/polars/series/list.py b/py-polars/polars/series/list.py index 3f4b11d257b5..6387362d04e9 100644 --- a/py-polars/polars/series/list.py +++ b/py-polars/polars/series/list.py @@ -347,7 +347,7 @@ def arg_max(self) -> Series: def diff(self, n: int = 1, null_behavior: NullBehavior = "ignore") -> Series: """ - Calculate the n-th discrete difference of every sublist. + Calculate the first discrete difference between shifted items of every sublist. Parameters ---------- diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 5041b8991899..90591799cdf2 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -6003,7 +6003,7 @@ def rank( def diff(self, n: int = 1, null_behavior: NullBehavior = "ignore") -> Series: """ - Calculate the n-th discrete difference. + Calculate the first discrete difference between shifted items. Parameters ---------- From 3386abd22f288d13e4290c43af3f31a28a5b8eca Mon Sep 17 00:00:00 2001 From: Brayan Jules Date: Tue, 15 Aug 2023 17:54:54 -0400 Subject: [PATCH 092/103] feat(rust): utf8 to temporal casting --- crates/polars-core/src/chunked_array/cast.rs | 31 ++++++++++++++++++++ py-polars/tests/unit/test_lazy.py | 2 +- py-polars/tests/unit/test_queries.py | 27 +++++++++++++++++ 3 files changed, 59 insertions(+), 1 deletion(-) diff --git a/crates/polars-core/src/chunked_array/cast.rs b/crates/polars-core/src/chunked_array/cast.rs index bf06636829f7..de0fdf2f8e61 100644 --- a/crates/polars-core/src/chunked_array/cast.rs +++ b/crates/polars-core/src/chunked_array/cast.rs @@ -7,6 +7,7 @@ use arrow::compute::cast::CastOptions; use crate::chunked_array::categorical::CategoricalChunkedBuilder; #[cfg(feature = "timezones")] use crate::chunked_array::temporal::validate_time_zone; +use crate::prelude::DataType::Datetime; use crate::prelude::*; pub(crate) fn cast_chunks( @@ -195,6 +196,36 @@ impl ChunkCast for Utf8Chunked { polars_bail!(ComputeError: "expected 'precision' or 'scale' when casting to Decimal") }, }, + #[cfg(feature = "dtype-date")] + DataType::Date => { + let result = cast_chunks(&self.chunks, &data_type, true)?; + let out = Series::try_from((self.name(), result))?; + Ok(out) + }, + #[cfg(feature = "dtype-datetime")] + DataType::Datetime(tu, tz) => { + let out = match tz { + #[cfg(feature = "timezones")] + Some(tz) => { + validate_time_zone(tz)?; + let result = cast_chunks( + &self.chunks, + &Datetime(TimeUnit::Nanoseconds, Some(tz.clone())), + true, + )?; + Series::try_from((self.name(), result)) + }, + _ => { + let result = cast_chunks( + &self.chunks, + &Datetime(TimeUnit::Nanoseconds, None), + true, + )?; + Series::try_from((self.name(), result)) + }, + }; + out + }, _ => cast_impl(self.name(), &self.chunks, data_type), } } diff --git a/py-polars/tests/unit/test_lazy.py b/py-polars/tests/unit/test_lazy.py index 7fe0dae5f371..30b30740e215 100644 --- a/py-polars/tests/unit/test_lazy.py +++ b/py-polars/tests/unit/test_lazy.py @@ -1375,7 +1375,7 @@ def test_quadratic_behavior_4736() -> None: ldf.select(reduce(add, (pl.col(fld) for fld in ldf.columns))) -@pytest.mark.parametrize("input_dtype", [pl.Utf8, pl.Int64, pl.Float64]) +@pytest.mark.parametrize("input_dtype", [pl.Int64, pl.Float64]) def test_from_epoch(input_dtype: pl.PolarsDataType) -> None: ldf = pl.LazyFrame( [ diff --git a/py-polars/tests/unit/test_queries.py b/py-polars/tests/unit/test_queries.py index a63c2da3ed24..d239d23f65fe 100644 --- a/py-polars/tests/unit/test_queries.py +++ b/py-polars/tests/unit/test_queries.py @@ -370,3 +370,30 @@ def test_shift_drop_nulls_10875() -> None: assert pl.LazyFrame({"a": [1, 2, 3]}).shift(1).drop_nulls().collect()[ "a" ].to_list() == [1, 2] + + +def test_utf8_date() -> None: + df = pl.DataFrame({"x1": ["2021-01-01"]}).with_columns( + **{"x1-date": pl.col("x1").cast(pl.Date)} + ) + out = df.select(pl.col("x1-date")) + assert out.shape == (1, 1) + assert out.dtypes == [pl.Date] + + +def test_utf8_datetime() -> None: + df = pl.DataFrame( + {"x1": ["2021-12-19T16:39:57-02:00", "2022-12-19T16:39:57"]} + ).with_columns( + **{ + "x1-datetime-ns": pl.col("x1").cast(pl.Datetime(time_unit="ns")), + "x1-datetime-ms": pl.col("x1").cast(pl.Datetime(time_unit="ms")), + "x1-datetime-us": pl.col("x1").cast(pl.Datetime(time_unit="us")), + } + ) + + out = df.select( + pl.col("x1-datetime-ns"), pl.col("x1-datetime-ms"), pl.col("x1-datetime-us") + ) + assert out.shape == (2, 3) + assert out.dtypes == [pl.Datetime, pl.Datetime, pl.Datetime] From bce21ec2730d705164fd17d04d9f0f9aebd0f230 Mon Sep 17 00:00:00 2001 From: Brayan Jules Date: Tue, 15 Aug 2023 17:54:54 -0400 Subject: [PATCH 093/103] feat(rust): utf8 to temporal casting --- crates/polars-core/src/chunked_array/cast.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/polars-core/src/chunked_array/cast.rs b/crates/polars-core/src/chunked_array/cast.rs index de0fdf2f8e61..a3184ff7fca3 100644 --- a/crates/polars-core/src/chunked_array/cast.rs +++ b/crates/polars-core/src/chunked_array/cast.rs @@ -198,12 +198,12 @@ impl ChunkCast for Utf8Chunked { }, #[cfg(feature = "dtype-date")] DataType::Date => { - let result = cast_chunks(&self.chunks, &data_type, true)?; + let result = cast_chunks(&self.chunks, data_type, true)?; let out = Series::try_from((self.name(), result))?; Ok(out) }, #[cfg(feature = "dtype-datetime")] - DataType::Datetime(tu, tz) => { + DataType::Datetime(_tu, tz) => { let out = match tz { #[cfg(feature = "timezones")] Some(tz) => { From 4b08b6b3565c803cc9e10e289e77de5c90dfbe42 Mon Sep 17 00:00:00 2001 From: Brayan Jules Date: Mon, 9 Oct 2023 23:37:07 -0300 Subject: [PATCH 094/103] feat: utf8 to timestamp/date casting support Support for different timeunits added in nano-arrow --- crates/polars-arrow/src/compute/cast/mod.rs | 36 +++++++++++++-- .../polars-arrow/src/compute/cast/utf8_to.rs | 32 +++++++------ .../polars-arrow/src/temporal_conversions.rs | 46 +++++++------------ crates/polars-core/src/chunked_array/cast.rs | 11 ++--- py-polars/tests/unit/test_queries.py | 20 +++++++- 5 files changed, 90 insertions(+), 55 deletions(-) diff --git a/crates/polars-arrow/src/compute/cast/mod.rs b/crates/polars-arrow/src/compute/cast/mod.rs index bbabbe279439..b6778116b091 100644 --- a/crates/polars-arrow/src/compute/cast/mod.rs +++ b/crates/polars-arrow/src/compute/cast/mod.rs @@ -585,9 +585,23 @@ pub fn cast( LargeUtf8 => Ok(Box::new(utf8_to_large_utf8( array.as_any().downcast_ref().unwrap(), ))), - Timestamp(TimeUnit::Nanosecond, None) => utf8_to_naive_timestamp_ns_dyn::(array), + Timestamp(TimeUnit::Nanosecond, None) => { + utf8_to_naive_timestamp_dyn::(array, TimeUnit::Nanosecond) + }, + Timestamp(TimeUnit::Millisecond, None) => { + utf8_to_naive_timestamp_dyn::(array, TimeUnit::Millisecond) + }, + Timestamp(TimeUnit::Microsecond, None) => { + utf8_to_naive_timestamp_dyn::(array, TimeUnit::Microsecond) + }, Timestamp(TimeUnit::Nanosecond, Some(tz)) => { - utf8_to_timestamp_ns_dyn::(array, tz.clone()) + utf8_to_timestamp_dyn::(array, tz.clone(), TimeUnit::Nanosecond) + }, + Timestamp(TimeUnit::Millisecond, Some(tz)) => { + utf8_to_timestamp_dyn::(array, tz.clone(), TimeUnit::Millisecond) + }, + Timestamp(TimeUnit::Microsecond, Some(tz)) => { + utf8_to_timestamp_dyn::(array, tz.clone(), TimeUnit::Microsecond) }, _ => polars_bail!(InvalidOperation: "casting from {from_type:?} to {to_type:?} not supported", @@ -612,9 +626,23 @@ pub fn cast( to_type.clone(), ) .boxed()), - Timestamp(TimeUnit::Nanosecond, None) => utf8_to_naive_timestamp_ns_dyn::(array), + Timestamp(TimeUnit::Nanosecond, None) => { + utf8_to_naive_timestamp_dyn::(array, TimeUnit::Nanosecond) + }, + Timestamp(TimeUnit::Millisecond, None) => { + utf8_to_naive_timestamp_dyn::(array, TimeUnit::Millisecond) + }, + Timestamp(TimeUnit::Microsecond, None) => { + utf8_to_naive_timestamp_dyn::(array, TimeUnit::Microsecond) + }, Timestamp(TimeUnit::Nanosecond, Some(tz)) => { - utf8_to_timestamp_ns_dyn::(array, tz.clone()) + utf8_to_timestamp_dyn::(array, tz.clone(), TimeUnit::Nanosecond) + }, + Timestamp(TimeUnit::Millisecond, Some(tz)) => { + utf8_to_timestamp_dyn::(array, tz.clone(), TimeUnit::Millisecond) + }, + Timestamp(TimeUnit::Microsecond, Some(tz)) => { + utf8_to_timestamp_dyn::(array, tz.clone(), TimeUnit::Microsecond) }, _ => polars_bail!(InvalidOperation: "casting from {from_type:?} to {to_type:?} not supported", diff --git a/crates/polars-arrow/src/compute/cast/utf8_to.rs b/crates/polars-arrow/src/compute/cast/utf8_to.rs index 6af294c00e44..85a252544e5e 100644 --- a/crates/polars-arrow/src/compute/cast/utf8_to.rs +++ b/crates/polars-arrow/src/compute/cast/utf8_to.rs @@ -3,11 +3,11 @@ use polars_error::PolarsResult; use super::CastOptions; use crate::array::*; -use crate::datatypes::DataType; +use crate::datatypes::{DataType, TimeUnit}; use crate::offset::Offset; use crate::temporal_conversions::{ - utf8_to_naive_timestamp_ns as utf8_to_naive_timestamp_ns_, - utf8_to_timestamp_ns as utf8_to_timestamp_ns_, EPOCH_DAYS_FROM_CE, + utf8_to_naive_timestamp as utf8_to_naive_timestamp_, utf8_to_timestamp as utf8_to_timestamp_, + EPOCH_DAYS_FROM_CE, }; use crate::types::NativeType; @@ -110,34 +110,40 @@ pub fn utf8_to_dictionary( Ok(array.into()) } -pub(super) fn utf8_to_naive_timestamp_ns_dyn( +pub(super) fn utf8_to_naive_timestamp_dyn( from: &dyn Array, + tu: TimeUnit, ) -> PolarsResult> { let from = from.as_any().downcast_ref().unwrap(); - Ok(Box::new(utf8_to_naive_timestamp_ns::(from))) + Ok(Box::new(utf8_to_naive_timestamp::(from, tu))) } -/// [`crate::temporal_conversions::utf8_to_timestamp_ns`] applied for RFC3339 formatting -pub fn utf8_to_naive_timestamp_ns(from: &Utf8Array) -> PrimitiveArray { - utf8_to_naive_timestamp_ns_(from, RFC3339) +/// [`crate::temporal_conversions::utf8_to_timestamp`] applied for RFC3339 formatting +pub fn utf8_to_naive_timestamp( + from: &Utf8Array, + tu: TimeUnit, +) -> PrimitiveArray { + utf8_to_naive_timestamp_(from, RFC3339, tu) } -pub(super) fn utf8_to_timestamp_ns_dyn( +pub(super) fn utf8_to_timestamp_dyn( from: &dyn Array, timezone: String, + tu: TimeUnit, ) -> PolarsResult> { let from = from.as_any().downcast_ref().unwrap(); - utf8_to_timestamp_ns::(from, timezone) + utf8_to_timestamp::(from, timezone, tu) .map(Box::new) .map(|x| x as Box) } -/// [`crate::temporal_conversions::utf8_to_timestamp_ns`] applied for RFC3339 formatting -pub fn utf8_to_timestamp_ns( +/// [`crate::temporal_conversions::utf8_to_timestamp`] applied for RFC3339 formatting +pub fn utf8_to_timestamp( from: &Utf8Array, timezone: String, + tu: TimeUnit, ) -> PolarsResult> { - utf8_to_timestamp_ns_(from, RFC3339, timezone) + utf8_to_timestamp_(from, RFC3339, timezone, tu) } /// Conversion of utf8 diff --git a/crates/polars-arrow/src/temporal_conversions.rs b/crates/polars-arrow/src/temporal_conversions.rs index bcaa4875363d..07981e017437 100644 --- a/crates/polars-arrow/src/temporal_conversions.rs +++ b/crates/polars-arrow/src/temporal_conversions.rs @@ -321,17 +321,6 @@ pub fn parse_offset(offset: &str) -> PolarsResult { .expect("FixedOffset::east out of bounds")) } -/// Parses `value` to `Option` consistent with the Arrow's definition of timestamp with timezone. -/// `tz` must be built from `timezone` (either via [`parse_offset`] or `chrono-tz`). -#[inline] -pub fn utf8_to_timestamp_ns_scalar( - value: &str, - fmt: &str, - tz: &T, -) -> Option { - utf8_to_timestamp_scalar(value, fmt, tz, &TimeUnit::Nanosecond) -} - /// Parses `value` to `Option` consistent with the Arrow's definition of timestamp with timezone. /// `tz` must be built from `timezone` (either via [`parse_offset`] or `chrono-tz`). /// Returns in scale `tz` of `TimeUnit`. @@ -362,12 +351,6 @@ pub fn utf8_to_timestamp_scalar( } } -/// Parses `value` to `Option` consistent with the Arrow's definition of timestamp without timezone. -#[inline] -pub fn utf8_to_naive_timestamp_ns_scalar(value: &str, fmt: &str) -> Option { - utf8_to_naive_timestamp_scalar(value, fmt, &TimeUnit::Nanosecond) -} - /// Parses `value` to `Option` consistent with the Arrow's definition of timestamp without timezone. /// Returns in scale `tz` of `TimeUnit`. #[inline] @@ -386,18 +369,18 @@ pub fn utf8_to_naive_timestamp_scalar(value: &str, fmt: &str, tu: &TimeUnit) -> .ok() } -fn utf8_to_timestamp_ns_impl( +fn utf8_to_timestamp_impl( array: &Utf8Array, fmt: &str, timezone: String, tz: T, + tu: TimeUnit, ) -> PrimitiveArray { let iter = array .iter() - .map(|x| x.and_then(|x| utf8_to_timestamp_ns_scalar(x, fmt, &tz))); + .map(|x| x.and_then(|x| utf8_to_timestamp_scalar(x, fmt, &tz, &tu))); - PrimitiveArray::from_trusted_len_iter(iter) - .to(DataType::Timestamp(TimeUnit::Nanosecond, Some(timezone))) + PrimitiveArray::from_trusted_len_iter(iter).to(DataType::Timestamp(tu, Some(timezone))) } /// Parses `value` to a [`chrono_tz::Tz`] with the Arrow's definition of timestamp with a timezone. @@ -411,13 +394,14 @@ pub fn parse_offset_tz(timezone: &str) -> PolarsResult { #[cfg(feature = "chrono-tz")] #[cfg_attr(docsrs, doc(cfg(feature = "chrono-tz")))] -fn chrono_tz_utf_to_timestamp_ns( +fn chrono_tz_utf_to_timestamp( array: &Utf8Array, fmt: &str, timezone: String, + tu: TimeUnit, ) -> PolarsResult> { let tz = parse_offset_tz(&timezone)?; - Ok(utf8_to_timestamp_ns_impl(array, fmt, timezone, tz)) + Ok(utf8_to_timestamp_impl(array, fmt, timezone, tz, tu)) } #[cfg(not(feature = "chrono-tz"))] @@ -432,22 +416,23 @@ fn chrono_tz_utf_to_timestamp_ns( /// Parses a [`Utf8Array`] to a timeozone-aware timestamp, i.e. [`PrimitiveArray`] with type `Timestamp(Nanosecond, Some(timezone))`. /// # Implementation /// * parsed values with timezone other than `timezone` are converted to `timezone`. -/// * parsed values without timezone are null. Use [`utf8_to_naive_timestamp_ns`] to parse naive timezones. +/// * parsed values without timezone are null. Use [`utf8_to_naive_timestamp`] to parse naive timezones. /// * Null elements remain null; non-parsable elements are null. /// The feature `"chrono-tz"` enables IANA and zoneinfo formats for `timezone`. /// # Error /// This function errors iff `timezone` is not parsable to an offset. -pub fn utf8_to_timestamp_ns( +pub fn utf8_to_timestamp( array: &Utf8Array, fmt: &str, timezone: String, + tu: TimeUnit, ) -> PolarsResult> { let tz = parse_offset(timezone.as_str()); if let Ok(tz) = tz { - Ok(utf8_to_timestamp_ns_impl(array, fmt, timezone, tz)) + Ok(utf8_to_timestamp_impl(array, fmt, timezone, tz, tu)) } else { - chrono_tz_utf_to_timestamp_ns(array, fmt, timezone) + chrono_tz_utf_to_timestamp(array, fmt, timezone, tu) } } @@ -455,15 +440,16 @@ pub fn utf8_to_timestamp_ns( /// [`PrimitiveArray`] with type `Timestamp(Nanosecond, None)`. /// Timezones are ignored. /// Null elements remain null; non-parsable elements are set to null. -pub fn utf8_to_naive_timestamp_ns( +pub fn utf8_to_naive_timestamp( array: &Utf8Array, fmt: &str, + tu: TimeUnit, ) -> PrimitiveArray { let iter = array .iter() - .map(|x| x.and_then(|x| utf8_to_naive_timestamp_ns_scalar(x, fmt))); + .map(|x| x.and_then(|x| utf8_to_naive_timestamp_scalar(x, fmt, &tu))); - PrimitiveArray::from_trusted_len_iter(iter).to(DataType::Timestamp(TimeUnit::Nanosecond, None)) + PrimitiveArray::from_trusted_len_iter(iter).to(DataType::Timestamp(tu, None)) } fn add_month(year: i32, month: u32, months: i32) -> chrono::NaiveDate { diff --git a/crates/polars-core/src/chunked_array/cast.rs b/crates/polars-core/src/chunked_array/cast.rs index a3184ff7fca3..84976edbb513 100644 --- a/crates/polars-core/src/chunked_array/cast.rs +++ b/crates/polars-core/src/chunked_array/cast.rs @@ -203,24 +203,21 @@ impl ChunkCast for Utf8Chunked { Ok(out) }, #[cfg(feature = "dtype-datetime")] - DataType::Datetime(_tu, tz) => { + DataType::Datetime(tu, tz) => { let out = match tz { #[cfg(feature = "timezones")] Some(tz) => { validate_time_zone(tz)?; let result = cast_chunks( &self.chunks, - &Datetime(TimeUnit::Nanoseconds, Some(tz.clone())), + &Datetime(tu.to_owned(), Some(tz.clone())), true, )?; Series::try_from((self.name(), result)) }, _ => { - let result = cast_chunks( - &self.chunks, - &Datetime(TimeUnit::Nanoseconds, None), - true, - )?; + let result = + cast_chunks(&self.chunks, &Datetime(tu.to_owned(), None), true)?; Series::try_from((self.name(), result)) }, }; diff --git a/py-polars/tests/unit/test_queries.py b/py-polars/tests/unit/test_queries.py index d239d23f65fe..13ec031ffc54 100644 --- a/py-polars/tests/unit/test_queries.py +++ b/py-polars/tests/unit/test_queries.py @@ -383,7 +383,25 @@ def test_utf8_date() -> None: def test_utf8_datetime() -> None: df = pl.DataFrame( - {"x1": ["2021-12-19T16:39:57-02:00", "2022-12-19T16:39:57"]} + {"x1": ["2021-12-19T00:39:57", "2022-12-19T16:39:57"]} + ).with_columns( + **{ + "x1-datetime-ns": pl.col("x1").cast(pl.Datetime(time_unit="ns")), + "x1-datetime-ms": pl.col("x1").cast(pl.Datetime(time_unit="ms")), + "x1-datetime-us": pl.col("x1").cast(pl.Datetime(time_unit="us")), + } + ) + + out = df.select( + pl.col("x1-datetime-ns"), pl.col("x1-datetime-ms"), pl.col("x1-datetime-us") + ) + assert out.shape == (2, 3) + assert out.dtypes == [pl.Datetime, pl.Datetime, pl.Datetime] + + +def test_utf8_datetime_timezone() -> None: + df = pl.DataFrame( + {"x1": ["1996-12-19T16:39:57-02:00", "2022-12-19T00:39:57-03:00"]} ).with_columns( **{ "x1-datetime-ns": pl.col("x1").cast(pl.Datetime(time_unit="ns")), From 013916acad2ef482aa5abc101107132252701fe9 Mon Sep 17 00:00:00 2001 From: Brayan Jules Date: Mon, 16 Oct 2023 08:57:57 -0300 Subject: [PATCH 095/103] feat: added missing tests for failure scenarios, also fixed casting from int to date. --- .../polars-arrow/src/temporal_conversions.rs | 3 +- crates/polars-core/src/chunked_array/cast.rs | 5 ++- .../src/chunked_array/temporal/mod.rs | 14 +++++++ py-polars/tests/unit/test_lazy.py | 2 +- py-polars/tests/unit/test_queries.py | 38 ++++++++++++++++--- 5 files changed, 53 insertions(+), 9 deletions(-) diff --git a/crates/polars-arrow/src/temporal_conversions.rs b/crates/polars-arrow/src/temporal_conversions.rs index 07981e017437..598305618e9b 100644 --- a/crates/polars-arrow/src/temporal_conversions.rs +++ b/crates/polars-arrow/src/temporal_conversions.rs @@ -405,10 +405,11 @@ fn chrono_tz_utf_to_timestamp( } #[cfg(not(feature = "chrono-tz"))] -fn chrono_tz_utf_to_timestamp_ns( +fn chrono_tz_utf_to_timestamp( _: &Utf8Array, _: &str, timezone: String, + _: TimeUnit, ) -> PolarsResult> { panic!("timezone \"{timezone}\" cannot be parsed (feature chrono-tz is not active)") } diff --git a/crates/polars-core/src/chunked_array/cast.rs b/crates/polars-core/src/chunked_array/cast.rs index 84976edbb513..132ae62d7d8c 100644 --- a/crates/polars-core/src/chunked_array/cast.rs +++ b/crates/polars-core/src/chunked_array/cast.rs @@ -5,6 +5,7 @@ use arrow::compute::cast::CastOptions; #[cfg(feature = "dtype-categorical")] use crate::chunked_array::categorical::CategoricalChunkedBuilder; +use crate::chunked_array::temporal::{validate_is_number}; #[cfg(feature = "timezones")] use crate::chunked_array::temporal::validate_time_zone; use crate::prelude::DataType::Datetime; @@ -197,13 +198,13 @@ impl ChunkCast for Utf8Chunked { }, }, #[cfg(feature = "dtype-date")] - DataType::Date => { + DataType::Date if !validate_is_number(&self.chunks) => { let result = cast_chunks(&self.chunks, data_type, true)?; let out = Series::try_from((self.name(), result))?; Ok(out) }, #[cfg(feature = "dtype-datetime")] - DataType::Datetime(tu, tz) => { + DataType::Datetime(tu, tz) if !validate_is_number(&self.chunks) => { let out = match tz { #[cfg(feature = "timezones")] Some(tz) => { diff --git a/crates/polars-core/src/chunked_array/temporal/mod.rs b/crates/polars-core/src/chunked_array/temporal/mod.rs index 737ff5086d47..9e0759a9b31d 100644 --- a/crates/polars-core/src/chunked_array/temporal/mod.rs +++ b/crates/polars-core/src/chunked_array/temporal/mod.rs @@ -15,6 +15,7 @@ use chrono::NaiveDateTime; use chrono::NaiveTime; #[cfg(feature = "timezones")] use chrono_tz::Tz; +use polars_arrow::prelude::ArrayRef; #[cfg(feature = "dtype-time")] pub use time::time_to_time64ns; @@ -35,3 +36,16 @@ pub(crate) fn validate_time_zone(tz: &str) -> PolarsResult<()> { }, } } + +pub(crate) fn validate_is_number(vec_array: &Vec) -> bool { + vec_array.iter().all(|array|is_parsable_as_number(array)) +} + +fn is_parsable_as_number(array: &ArrayRef) -> bool { + if let Some(array) = array.as_any().downcast_ref::() { + array.iter().all(|value| value.expect("Unable to parse int string to datetime").parse::().is_ok()) + } else { + false + } +} + diff --git a/py-polars/tests/unit/test_lazy.py b/py-polars/tests/unit/test_lazy.py index 30b30740e215..2fedbf853435 100644 --- a/py-polars/tests/unit/test_lazy.py +++ b/py-polars/tests/unit/test_lazy.py @@ -1375,7 +1375,7 @@ def test_quadratic_behavior_4736() -> None: ldf.select(reduce(add, (pl.col(fld) for fld in ldf.columns))) -@pytest.mark.parametrize("input_dtype", [pl.Int64, pl.Float64]) +@pytest.mark.parametrize("input_dtype", [pl.Int64, pl.Float64, pl.Utf8]) def test_from_epoch(input_dtype: pl.PolarsDataType) -> None: ldf = pl.LazyFrame( [ diff --git a/py-polars/tests/unit/test_queries.py b/py-polars/tests/unit/test_queries.py index 13ec031ffc54..8db9c1283904 100644 --- a/py-polars/tests/unit/test_queries.py +++ b/py-polars/tests/unit/test_queries.py @@ -5,8 +5,10 @@ import numpy as np import pandas as pd +import pytest import polars as pl +from polars import ComputeError from polars.testing import assert_frame_equal @@ -381,6 +383,13 @@ def test_utf8_date() -> None: assert out.dtypes == [pl.Date] +def test_wrong_utf8_date() -> None: + df = pl.DataFrame({"x1": ["2021-01-aa"]}) + + with pytest.raises(ComputeError): + df.with_columns(**{"x1-date": pl.col("x1").cast(pl.Date)}) + + def test_utf8_datetime() -> None: df = pl.DataFrame( {"x1": ["2021-12-19T00:39:57", "2022-12-19T16:39:57"]} @@ -399,19 +408,38 @@ def test_utf8_datetime() -> None: assert out.dtypes == [pl.Datetime, pl.Datetime, pl.Datetime] +def test_wrong_utf8_datetime() -> None: + df = pl.DataFrame({"x1": ["2021-12-19 00:39:57", "2022-12-19 16:39:57"]}) + with pytest.raises(ComputeError): + df.with_columns( + **{"x1-datetime-ns": pl.col("x1").cast(pl.Datetime(time_unit="ns"))} + ) + + def test_utf8_datetime_timezone() -> None: df = pl.DataFrame( - {"x1": ["1996-12-19T16:39:57-02:00", "2022-12-19T00:39:57-03:00"]} + {"x1": ["1996-12-19T16:39:57 +00:00", "2022-12-19T00:39:57 +00:00"]} ).with_columns( **{ - "x1-datetime-ns": pl.col("x1").cast(pl.Datetime(time_unit="ns")), - "x1-datetime-ms": pl.col("x1").cast(pl.Datetime(time_unit="ms")), - "x1-datetime-us": pl.col("x1").cast(pl.Datetime(time_unit="us")), + "x1-datetime-ns": pl.col("x1").cast( + pl.Datetime(time_unit="ns", time_zone="America/Caracas") + ), + "x1-datetime-ms": pl.col("x1").cast( + pl.Datetime(time_unit="ms", time_zone="America/Santiago") + ), + "x1-datetime-us": pl.col("x1").cast( + pl.Datetime(time_unit="us", time_zone="UTC") + ), } ) out = df.select( pl.col("x1-datetime-ns"), pl.col("x1-datetime-ms"), pl.col("x1-datetime-us") ) + assert out.shape == (2, 3) - assert out.dtypes == [pl.Datetime, pl.Datetime, pl.Datetime] + assert out.dtypes == [ + pl.Datetime("ns", "America/Caracas"), + pl.Datetime("ms", "America/Santiago"), + pl.Datetime("us", "UTC"), + ] From dab1451fd27896a8ab06de16b5335f8f1a56a030 Mon Sep 17 00:00:00 2001 From: Brayan Jules Date: Mon, 16 Oct 2023 09:31:44 -0300 Subject: [PATCH 096/103] fix: fixed issue regarding arrow libraries import and code formatting --- crates/polars-core/src/chunked_array/cast.rs | 2 +- .../polars-core/src/chunked_array/temporal/mod.rs | 14 +++++++++----- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/crates/polars-core/src/chunked_array/cast.rs b/crates/polars-core/src/chunked_array/cast.rs index 132ae62d7d8c..9ee0a25e2961 100644 --- a/crates/polars-core/src/chunked_array/cast.rs +++ b/crates/polars-core/src/chunked_array/cast.rs @@ -5,7 +5,7 @@ use arrow::compute::cast::CastOptions; #[cfg(feature = "dtype-categorical")] use crate::chunked_array::categorical::CategoricalChunkedBuilder; -use crate::chunked_array::temporal::{validate_is_number}; +use crate::chunked_array::temporal::validate_is_number; #[cfg(feature = "timezones")] use crate::chunked_array::temporal::validate_time_zone; use crate::prelude::DataType::Datetime; diff --git a/crates/polars-core/src/chunked_array/temporal/mod.rs b/crates/polars-core/src/chunked_array/temporal/mod.rs index 9e0759a9b31d..c6ea220b7d21 100644 --- a/crates/polars-core/src/chunked_array/temporal/mod.rs +++ b/crates/polars-core/src/chunked_array/temporal/mod.rs @@ -15,13 +15,13 @@ use chrono::NaiveDateTime; use chrono::NaiveTime; #[cfg(feature = "timezones")] use chrono_tz::Tz; -use polars_arrow::prelude::ArrayRef; #[cfg(feature = "dtype-time")] pub use time::time_to_time64ns; pub use self::conversion::*; #[cfg(feature = "timezones")] use crate::prelude::{polars_bail, PolarsResult}; +use crate::prelude::{ArrayRef, LargeStringArray}; pub fn unix_time() -> NaiveDateTime { NaiveDateTime::from_timestamp_opt(0, 0).unwrap() @@ -38,14 +38,18 @@ pub(crate) fn validate_time_zone(tz: &str) -> PolarsResult<()> { } pub(crate) fn validate_is_number(vec_array: &Vec) -> bool { - vec_array.iter().all(|array|is_parsable_as_number(array)) + vec_array.iter().all(|array| is_parsable_as_number(array)) } fn is_parsable_as_number(array: &ArrayRef) -> bool { - if let Some(array) = array.as_any().downcast_ref::() { - array.iter().all(|value| value.expect("Unable to parse int string to datetime").parse::().is_ok()) + if let Some(array) = array.as_any().downcast_ref::() { + array.iter().all(|value| { + value + .expect("Unable to parse int string to datetime") + .parse::() + .is_ok() + }) } else { false } } - From 374d946058767ed265a563d777115eb246a95824 Mon Sep 17 00:00:00 2001 From: Brayan Jules Date: Mon, 16 Oct 2023 18:45:03 -0300 Subject: [PATCH 097/103] fix: fixed validate_is_number import issue, also added missing dataframe validation on unit tests --- crates/polars-core/src/chunked_array/cast.rs | 3 +- .../src/chunked_array/temporal/mod.rs | 4 +- py-polars/tests/unit/test_queries.py | 49 +++++++++++++++++-- 3 files changed, 48 insertions(+), 8 deletions(-) diff --git a/crates/polars-core/src/chunked_array/cast.rs b/crates/polars-core/src/chunked_array/cast.rs index 9ee0a25e2961..cb5414ad70ff 100644 --- a/crates/polars-core/src/chunked_array/cast.rs +++ b/crates/polars-core/src/chunked_array/cast.rs @@ -5,9 +5,8 @@ use arrow::compute::cast::CastOptions; #[cfg(feature = "dtype-categorical")] use crate::chunked_array::categorical::CategoricalChunkedBuilder; -use crate::chunked_array::temporal::validate_is_number; #[cfg(feature = "timezones")] -use crate::chunked_array::temporal::validate_time_zone; +use crate::chunked_array::temporal::{validate_is_number, validate_time_zone}; use crate::prelude::DataType::Datetime; use crate::prelude::*; diff --git a/crates/polars-core/src/chunked_array/temporal/mod.rs b/crates/polars-core/src/chunked_array/temporal/mod.rs index c6ea220b7d21..3b6a38aede8b 100644 --- a/crates/polars-core/src/chunked_array/temporal/mod.rs +++ b/crates/polars-core/src/chunked_array/temporal/mod.rs @@ -37,8 +37,8 @@ pub(crate) fn validate_time_zone(tz: &str) -> PolarsResult<()> { } } -pub(crate) fn validate_is_number(vec_array: &Vec) -> bool { - vec_array.iter().all(|array| is_parsable_as_number(array)) +pub(crate) fn validate_is_number(vec_array: &[ArrayRef]) -> bool { + vec_array.iter().all(is_parsable_as_number) } fn is_parsable_as_number(array: &ArrayRef) -> bool { diff --git a/py-polars/tests/unit/test_queries.py b/py-polars/tests/unit/test_queries.py index 8db9c1283904..af623feb0e2c 100644 --- a/py-polars/tests/unit/test_queries.py +++ b/py-polars/tests/unit/test_queries.py @@ -1,6 +1,6 @@ from __future__ import annotations -from datetime import datetime, timedelta +from datetime import datetime, timedelta, date from typing import Any import numpy as np @@ -378,9 +378,11 @@ def test_utf8_date() -> None: df = pl.DataFrame({"x1": ["2021-01-01"]}).with_columns( **{"x1-date": pl.col("x1").cast(pl.Date)} ) + expected = pl.DataFrame({"x1-date":[date(2021,1,1)]}) out = df.select(pl.col("x1-date")) assert out.shape == (1, 1) assert out.dtypes == [pl.Date] + assert_frame_equal(expected, out) def test_wrong_utf8_date() -> None: @@ -400,12 +402,26 @@ def test_utf8_datetime() -> None: "x1-datetime-us": pl.col("x1").cast(pl.Datetime(time_unit="us")), } ) + first_row = datetime(year=2021, month=12, day=19, hour=00, minute=39, second=57) + second_row = datetime(year=2022, month=12, day=19, hour=16, minute=39, second=57) + expected = pl.DataFrame( + { + "x1-datetime-ns": [first_row, second_row], + "x1-datetime-ms": [first_row, second_row], + "x1-datetime-us": [first_row, second_row] + } + ).select( + pl.col("x1-datetime-ns").dt.cast_time_unit("ns"), + pl.col("x1-datetime-ms").dt.cast_time_unit("ms"), + pl.col("x1-datetime-us").dt.cast_time_unit("us"), + ) out = df.select( pl.col("x1-datetime-ns"), pl.col("x1-datetime-ms"), pl.col("x1-datetime-us") ) assert out.shape == (2, 3) assert out.dtypes == [pl.Datetime, pl.Datetime, pl.Datetime] + assert_frame_equal(expected, out) def test_wrong_utf8_datetime() -> None: @@ -417,22 +433,46 @@ def test_wrong_utf8_datetime() -> None: def test_utf8_datetime_timezone() -> None: + ccs_tz = "America/Caracas" + stg_tz = "America/Santiago" + utc_tz = "UTC" df = pl.DataFrame( {"x1": ["1996-12-19T16:39:57 +00:00", "2022-12-19T00:39:57 +00:00"]} ).with_columns( **{ "x1-datetime-ns": pl.col("x1").cast( - pl.Datetime(time_unit="ns", time_zone="America/Caracas") + pl.Datetime(time_unit="ns", time_zone=ccs_tz) ), "x1-datetime-ms": pl.col("x1").cast( - pl.Datetime(time_unit="ms", time_zone="America/Santiago") + pl.Datetime(time_unit="ms", time_zone=stg_tz) ), "x1-datetime-us": pl.col("x1").cast( - pl.Datetime(time_unit="us", time_zone="UTC") + pl.Datetime(time_unit="us", time_zone=utc_tz) ), } ) + expected = pl.DataFrame( + { + "x1-datetime-ns": [ + datetime(year=1996, month=12, day=19, hour=12, minute=39, second=57), + datetime(year=2022, month=12, day=18, hour=20, minute=39, second=57), + ], + "x1-datetime-ms": [ + datetime(year=1996, month=12, day=19, hour=13, minute=39, second=57), + datetime(year=2022, month=12, day=18, hour=21, minute=39, second=57), + ], + "x1-datetime-us": [ + datetime(year=1996, month=12, day=19, hour=16, minute=39, second=57), + datetime(year=2022, month=12, day=19, hour=00, minute=39, second=57), + ], + } + ).select( + pl.col("x1-datetime-ns").dt.cast_time_unit("ns").dt.replace_time_zone(ccs_tz), + pl.col("x1-datetime-ms").dt.cast_time_unit("ms").dt.replace_time_zone(stg_tz), + pl.col("x1-datetime-us").dt.cast_time_unit("us").dt.replace_time_zone(utc_tz), + ) + out = df.select( pl.col("x1-datetime-ns"), pl.col("x1-datetime-ms"), pl.col("x1-datetime-us") ) @@ -443,3 +483,4 @@ def test_utf8_datetime_timezone() -> None: pl.Datetime("ms", "America/Santiago"), pl.Datetime("us", "UTC"), ] + assert_frame_equal(expected, out) From 68f9d693286799add0931481cdc17e92ada06710 Mon Sep 17 00:00:00 2001 From: Brayan Jules Date: Mon, 16 Oct 2023 19:52:53 -0300 Subject: [PATCH 098/103] fix: fixed linter issues. --- crates/polars-core/src/chunked_array/cast.rs | 4 +++- py-polars/tests/unit/test_queries.py | 6 +++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/crates/polars-core/src/chunked_array/cast.rs b/crates/polars-core/src/chunked_array/cast.rs index cb5414ad70ff..6b7dbbb6835f 100644 --- a/crates/polars-core/src/chunked_array/cast.rs +++ b/crates/polars-core/src/chunked_array/cast.rs @@ -5,8 +5,10 @@ use arrow::compute::cast::CastOptions; #[cfg(feature = "dtype-categorical")] use crate::chunked_array::categorical::CategoricalChunkedBuilder; +#[cfg(feature = "temporal")] +use crate::chunked_array::temporal::validate_is_number; #[cfg(feature = "timezones")] -use crate::chunked_array::temporal::{validate_is_number, validate_time_zone}; +use crate::chunked_array::temporal::validate_time_zone; use crate::prelude::DataType::Datetime; use crate::prelude::*; diff --git a/py-polars/tests/unit/test_queries.py b/py-polars/tests/unit/test_queries.py index af623feb0e2c..48e116c39656 100644 --- a/py-polars/tests/unit/test_queries.py +++ b/py-polars/tests/unit/test_queries.py @@ -1,6 +1,6 @@ from __future__ import annotations -from datetime import datetime, timedelta, date +from datetime import date, datetime, timedelta from typing import Any import numpy as np @@ -378,7 +378,7 @@ def test_utf8_date() -> None: df = pl.DataFrame({"x1": ["2021-01-01"]}).with_columns( **{"x1-date": pl.col("x1").cast(pl.Date)} ) - expected = pl.DataFrame({"x1-date":[date(2021,1,1)]}) + expected = pl.DataFrame({"x1-date": [date(2021, 1, 1)]}) out = df.select(pl.col("x1-date")) assert out.shape == (1, 1) assert out.dtypes == [pl.Date] @@ -408,7 +408,7 @@ def test_utf8_datetime() -> None: { "x1-datetime-ns": [first_row, second_row], "x1-datetime-ms": [first_row, second_row], - "x1-datetime-us": [first_row, second_row] + "x1-datetime-us": [first_row, second_row], } ).select( pl.col("x1-datetime-ns").dt.cast_time_unit("ns"), From 8e810ff4199f4de27ffdadff224e01d5d41f7eaa Mon Sep 17 00:00:00 2001 From: Brayan Jules Date: Tue, 17 Oct 2023 18:37:40 -0300 Subject: [PATCH 099/103] fix: removing additional asserts from unit test, also improved pattern matching on timestamp casting --- crates/polars-arrow/src/compute/cast/mod.rs | 40 ++++----------------- py-polars/tests/unit/test_queries.py | 10 ------ 2 files changed, 6 insertions(+), 44 deletions(-) diff --git a/crates/polars-arrow/src/compute/cast/mod.rs b/crates/polars-arrow/src/compute/cast/mod.rs index b6778116b091..f02daf2a364f 100644 --- a/crates/polars-arrow/src/compute/cast/mod.rs +++ b/crates/polars-arrow/src/compute/cast/mod.rs @@ -585,23 +585,9 @@ pub fn cast( LargeUtf8 => Ok(Box::new(utf8_to_large_utf8( array.as_any().downcast_ref().unwrap(), ))), - Timestamp(TimeUnit::Nanosecond, None) => { - utf8_to_naive_timestamp_dyn::(array, TimeUnit::Nanosecond) - }, - Timestamp(TimeUnit::Millisecond, None) => { - utf8_to_naive_timestamp_dyn::(array, TimeUnit::Millisecond) - }, - Timestamp(TimeUnit::Microsecond, None) => { - utf8_to_naive_timestamp_dyn::(array, TimeUnit::Microsecond) - }, - Timestamp(TimeUnit::Nanosecond, Some(tz)) => { - utf8_to_timestamp_dyn::(array, tz.clone(), TimeUnit::Nanosecond) - }, - Timestamp(TimeUnit::Millisecond, Some(tz)) => { - utf8_to_timestamp_dyn::(array, tz.clone(), TimeUnit::Millisecond) - }, - Timestamp(TimeUnit::Microsecond, Some(tz)) => { - utf8_to_timestamp_dyn::(array, tz.clone(), TimeUnit::Microsecond) + Timestamp(tu, None) => utf8_to_naive_timestamp_dyn::(array, tu.to_owned()), + Timestamp(tu, Some(tz)) => { + utf8_to_timestamp_dyn::(array, tz.clone(), tu.to_owned()) }, _ => polars_bail!(InvalidOperation: "casting from {from_type:?} to {to_type:?} not supported", @@ -626,23 +612,9 @@ pub fn cast( to_type.clone(), ) .boxed()), - Timestamp(TimeUnit::Nanosecond, None) => { - utf8_to_naive_timestamp_dyn::(array, TimeUnit::Nanosecond) - }, - Timestamp(TimeUnit::Millisecond, None) => { - utf8_to_naive_timestamp_dyn::(array, TimeUnit::Millisecond) - }, - Timestamp(TimeUnit::Microsecond, None) => { - utf8_to_naive_timestamp_dyn::(array, TimeUnit::Microsecond) - }, - Timestamp(TimeUnit::Nanosecond, Some(tz)) => { - utf8_to_timestamp_dyn::(array, tz.clone(), TimeUnit::Nanosecond) - }, - Timestamp(TimeUnit::Millisecond, Some(tz)) => { - utf8_to_timestamp_dyn::(array, tz.clone(), TimeUnit::Millisecond) - }, - Timestamp(TimeUnit::Microsecond, Some(tz)) => { - utf8_to_timestamp_dyn::(array, tz.clone(), TimeUnit::Microsecond) + Timestamp(tu, None) => utf8_to_naive_timestamp_dyn::(array, tu.to_owned()), + Timestamp(tu, Some(tz)) => { + utf8_to_timestamp_dyn::(array, tz.clone(), tu.to_owned()) }, _ => polars_bail!(InvalidOperation: "casting from {from_type:?} to {to_type:?} not supported", diff --git a/py-polars/tests/unit/test_queries.py b/py-polars/tests/unit/test_queries.py index 48e116c39656..14835176cc80 100644 --- a/py-polars/tests/unit/test_queries.py +++ b/py-polars/tests/unit/test_queries.py @@ -380,8 +380,6 @@ def test_utf8_date() -> None: ) expected = pl.DataFrame({"x1-date": [date(2021, 1, 1)]}) out = df.select(pl.col("x1-date")) - assert out.shape == (1, 1) - assert out.dtypes == [pl.Date] assert_frame_equal(expected, out) @@ -419,8 +417,6 @@ def test_utf8_datetime() -> None: out = df.select( pl.col("x1-datetime-ns"), pl.col("x1-datetime-ms"), pl.col("x1-datetime-us") ) - assert out.shape == (2, 3) - assert out.dtypes == [pl.Datetime, pl.Datetime, pl.Datetime] assert_frame_equal(expected, out) @@ -477,10 +473,4 @@ def test_utf8_datetime_timezone() -> None: pl.col("x1-datetime-ns"), pl.col("x1-datetime-ms"), pl.col("x1-datetime-us") ) - assert out.shape == (2, 3) - assert out.dtypes == [ - pl.Datetime("ns", "America/Caracas"), - pl.Datetime("ms", "America/Santiago"), - pl.Datetime("us", "UTC"), - ] assert_frame_equal(expected, out) From 5225444705750b4b56b3c42fcabd336e2db45a5d Mon Sep 17 00:00:00 2001 From: Brayan Jules Date: Fri, 20 Oct 2023 22:21:50 -0300 Subject: [PATCH 100/103] fix: fixed the bug that incorrectly enabled the conversion from epoch string to datetime. --- crates/polars-core/src/chunked_array/cast.rs | 6 ++---- .../src/chunked_array/temporal/mod.rs | 19 +------------------ py-polars/tests/unit/test_lazy.py | 19 ++++++++++++++++++- 3 files changed, 21 insertions(+), 23 deletions(-) diff --git a/crates/polars-core/src/chunked_array/cast.rs b/crates/polars-core/src/chunked_array/cast.rs index 6b7dbbb6835f..84976edbb513 100644 --- a/crates/polars-core/src/chunked_array/cast.rs +++ b/crates/polars-core/src/chunked_array/cast.rs @@ -5,8 +5,6 @@ use arrow::compute::cast::CastOptions; #[cfg(feature = "dtype-categorical")] use crate::chunked_array::categorical::CategoricalChunkedBuilder; -#[cfg(feature = "temporal")] -use crate::chunked_array::temporal::validate_is_number; #[cfg(feature = "timezones")] use crate::chunked_array::temporal::validate_time_zone; use crate::prelude::DataType::Datetime; @@ -199,13 +197,13 @@ impl ChunkCast for Utf8Chunked { }, }, #[cfg(feature = "dtype-date")] - DataType::Date if !validate_is_number(&self.chunks) => { + DataType::Date => { let result = cast_chunks(&self.chunks, data_type, true)?; let out = Series::try_from((self.name(), result))?; Ok(out) }, #[cfg(feature = "dtype-datetime")] - DataType::Datetime(tu, tz) if !validate_is_number(&self.chunks) => { + DataType::Datetime(tu, tz) => { let out = match tz { #[cfg(feature = "timezones")] Some(tz) => { diff --git a/crates/polars-core/src/chunked_array/temporal/mod.rs b/crates/polars-core/src/chunked_array/temporal/mod.rs index 3b6a38aede8b..0a89825f6959 100644 --- a/crates/polars-core/src/chunked_array/temporal/mod.rs +++ b/crates/polars-core/src/chunked_array/temporal/mod.rs @@ -19,9 +19,9 @@ use chrono_tz::Tz; pub use time::time_to_time64ns; pub use self::conversion::*; +use crate::prelude::ArrayRef; #[cfg(feature = "timezones")] use crate::prelude::{polars_bail, PolarsResult}; -use crate::prelude::{ArrayRef, LargeStringArray}; pub fn unix_time() -> NaiveDateTime { NaiveDateTime::from_timestamp_opt(0, 0).unwrap() @@ -36,20 +36,3 @@ pub(crate) fn validate_time_zone(tz: &str) -> PolarsResult<()> { }, } } - -pub(crate) fn validate_is_number(vec_array: &[ArrayRef]) -> bool { - vec_array.iter().all(is_parsable_as_number) -} - -fn is_parsable_as_number(array: &ArrayRef) -> bool { - if let Some(array) = array.as_any().downcast_ref::() { - array.iter().all(|value| { - value - .expect("Unable to parse int string to datetime") - .parse::() - .is_ok() - }) - } else { - false - } -} diff --git a/py-polars/tests/unit/test_lazy.py b/py-polars/tests/unit/test_lazy.py index 2fedbf853435..7dc6478ab5d2 100644 --- a/py-polars/tests/unit/test_lazy.py +++ b/py-polars/tests/unit/test_lazy.py @@ -1375,7 +1375,7 @@ def test_quadratic_behavior_4736() -> None: ldf.select(reduce(add, (pl.col(fld) for fld in ldf.columns))) -@pytest.mark.parametrize("input_dtype", [pl.Int64, pl.Float64, pl.Utf8]) +@pytest.mark.parametrize("input_dtype", [pl.Int64, pl.Float64]) def test_from_epoch(input_dtype: pl.PolarsDataType) -> None: ldf = pl.LazyFrame( [ @@ -1415,6 +1415,23 @@ def test_from_epoch(input_dtype: pl.PolarsDataType) -> None: _ = ldf.select(pl.from_epoch(ts_col, time_unit="s2")) # type: ignore[call-overload] +def test_from_epoch_str() -> None: + ldf = pl.LazyFrame( + [ + pl.Series("timestamp_ms", [1147880044 * 1_000]).cast(pl.Utf8), + pl.Series("timestamp_us", [1147880044 * 1_000_000]).cast(pl.Utf8), + ] + ) + + with pytest.raises(ComputeError): + ldf.select( + [ + pl.from_epoch(pl.col("timestamp_ms"), time_unit="ms"), + pl.from_epoch(pl.col("timestamp_us"), time_unit="us"), + ] + ).collect() + + def test_cumagg_types() -> None: ldf = pl.LazyFrame({"a": [1, 2], "b": [True, False], "c": [1.3, 2.4]}) cumsum_lf = ldf.select( From fa5a0e193d214e9656f753734cddd9bfac4b225d Mon Sep 17 00:00:00 2001 From: Brayan Jules Date: Fri, 20 Oct 2023 22:37:26 -0300 Subject: [PATCH 101/103] fix: removed unused import --- crates/polars-core/src/chunked_array/temporal/mod.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/crates/polars-core/src/chunked_array/temporal/mod.rs b/crates/polars-core/src/chunked_array/temporal/mod.rs index 0a89825f6959..737ff5086d47 100644 --- a/crates/polars-core/src/chunked_array/temporal/mod.rs +++ b/crates/polars-core/src/chunked_array/temporal/mod.rs @@ -19,7 +19,6 @@ use chrono_tz::Tz; pub use time::time_to_time64ns; pub use self::conversion::*; -use crate::prelude::ArrayRef; #[cfg(feature = "timezones")] use crate::prelude::{polars_bail, PolarsResult}; From 30dcdc3f182d315f7b76df907001c897d4ebbce7 Mon Sep 17 00:00:00 2001 From: Brayan Jules Date: Sat, 21 Oct 2023 13:14:37 -0300 Subject: [PATCH 102/103] fix: fixed variable naming from tu to time_unit and from tz to time_zone --- crates/polars-arrow/src/compute/cast/mod.rs | 16 ++++++++++------ crates/polars-arrow/src/compute/cast/utf8_to.rs | 12 ++++++------ crates/polars-core/src/chunked_array/cast.rs | 12 ++++++------ 3 files changed, 22 insertions(+), 18 deletions(-) diff --git a/crates/polars-arrow/src/compute/cast/mod.rs b/crates/polars-arrow/src/compute/cast/mod.rs index f02daf2a364f..23edbd1c9056 100644 --- a/crates/polars-arrow/src/compute/cast/mod.rs +++ b/crates/polars-arrow/src/compute/cast/mod.rs @@ -585,9 +585,11 @@ pub fn cast( LargeUtf8 => Ok(Box::new(utf8_to_large_utf8( array.as_any().downcast_ref().unwrap(), ))), - Timestamp(tu, None) => utf8_to_naive_timestamp_dyn::(array, tu.to_owned()), - Timestamp(tu, Some(tz)) => { - utf8_to_timestamp_dyn::(array, tz.clone(), tu.to_owned()) + Timestamp(time_unit, None) => { + utf8_to_naive_timestamp_dyn::(array, time_unit.to_owned()) + }, + Timestamp(time_unit, Some(time_zone)) => { + utf8_to_timestamp_dyn::(array, time_zone.clone(), time_unit.to_owned()) }, _ => polars_bail!(InvalidOperation: "casting from {from_type:?} to {to_type:?} not supported", @@ -612,9 +614,11 @@ pub fn cast( to_type.clone(), ) .boxed()), - Timestamp(tu, None) => utf8_to_naive_timestamp_dyn::(array, tu.to_owned()), - Timestamp(tu, Some(tz)) => { - utf8_to_timestamp_dyn::(array, tz.clone(), tu.to_owned()) + Timestamp(time_unit, None) => { + utf8_to_naive_timestamp_dyn::(array, time_unit.to_owned()) + }, + Timestamp(time_unit, Some(time_zone)) => { + utf8_to_timestamp_dyn::(array, time_zone.clone(), time_unit.to_owned()) }, _ => polars_bail!(InvalidOperation: "casting from {from_type:?} to {to_type:?} not supported", diff --git a/crates/polars-arrow/src/compute/cast/utf8_to.rs b/crates/polars-arrow/src/compute/cast/utf8_to.rs index 85a252544e5e..4b6ac51fafd5 100644 --- a/crates/polars-arrow/src/compute/cast/utf8_to.rs +++ b/crates/polars-arrow/src/compute/cast/utf8_to.rs @@ -112,10 +112,10 @@ pub fn utf8_to_dictionary( pub(super) fn utf8_to_naive_timestamp_dyn( from: &dyn Array, - tu: TimeUnit, + time_unit: TimeUnit, ) -> PolarsResult> { let from = from.as_any().downcast_ref().unwrap(); - Ok(Box::new(utf8_to_naive_timestamp::(from, tu))) + Ok(Box::new(utf8_to_naive_timestamp::(from, time_unit))) } /// [`crate::temporal_conversions::utf8_to_timestamp`] applied for RFC3339 formatting @@ -129,10 +129,10 @@ pub fn utf8_to_naive_timestamp( pub(super) fn utf8_to_timestamp_dyn( from: &dyn Array, timezone: String, - tu: TimeUnit, + time_unit: TimeUnit, ) -> PolarsResult> { let from = from.as_any().downcast_ref().unwrap(); - utf8_to_timestamp::(from, timezone, tu) + utf8_to_timestamp::(from, timezone, time_unit) .map(Box::new) .map(|x| x as Box) } @@ -141,9 +141,9 @@ pub(super) fn utf8_to_timestamp_dyn( pub fn utf8_to_timestamp( from: &Utf8Array, timezone: String, - tu: TimeUnit, + time_unit: TimeUnit, ) -> PolarsResult> { - utf8_to_timestamp_(from, RFC3339, timezone, tu) + utf8_to_timestamp_(from, RFC3339, timezone, time_unit) } /// Conversion of utf8 diff --git a/crates/polars-core/src/chunked_array/cast.rs b/crates/polars-core/src/chunked_array/cast.rs index 84976edbb513..2c216a69731a 100644 --- a/crates/polars-core/src/chunked_array/cast.rs +++ b/crates/polars-core/src/chunked_array/cast.rs @@ -203,21 +203,21 @@ impl ChunkCast for Utf8Chunked { Ok(out) }, #[cfg(feature = "dtype-datetime")] - DataType::Datetime(tu, tz) => { - let out = match tz { + DataType::Datetime(time_unit, time_zone) => { + let out = match time_zone { #[cfg(feature = "timezones")] - Some(tz) => { - validate_time_zone(tz)?; + Some(time_zone) => { + validate_time_zone(time_zone)?; let result = cast_chunks( &self.chunks, - &Datetime(tu.to_owned(), Some(tz.clone())), + &Datetime(time_unit.to_owned(), Some(time_zone.clone())), true, )?; Series::try_from((self.name(), result)) }, _ => { let result = - cast_chunks(&self.chunks, &Datetime(tu.to_owned(), None), true)?; + cast_chunks(&self.chunks, &Datetime(time_unit.to_owned(), None), true)?; Series::try_from((self.name(), result)) }, }; From 20f94d9abba7f60da19b056c6d0c348bf310f832 Mon Sep 17 00:00:00 2001 From: Brayan Jules Date: Mon, 23 Oct 2023 22:38:35 -0300 Subject: [PATCH 103/103] Revert "Merge remote-tracking branch 'origin/utf8-to-temporal-cast' into utf8-to-temporal-cast" This reverts commit 215990bf85e355cb203caa246b643ff7d1c14c19, reversing changes made to c5459f1226218c09478fb86ffa41453dd62c80a7. --- .github/dependabot.yml | 12 - .github/workflows/docs-global.yml | 8 +- .github/workflows/lint-global.yml | 2 +- .github/workflows/test-rust.yml | 18 +- CONTRIBUTING.md | 2 +- Cargo.toml | 5 +- crates/Makefile | 24 +- crates/polars-arrow/Cargo.toml | 6 +- crates/polars-arrow/src/array/ord.rs | 180 ++++++ .../src/compute/arithmetics/basic/add.rs | 314 +++++++++- .../src/compute/arithmetics/basic/div.rs | 63 ++ .../src/compute/arithmetics/basic/mod.rs | 2 + .../src/compute/arithmetics/basic/mul.rs | 315 +++++++++- .../src/compute/arithmetics/basic/pow.rs | 49 ++ .../src/compute/arithmetics/basic/rem.rs | 111 +++- .../src/compute/arithmetics/basic/sub.rs | 314 +++++++++- .../src/compute/arithmetics/decimal/add.rs | 220 +++++++ .../src/compute/arithmetics/decimal/div.rs | 301 +++++++++ .../src/compute/arithmetics/decimal/mod.rs | 120 ++++ .../src/compute/arithmetics/decimal/mul.rs | 313 ++++++++++ .../src/compute/arithmetics/decimal/sub.rs | 237 +++++++ .../src/compute/arithmetics/mod.rs | 131 ++++ .../src/compute/arithmetics/time.rs | 432 +++++++++++++ crates/polars-arrow/src/compute/cast/mod.rs | 23 +- .../polars-arrow/src/compute/cast/utf8_to.rs | 12 +- .../polars-arrow/src/compute/if_then_else.rs | 18 + crates/polars-arrow/src/ffi/schema.rs | 4 - crates/polars-arrow/src/io/ipc/mod.rs | 48 ++ .../src/io/ipc/write/file_async.rs | 40 ++ .../src/io/ipc/write/stream_async.rs | 31 + .../parquet/read/deserialize/nested_utils.rs | 15 +- .../polars-arrow/src/io/parquet/write/sink.rs | 41 ++ .../src/legacy/kernels/ewm/mod.rs | 5 +- crates/polars-core/Cargo.toml | 8 + .../src/chunked_array/builder/binary.rs | 6 +- .../src/chunked_array/builder/boolean.rs | 4 +- .../src/chunked_array/builder/primitive.rs | 4 +- .../src/chunked_array/builder/utf8.rs | 4 +- crates/polars-core/src/chunked_array/cast.rs | 46 +- crates/polars-core/src/chunked_array/from.rs | 19 +- .../src/chunked_array/list/iterator.rs | 50 +- .../chunked_array/logical/categorical/from.rs | 4 +- .../chunked_array/logical/categorical/mod.rs | 56 +- .../logical/categorical/ops/append.rs | 22 +- .../logical/categorical/ops/unique.rs | 16 +- .../logical/categorical/ops/zip.rs | 6 +- crates/polars-core/src/chunked_array/mod.rs | 11 +- .../src/chunked_array/object/builder.rs | 6 - .../src/chunked_array/ops/append.rs | 37 +- .../src/chunked_array/ops/apply.rs | 1 - .../src/chunked_array/ops/chunkops.rs | 12 - .../src/chunked_array/ops/compare_inner.rs | 2 +- .../src/chunked_array/ops/cum_agg.rs | 176 ++++++ .../polars-core/src/chunked_array/ops/mod.rs | 22 + .../ops/sort/arg_sort_multiple.rs | 2 +- .../src/chunked_array/ops/sort/categorical.rs | 12 +- .../src/chunked_array/upstream_traits.rs | 2 - crates/polars-core/src/fmt.rs | 168 +---- .../polars-core/src/frame/group_by/perfect.rs | 6 +- crates/polars-core/src/series/comparison.rs | 19 +- .../src/series/implementations/categorical.rs | 64 +- .../src/series/implementations/dates_time.rs | 11 + .../src/series/implementations/datetime.rs | 16 + .../src/series/implementations/duration.rs | 16 + .../src/series/implementations/floats.rs | 10 + .../src/series/implementations/mod.rs | 10 + crates/polars-core/src/series/into.rs | 2 +- crates/polars-core/src/series/mod.rs | 106 +++- crates/polars-core/src/series/ops/diff.rs | 24 + crates/polars-core/src/series/ops/ewm.rs | 104 ++++ crates/polars-core/src/series/ops/mod.rs | 8 + .../polars-core/src/series/ops/pct_change.rs | 48 ++ .../src/series/ops/round.rs | 66 +- crates/polars-core/src/series/series_trait.rs | 12 + crates/polars-core/src/utils/mod.rs | 3 - crates/polars-error/Cargo.toml | 4 +- crates/polars-ffi/src/lib.rs | 22 +- crates/polars-io/Cargo.toml | 2 +- .../polars-io/src/cloud/object_store_setup.rs | 43 +- crates/polars-io/src/parquet/read.rs | 7 - crates/polars-io/src/parquet/read_impl.rs | 49 +- crates/polars-lazy/Cargo.toml | 3 +- crates/polars-lazy/src/frame/mod.rs | 16 +- .../src/physical_plan/executors/join.rs | 2 - .../src/physical_plan/executors/scan/csv.rs | 2 +- .../src/physical_plan/executors/scan/ipc.rs | 2 +- .../physical_plan/executors/scan/parquet.rs | 5 +- .../physical_plan/expressions/aggregation.rs | 2 - .../src/physical_plan/expressions/binary.rs | 2 - .../src/physical_plan/expressions/sort.rs | 1 - .../src/physical_plan/expressions/window.rs | 1 - .../src/physical_plan/planner/expr.rs | 7 +- .../src/physical_plan/planner/lp.rs | 90 ++- .../physical_plan/streaming/convert_alp.rs | 5 - crates/polars-lazy/src/prelude.rs | 3 +- crates/polars-lazy/src/utils.rs | 8 +- crates/polars-ops/Cargo.toml | 10 +- .../src/chunked_array/interpolate.rs | 27 +- .../src/chunked_array/list/namespace.rs | 84 +-- .../polars-ops/src/frame/join/asof/groups.rs | 35 +- .../src/frame/join/hash_join/mod.rs | 49 +- .../src/frame/join/hash_join/zip_outer.rs | 4 +- crates/polars-ops/src/frame/join/mod.rs | 20 +- .../series/ops/approx_algo/hyperloglogplus.rs | 1 - crates/polars-ops/src/series/ops/cum_agg.rs | 230 ------- crates/polars-ops/src/series/ops/diff.rs | 22 - crates/polars-ops/src/series/ops/ewm.rs | 103 ---- crates/polars-ops/src/series/ops/is_in.rs | 2 +- crates/polars-ops/src/series/ops/mod.rs | 20 - .../polars-ops/src/series/ops/pct_change.rs | 25 - crates/polars-ops/src/series/ops/rank.rs | 21 +- .../src/executors/sources/parquet.rs | 3 +- crates/polars-pipe/src/pipeline/convert.rs | 8 +- crates/polars-plan/Cargo.toml | 17 +- crates/polars-plan/src/dot.rs | 30 +- .../polars-plan/src/dsl/function_expr/cum.rs | 20 +- .../src/dsl/function_expr/dispatch.rs | 36 +- .../polars-plan/src/dsl/function_expr/ewm.rs | 13 - .../polars-plan/src/dsl/function_expr/list.rs | 41 -- .../polars-plan/src/dsl/function_expr/mod.rs | 120 +--- .../src/dsl/function_expr/plugin.rs | 72 +-- .../src/dsl/function_expr/schema.rs | 33 +- .../polars-plan/src/dsl/functions/temporal.rs | 2 - crates/polars-plan/src/dsl/list.rs | 42 -- crates/polars-plan/src/dsl/mod.rs | 123 +++- crates/polars-plan/src/logical_plan/alp.rs | 6 +- .../polars-plan/src/logical_plan/builder.rs | 10 +- .../src/logical_plan/conversion.rs | 8 +- crates/polars-plan/src/logical_plan/format.rs | 22 +- crates/polars-plan/src/logical_plan/mod.rs | 2 +- .../src/logical_plan/optimizer/cse.rs | 4 +- .../logical_plan/optimizer/file_caching.rs | 22 +- .../optimizer/predicate_pushdown/mod.rs | 9 +- .../optimizer/predicate_pushdown/utils.rs | 47 +- .../optimizer/projection_pushdown/mod.rs | 4 +- .../projection_pushdown/projection.rs | 52 +- .../optimizer/slice_pushdown_lp.rs | 8 +- .../src/logical_plan/projection.rs | 19 +- crates/polars-plan/src/logical_plan/schema.rs | 6 +- crates/polars-sql/Cargo.toml | 2 +- crates/polars-sql/src/functions.rs | 44 +- crates/polars-time/Cargo.toml | 2 +- crates/polars/Cargo.toml | 13 +- crates/polars/src/lib.rs | 8 +- docs/_build/scripts/people.py | 3 +- docs/index.md | 4 + docs/requirements.txt | 6 +- .../python/user-guide/expressions/lists.py | 5 +- .../user-guide/expressions/operators.py | 2 +- docs/src/python/user-guide/sql/intro.py | 16 +- docs/user-guide/concepts/lazy-vs-eager.md | 2 +- docs/user-guide/expressions/plugins.md | 237 ------- .../expressions/user-defined-functions.md | 2 +- docs/user-guide/io/cloud-storage.md | 2 +- docs/user-guide/migration/pandas.md | 25 +- .../transformations/time-series/timezones.md | 14 +- .../python_rust_compiled_function/Cargo.toml | 2 +- examples/read_parquet_cloud/Cargo.toml | 2 +- examples/write_parquet_cloud/Cargo.toml | 2 +- mkdocs.yml | 5 +- py-polars/Cargo.lock | 33 +- py-polars/Cargo.toml | 7 +- py-polars/docs/requirements-docs.txt | 2 +- .../reference/dataframe/modify_select.rst | 1 - .../source/reference/expressions/list.rst | 1 - .../reference/lazyframe/modify_select.rst | 1 - .../docs/source/reference/series/list.rst | 1 - py-polars/docs/source/reference/testing.rst | 2 - py-polars/polars/__init__.py | 2 - py-polars/polars/config.py | 4 +- py-polars/polars/dataframe/frame.py | 135 +--- py-polars/polars/dataframe/group_by.py | 6 +- py-polars/polars/datatypes/__init__.py | 6 +- py-polars/polars/datatypes/classes.py | 136 ++-- py-polars/polars/datatypes/constants.py | 5 - py-polars/polars/datatypes/convert.py | 3 +- py-polars/polars/expr/array.py | 8 +- py-polars/polars/expr/expr.py | 47 +- py-polars/polars/expr/list.py | 60 +- py-polars/polars/functions/eager.py | 4 +- py-polars/polars/io/database.py | 83 +-- py-polars/polars/io/spreadsheet/functions.py | 17 +- py-polars/polars/lazyframe/frame.py | 143 +---- py-polars/polars/series/array.py | 8 +- py-polars/polars/series/list.py | 42 +- py-polars/polars/series/series.py | 71 +-- py-polars/polars/series/utils.py | 2 +- py-polars/polars/testing/__init__.py | 6 +- py-polars/polars/testing/_private.py | 38 ++ py-polars/polars/testing/asserts.py | 581 ++++++++++++++++++ py-polars/polars/testing/asserts/__init__.py | 9 - py-polars/polars/testing/asserts/frame.py | 276 --------- py-polars/polars/testing/asserts/series.py | 403 ------------ py-polars/polars/testing/asserts/utils.py | 12 - .../polars/testing/parametric/primitives.py | 3 +- py-polars/polars/type_aliases.py | 4 +- py-polars/polars/utils/_async.py | 4 +- py-polars/polars/utils/_construction.py | 8 +- py-polars/polars/utils/convert.py | 2 +- py-polars/pyproject.toml | 1 - py-polars/requirements-dev.txt | 2 +- py-polars/requirements-lint.txt | 8 +- py-polars/src/conversion.rs | 5 +- py-polars/src/dataframe.rs | 2 +- py-polars/src/error.rs | 22 +- py-polars/src/expr/general.rs | 12 +- py-polars/src/expr/list.rs | 30 - py-polars/src/lazyframe.rs | 2 +- py-polars/src/lazygroupby.rs | 8 +- py-polars/src/series/export.rs | 5 - py-polars/src/series/numpy_ufunc.rs | 7 +- .../tests/parametric/test_groupby_rolling.py | 4 +- py-polars/tests/test_udfs.py | 2 +- py-polars/tests/unit/dataframe/test_df.py | 5 +- py-polars/tests/unit/datatypes/test_array.py | 37 +- .../tests/unit/datatypes/test_categorical.py | 15 - py-polars/tests/unit/datatypes/test_list.py | 26 +- py-polars/tests/unit/datatypes/test_struct.py | 6 +- .../tests/unit/datatypes/test_temporal.py | 22 +- py-polars/tests/unit/io/test_database_read.py | 156 ++--- py-polars/tests/unit/io/test_hive.py | 12 - py-polars/tests/unit/io/test_spreadsheet.py | 47 -- py-polars/tests/unit/namespaces/test_array.py | 8 +- py-polars/tests/unit/namespaces/test_list.py | 31 - .../unit/operations/map/test_map_groups.py | 6 +- .../unit/operations/rolling/test_rolling.py | 28 +- .../unit/operations/test_aggregations.py | 1 + .../tests/unit/operations/test_explode.py | 2 +- .../tests/unit/operations/test_group_by.py | 43 +- .../unit/operations/test_group_by_dynamic.py | 2 +- .../unit/operations/test_group_by_rolling.py | 50 +- py-polars/tests/unit/operations/test_join.py | 22 - py-polars/tests/unit/series/test_series.py | 83 +-- py-polars/tests/unit/sql/test_sql.py | 41 +- .../streaming/test_streaming_categoricals.py | 14 - py-polars/tests/unit/test_constructors.py | 4 - py-polars/tests/unit/test_errors.py | 2 +- py-polars/tests/unit/test_exprs.py | 4 +- py-polars/tests/unit/test_interop.py | 8 - py-polars/tests/unit/test_predicates.py | 11 - py-polars/tests/unit/test_projections.py | 6 - ...assert_series_equal.py => test_testing.py} | 525 +++++++++++++--- py-polars/tests/unit/testing/__init__.py | 0 .../unit/testing/test_assert_frame_equal.py | 420 ------------- 244 files changed, 6012 insertions(+), 4654 deletions(-) create mode 100644 crates/polars-arrow/src/array/ord.rs create mode 100644 crates/polars-arrow/src/compute/arithmetics/basic/pow.rs create mode 100644 crates/polars-arrow/src/compute/arithmetics/decimal/add.rs create mode 100644 crates/polars-arrow/src/compute/arithmetics/decimal/div.rs create mode 100644 crates/polars-arrow/src/compute/arithmetics/decimal/mod.rs create mode 100644 crates/polars-arrow/src/compute/arithmetics/decimal/mul.rs create mode 100644 crates/polars-arrow/src/compute/arithmetics/decimal/sub.rs create mode 100644 crates/polars-arrow/src/compute/arithmetics/time.rs create mode 100644 crates/polars-core/src/chunked_array/ops/cum_agg.rs create mode 100644 crates/polars-core/src/series/ops/diff.rs create mode 100644 crates/polars-core/src/series/ops/ewm.rs create mode 100644 crates/polars-core/src/series/ops/pct_change.rs rename crates/{polars-ops => polars-core}/src/series/ops/round.rs (55%) delete mode 100644 crates/polars-ops/src/series/ops/cum_agg.rs delete mode 100644 crates/polars-ops/src/series/ops/diff.rs delete mode 100644 crates/polars-ops/src/series/ops/ewm.rs delete mode 100644 crates/polars-ops/src/series/ops/pct_change.rs delete mode 100644 crates/polars-plan/src/dsl/function_expr/ewm.rs delete mode 100644 docs/user-guide/expressions/plugins.md create mode 100644 py-polars/polars/testing/_private.py create mode 100644 py-polars/polars/testing/asserts.py delete mode 100644 py-polars/polars/testing/asserts/__init__.py delete mode 100644 py-polars/polars/testing/asserts/frame.py delete mode 100644 py-polars/polars/testing/asserts/series.py delete mode 100644 py-polars/polars/testing/asserts/utils.py delete mode 100644 py-polars/tests/unit/streaming/test_streaming_categoricals.py rename py-polars/tests/unit/{testing/test_assert_series_equal.py => test_testing.py} (57%) delete mode 100644 py-polars/tests/unit/testing/__init__.py delete mode 100644 py-polars/tests/unit/testing/test_assert_frame_equal.py diff --git a/.github/dependabot.yml b/.github/dependabot.yml index ee769d681e8a..31d8d580266f 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -48,15 +48,3 @@ updates: prefix: build(python) prefix-development: chore(python) labels: ['skip changelog'] - - # Documentation - - package-ecosystem: pip - directory: docs - schedule: - interval: monthly - ignore: - - dependency-name: '*' - update-types: ['version-update:semver-patch'] - commit-message: - prefix: chore(python) - labels: ['skip changelog'] diff --git a/.github/workflows/docs-global.yml b/.github/workflows/docs-global.yml index 823278186535..6e8f12bcae5e 100644 --- a/.github/workflows/docs-global.yml +++ b/.github/workflows/docs-global.yml @@ -9,8 +9,6 @@ on: push: tags: - py-** - # Allow manual trigger until we have properly versioned docs - workflow_dispatch: jobs: markdown-link-check: @@ -28,7 +26,7 @@ jobs: - uses: psf/black@stable with: src: docs/src/python - version: "23.10.0" + version: "23.9.1" deploy: runs-on: ubuntu-latest @@ -74,12 +72,12 @@ jobs: run: mkdocs build - name: Add .nojekyll - if: github.ref_type == 'tag' || github.event_name == 'workflow_dispatch' + if: ${{ github.ref_type == 'tag' }} working-directory: site run: touch .nojekyll - name: Deploy docs - if: github.ref_type == 'tag' || github.event_name == 'workflow_dispatch' + if: ${{ github.ref_type == 'tag' }} uses: JamesIves/github-pages-deploy-action@v4 with: folder: site diff --git a/.github/workflows/lint-global.yml b/.github/workflows/lint-global.yml index 07344c893be9..2ebcc0dca3b0 100644 --- a/.github/workflows/lint-global.yml +++ b/.github/workflows/lint-global.yml @@ -15,4 +15,4 @@ jobs: - name: Lint Markdown and TOML uses: dprint/check@v2.2 - name: Spell Check with Typos - uses: crate-ci/typos@v1.16.20 + uses: crate-ci/typos@v1.16.8 diff --git a/.github/workflows/test-rust.yml b/.github/workflows/test-rust.yml index 381235da5d6b..d4085dcfa82c 100644 --- a/.github/workflows/test-rust.yml +++ b/.github/workflows/test-rust.yml @@ -45,29 +45,27 @@ jobs: - name: Compile tests run: > cargo test --all-features --no-run - -p polars-core - -p polars-io -p polars-lazy - -p polars-ops -p polars-plan - -p polars-row - -p polars-sql + -p polars-io + -p polars-core -p polars-time -p polars-utils + -p polars-row + -p polars-sql - name: Run tests if: github.ref_name != 'main' run: > cargo test --all-features - -p polars-core - -p polars-io -p polars-lazy - -p polars-ops -p polars-plan - -p polars-row - -p polars-sql + -p polars-io + -p polars-core -p polars-time -p polars-utils + -p polars-row + -p polars-sql integration-test: runs-on: ${{ matrix.os }} diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 280e5e03e581..44321d2f35bb 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -157,7 +157,7 @@ The user guide is maintained in the `docs/user-guide` folder. Before creating a The user guide is built using [MkDocs](https://www.mkdocs.org/). You install the dependencies for building the user guide by running `make requirements` in the root of the repo. -Run `mkdocs serve` to build and serve the user guide, so you can view it locally and see updates as you make changes. +Run `mkdocs serve` to build and serve the user guide so you can view it locally and see updates as you make changes. #### Creating a new user guide page diff --git a/Cargo.toml b/Cargo.toml index 6a5f0b2a835e..d99bf9e0bf25 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,7 +45,7 @@ rayon = "1.8" regex = "1.9" serde = "1.0.188" serde_json = "1" -simd-json = { version = "0.12", features = ["known-key"] } +simd-json = { version = "0.11", features = ["allow-non-simd", "known-key"] } smartstring = "1" sqlparser = "0.38" strum_macros = "0.25" @@ -85,7 +85,7 @@ arrow-array = { version = ">=41", default-features = false } arrow-buffer = { version = ">=41", default-features = false } arrow-data = { version = ">=41", default-features = false } arrow-schema = { version = ">=41", default-features = false } -parquet2 = { version = "0.17.2", features = ["async"], default-features = false } +parquet2 = { version = "0.17.2", features = ["async"] } avro-schema = { version = "0.3" } [workspace.dependencies.arrow] @@ -106,6 +106,5 @@ features = [ ] [patch.crates-io] -ahash = { git = "https://github.com/orlp/aHash", branch = "fix-arm-intrinsics" } # packed_simd_2 = { git = "https://github.com/rust-lang/packed_simd", rev = "e57c7ba11386147e6d2cbad7c88f376aab4bdc86" } # simd-json = { git = "https://github.com/ritchie46/simd-json", branch = "alignment" } diff --git a/crates/Makefile b/crates/Makefile index eaf0f02d5cce..718f3dde5580 100644 --- a/crates/Makefile +++ b/crates/Makefile @@ -42,30 +42,30 @@ miri: ## Run miri .PHONY: test test: ## Run tests cargo test --all-features \ - -p polars-core \ - -p polars-io \ -p polars-lazy \ - -p polars-ops \ - -p polars-plan \ - -p polars-row \ - -p polars-sql \ + -p polars-io \ + -p polars-core \ -p polars-time \ -p polars-utils \ + -p polars-row \ + -p polars-sql \ + -p polars-ops \ + -p polars-plan \ -- \ --test-threads=2 .PHONY: nextest nextest: ## Run tests with nextest cargo nextest run --all-features \ - -p polars-core \ - -p polars-io \ -p polars-lazy \ - -p polars-ops \ - -p polars-plan \ - -p polars-row \ - -p polars-sql \ + -p polars-io \ + -p polars-core \ -p polars-time \ -p polars-utils \ + -p polars-row \ + -p polars-sql \ + -p polars-ops \ + -p polars-plan \ .PHONY: integration-tests integration-tests: ## Run integration tests diff --git a/crates/polars-arrow/Cargo.toml b/crates/polars-arrow/Cargo.toml index 58fb4d3ff72e..39828ff7e59f 100644 --- a/crates/polars-arrow/Cargo.toml +++ b/crates/polars-arrow/Cargo.toml @@ -36,7 +36,7 @@ lexical-core = { workspace = true, optional = true } fallible-streaming-iterator = { workspace = true, optional = true } regex = { workspace = true, optional = true } -regex-syntax = { version = "0.8", optional = true } +regex-syntax = { version = "0.7", optional = true } streaming-iterator = { workspace = true } indexmap = { workspace = true, optional = true } @@ -47,7 +47,7 @@ hex = { workspace = true, optional = true } # for IPC compression lz4 = { version = "1.24", optional = true } -zstd = { version = "0.13", optional = true } +zstd = { version = "0.12", optional = true } base64 = { workspace = true, optional = true } @@ -74,7 +74,7 @@ arrow-array = { workspace = true, optional = true } arrow-buffer = { workspace = true, optional = true } arrow-data = { workspace = true, optional = true } arrow-schema = { workspace = true, optional = true } -parquet2 = { workspace = true, optional = true, default-features = true, features = ["async"] } +parquet2 = { workspace = true, optional = true, features = ["async"] } [dev-dependencies] avro-rs = { version = "0.13", features = ["snappy"] } diff --git a/crates/polars-arrow/src/array/ord.rs b/crates/polars-arrow/src/array/ord.rs new file mode 100644 index 000000000000..b585d67600e4 --- /dev/null +++ b/crates/polars-arrow/src/array/ord.rs @@ -0,0 +1,180 @@ +//! Contains functions and function factories to order values within arrays. +use std::cmp::Ordering; +use polars_error::polars_bail; + +use crate::array::*; +use crate::datatypes::*; +use crate::offset::Offset; +use crate::types::NativeType; +use crate::util::total_ord::TotalOrd; + +/// Compare the values at two arbitrary indices in two arrays. +pub type DynComparator = Box Ordering + Send + Sync>; + +fn compare_primitives( + left: &dyn Array, + right: &dyn Array, +) -> DynComparator { + let left = left + .as_any() + .downcast_ref::>() + .unwrap() + .clone(); + let right = right + .as_any() + .downcast_ref::>() + .unwrap() + .clone(); + Box::new(move |i, j| left.value(i).tot_cmp(&right.value(j))) +} + +fn compare_boolean(left: &dyn Array, right: &dyn Array) -> DynComparator { + let left = left + .as_any() + .downcast_ref::() + .unwrap() + .clone(); + let right = right + .as_any() + .downcast_ref::() + .unwrap() + .clone(); + Box::new(move |i, j| left.value(i).cmp(&right.value(j))) +} + +fn compare_string(left: &dyn Array, right: &dyn Array) -> DynComparator { + let left = left + .as_any() + .downcast_ref::>() + .unwrap() + .clone(); + let right = right + .as_any() + .downcast_ref::>() + .unwrap() + .clone(); + Box::new(move |i, j| left.value(i).cmp(right.value(j))) +} + +fn compare_binary(left: &dyn Array, right: &dyn Array) -> DynComparator { + let left = left + .as_any() + .downcast_ref::>() + .unwrap() + .clone(); + let right = right + .as_any() + .downcast_ref::>() + .unwrap() + .clone(); + Box::new(move |i, j| left.value(i).cmp(right.value(j))) +} + +fn compare_dict(left: &DictionaryArray, right: &DictionaryArray) -> Result +where + K: DictionaryKey, +{ + let left_keys = left.keys().values().clone(); + let right_keys = right.keys().values().clone(); + + let comparator = build_compare(left.values().as_ref(), right.values().as_ref())?; + + Ok(Box::new(move |i: usize, j: usize| { + // safety: all dictionaries keys are guaranteed to be castable to usize + let key_left = unsafe { left_keys[i].as_usize() }; + let key_right = unsafe { right_keys[j].as_usize() }; + (comparator)(key_left, key_right) + })) +} + +macro_rules! dyn_dict { + ($key:ty, $lhs:expr, $rhs:expr) => {{ + let lhs = $lhs.as_any().downcast_ref().unwrap(); + let rhs = $rhs.as_any().downcast_ref().unwrap(); + compare_dict::<$key>(lhs, rhs)? + }}; +} + +/// returns a comparison function that compares values at two different slots +/// between two [`Array`]. +/// # Example +/// ``` +/// use polars_arrow::array::{ord::build_compare, PrimitiveArray}; +/// +/// # fn main() -> polars_arrow::error::Result<()> { +/// let array1 = PrimitiveArray::from_slice([1, 2]); +/// let array2 = PrimitiveArray::from_slice([3, 4]); +/// +/// let cmp = build_compare(&array1, &array2)?; +/// +/// // 1 (index 0 of array1) is smaller than 4 (index 1 of array2) +/// assert_eq!(std::cmp::Ordering::Less, (cmp)(0, 1)); +/// # Ok(()) +/// # } +/// ``` +/// # Error +/// The arrays' [`DataType`] must be equal and the types must have a natural order. +// This is a factory of comparisons. +pub fn build_compare(left: &dyn Array, right: &dyn Array) -> Result { + use DataType::*; + use IntervalUnit::*; + use TimeUnit::*; + Ok(match (left.data_type(), right.data_type()) { + (a, b) if a != b => { + polars_bail!(ComputeError: + "Can't compare arrays of different types".to_string(), + ); + }, + (Boolean, Boolean) => compare_boolean(left, right), + (UInt8, UInt8) => compare_primitives::(left, right), + (UInt16, UInt16) => compare_primitives::(left, right), + (UInt32, UInt32) => compare_primitives::(left, right), + (UInt64, UInt64) => compare_primitives::(left, right), + (Int8, Int8) => compare_primitives::(left, right), + (Int16, Int16) => compare_primitives::(left, right), + (Int32, Int32) + | (Date32, Date32) + | (Time32(Second), Time32(Second)) + | (Time32(Millisecond), Time32(Millisecond)) + | (Interval(YearMonth), Interval(YearMonth)) => compare_primitives::(left, right), + (Int64, Int64) + | (Date64, Date64) + | (Time64(Microsecond), Time64(Microsecond)) + | (Time64(Nanosecond), Time64(Nanosecond)) + | (Timestamp(Second, None), Timestamp(Second, None)) + | (Timestamp(Millisecond, None), Timestamp(Millisecond, None)) + | (Timestamp(Microsecond, None), Timestamp(Microsecond, None)) + | (Timestamp(Nanosecond, None), Timestamp(Nanosecond, None)) + | (Duration(Second), Duration(Second)) + | (Duration(Millisecond), Duration(Millisecond)) + | (Duration(Microsecond), Duration(Microsecond)) + | (Duration(Nanosecond), Duration(Nanosecond)) => compare_primitives::(left, right), + (Float32, Float32) => compare_primitives::(left, right), + (Float64, Float64) => compare_primitives::(left, right), + (Decimal(_, _), Decimal(_, _)) => compare_primitives::(left, right), + (Utf8, Utf8) => compare_string::(left, right), + (LargeUtf8, LargeUtf8) => compare_string::(left, right), + (Binary, Binary) => compare_binary::(left, right), + (LargeBinary, LargeBinary) => compare_binary::(left, right), + (Dictionary(key_type_lhs, ..), Dictionary(key_type_rhs, ..)) => { + match (key_type_lhs, key_type_rhs) { + (IntegerType::UInt8, IntegerType::UInt8) => dyn_dict!(u8, left, right), + (IntegerType::UInt16, IntegerType::UInt16) => dyn_dict!(u16, left, right), + (IntegerType::UInt32, IntegerType::UInt32) => dyn_dict!(u32, left, right), + (IntegerType::UInt64, IntegerType::UInt64) => dyn_dict!(u64, left, right), + (IntegerType::Int8, IntegerType::Int8) => dyn_dict!(i8, left, right), + (IntegerType::Int16, IntegerType::Int16) => dyn_dict!(i16, left, right), + (IntegerType::Int32, IntegerType::Int32) => dyn_dict!(i32, left, right), + (IntegerType::Int64, IntegerType::Int64) => dyn_dict!(i64, left, right), + (lhs, _) => { + return Err(Error::InvalidArgumentError(format!( + "Dictionaries do not support keys of type {lhs:?}" + ))) + }, + } + }, + _ => { + unimplemented!() + }, + }) +} diff --git a/crates/polars-arrow/src/compute/arithmetics/basic/add.rs b/crates/polars-arrow/src/compute/arithmetics/basic/add.rs index ec941edc2381..4ac4fb8bd02f 100644 --- a/crates/polars-arrow/src/compute/arithmetics/basic/add.rs +++ b/crates/polars-arrow/src/compute/arithmetics/basic/add.rs @@ -1,12 +1,33 @@ //! Definition of basic add operations with primitive arrays use std::ops::Add; +use num_traits::ops::overflowing::OverflowingAdd; +use num_traits::{CheckedAdd, SaturatingAdd, WrappingAdd}; + use super::NativeArithmetics; use crate::array::PrimitiveArray; -use crate::compute::arity::{binary, unary}; +use crate::bitmap::Bitmap; +use crate::compute::arithmetics::{ + ArrayAdd, ArrayCheckedAdd, ArrayOverflowingAdd, ArraySaturatingAdd, ArrayWrappingAdd, +}; +use crate::compute::arity::{ + binary, binary_checked, binary_with_bitmap, unary, unary_checked, unary_with_bitmap, +}; /// Adds two primitive arrays with the same type. /// Panics if the sum of one pair of values overflows. +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::basic::add; +/// use polars_arrow::array::PrimitiveArray; +/// +/// let a = PrimitiveArray::from([None, Some(6), None, Some(6)]); +/// let b = PrimitiveArray::from([Some(5), None, None, Some(6)]); +/// let result = add(&a, &b); +/// let expected = PrimitiveArray::from([None, None, None, Some(12)]); +/// assert_eq!(result, expected) +/// ``` pub fn add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray where T: NativeArithmetics + Add, @@ -14,8 +35,166 @@ where binary(lhs, rhs, lhs.data_type().clone(), |a, b| a + b) } +/// Wrapping addition of two [`PrimitiveArray`]s. +/// It wraps around at the boundary of the type if the result overflows. +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::basic::wrapping_add; +/// use polars_arrow::array::PrimitiveArray; +/// +/// let a = PrimitiveArray::from([Some(-100i8), Some(100i8), Some(100i8)]); +/// let b = PrimitiveArray::from([Some(0i8), Some(100i8), Some(0i8)]); +/// let result = wrapping_add(&a, &b); +/// let expected = PrimitiveArray::from([Some(-100i8), Some(-56i8), Some(100i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn wrapping_add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + WrappingAdd, +{ + let op = move |a: T, b: T| a.wrapping_add(&b); + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Checked addition of two primitive arrays. If the result from the sum +/// overflows, the validity for that index is changed to None +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::basic::checked_add; +/// use polars_arrow::array::PrimitiveArray; +/// +/// let a = PrimitiveArray::from([Some(100i8), Some(100i8), Some(100i8)]); +/// let b = PrimitiveArray::from([Some(0i8), Some(100i8), Some(0i8)]); +/// let result = checked_add(&a, &b); +/// let expected = PrimitiveArray::from([Some(100i8), None, Some(100i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn checked_add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + CheckedAdd, +{ + let op = move |a: T, b: T| a.checked_add(&b); + + binary_checked(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Saturating addition of two primitive arrays. If the result from the sum is +/// larger than the possible number for this type, the result for the operation +/// will be the saturated value. +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::basic::saturating_add; +/// use polars_arrow::array::PrimitiveArray; +/// +/// let a = PrimitiveArray::from([Some(100i8)]); +/// let b = PrimitiveArray::from([Some(100i8)]); +/// let result = saturating_add(&a, &b); +/// let expected = PrimitiveArray::from([Some(127)]); +/// assert_eq!(result, expected); +/// ``` +pub fn saturating_add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + SaturatingAdd, +{ + let op = move |a: T, b: T| a.saturating_add(&b); + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Overflowing addition of two primitive arrays. If the result from the sum is +/// larger than the possible number for this type, the result for the operation +/// will be an array with overflowed values and a validity array indicating +/// the overflowing elements from the array. +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::basic::overflowing_add; +/// use polars_arrow::array::PrimitiveArray; +/// +/// let a = PrimitiveArray::from([Some(1i8), Some(100i8)]); +/// let b = PrimitiveArray::from([Some(1i8), Some(100i8)]); +/// let (result, overflow) = overflowing_add(&a, &b); +/// let expected = PrimitiveArray::from([Some(2i8), Some(-56i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn overflowing_add( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> (PrimitiveArray, Bitmap) +where + T: NativeArithmetics + OverflowingAdd, +{ + let op = move |a: T, b: T| a.overflowing_add(&b); + + binary_with_bitmap(lhs, rhs, lhs.data_type().clone(), op) +} + +// Implementation of ArrayAdd trait for PrimitiveArrays +impl ArrayAdd> for PrimitiveArray +where + T: NativeArithmetics + Add, +{ + fn add(&self, rhs: &PrimitiveArray) -> Self { + add(self, rhs) + } +} + +impl ArrayWrappingAdd> for PrimitiveArray +where + T: NativeArithmetics + WrappingAdd, +{ + fn wrapping_add(&self, rhs: &PrimitiveArray) -> Self { + wrapping_add(self, rhs) + } +} + +// Implementation of ArrayCheckedAdd trait for PrimitiveArrays +impl ArrayCheckedAdd> for PrimitiveArray +where + T: NativeArithmetics + CheckedAdd, +{ + fn checked_add(&self, rhs: &PrimitiveArray) -> Self { + checked_add(self, rhs) + } +} + +// Implementation of ArraySaturatingAdd trait for PrimitiveArrays +impl ArraySaturatingAdd> for PrimitiveArray +where + T: NativeArithmetics + SaturatingAdd, +{ + fn saturating_add(&self, rhs: &PrimitiveArray) -> Self { + saturating_add(self, rhs) + } +} + +// Implementation of ArraySaturatingAdd trait for PrimitiveArrays +impl ArrayOverflowingAdd> for PrimitiveArray +where + T: NativeArithmetics + OverflowingAdd, +{ + fn overflowing_add(&self, rhs: &PrimitiveArray) -> (Self, Bitmap) { + overflowing_add(self, rhs) + } +} + /// Adds a scalar T to a primitive array of type T. /// Panics if the sum of the values overflows. +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::basic::add_scalar; +/// use polars_arrow::array::PrimitiveArray; +/// +/// let a = PrimitiveArray::from([None, Some(6), None, Some(6)]); +/// let result = add_scalar(&a, &1i32); +/// let expected = PrimitiveArray::from([None, Some(7), None, Some(7)]); +/// assert_eq!(result, expected) +/// ``` pub fn add_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray where T: NativeArithmetics + Add, @@ -23,3 +202,136 @@ where let rhs = *rhs; unary(lhs, |a| a + rhs, lhs.data_type().clone()) } + +/// Wrapping addition of a scalar T to a [`PrimitiveArray`] of type T. +/// It do nothing if the result overflows. +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::basic::wrapping_add_scalar; +/// use polars_arrow::array::Int8Array; +/// +/// let a = Int8Array::from(&[None, Some(100)]); +/// let result = wrapping_add_scalar(&a, &100i8); +/// let expected = Int8Array::from(&[None, Some(-56)]); +/// assert_eq!(result, expected); +/// ``` +pub fn wrapping_add_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + WrappingAdd, +{ + unary(lhs, |a| a.wrapping_add(rhs), lhs.data_type().clone()) +} + +/// Checked addition of a scalar T to a primitive array of type T. If the +/// result from the sum overflows then the validity index for that value is +/// changed to None +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::basic::checked_add_scalar; +/// use polars_arrow::array::Int8Array; +/// +/// let a = Int8Array::from(&[None, Some(100), None, Some(100)]); +/// let result = checked_add_scalar(&a, &100i8); +/// let expected = Int8Array::from(&[None, None, None, None]); +/// assert_eq!(result, expected); +/// ``` +pub fn checked_add_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + CheckedAdd, +{ + let rhs = *rhs; + let op = move |a: T| a.checked_add(&rhs); + + unary_checked(lhs, op, lhs.data_type().clone()) +} + +/// Saturated addition of a scalar T to a primitive array of type T. If the +/// result from the sum is larger than the possible number for this type, then +/// the result will be saturated +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::basic::saturating_add_scalar; +/// use polars_arrow::array::PrimitiveArray; +/// +/// let a = PrimitiveArray::from([Some(100i8)]); +/// let result = saturating_add_scalar(&a, &100i8); +/// let expected = PrimitiveArray::from([Some(127)]); +/// assert_eq!(result, expected); +/// ``` +pub fn saturating_add_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + SaturatingAdd, +{ + let rhs = *rhs; + let op = move |a: T| a.saturating_add(&rhs); + + unary(lhs, op, lhs.data_type().clone()) +} + +/// Overflowing addition of a scalar T to a primitive array of type T. If the +/// result from the sum is larger than the possible number for this type, then +/// the result will be an array with overflowed values and a validity array +/// indicating the overflowing elements from the array +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::basic::overflowing_add_scalar; +/// use polars_arrow::array::PrimitiveArray; +/// +/// let a = PrimitiveArray::from([Some(1i8), Some(100i8)]); +/// let (result, overflow) = overflowing_add_scalar(&a, &100i8); +/// let expected = PrimitiveArray::from([Some(101i8), Some(-56i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn overflowing_add_scalar(lhs: &PrimitiveArray, rhs: &T) -> (PrimitiveArray, Bitmap) +where + T: NativeArithmetics + OverflowingAdd, +{ + let rhs = *rhs; + let op = move |a: T| a.overflowing_add(&rhs); + + unary_with_bitmap(lhs, op, lhs.data_type().clone()) +} + +// Implementation of ArrayAdd trait for PrimitiveArrays with a scalar +impl ArrayAdd for PrimitiveArray +where + T: NativeArithmetics + Add, +{ + fn add(&self, rhs: &T) -> Self { + add_scalar(self, rhs) + } +} + +// Implementation of ArrayCheckedAdd trait for PrimitiveArrays with a scalar +impl ArrayCheckedAdd for PrimitiveArray +where + T: NativeArithmetics + CheckedAdd, +{ + fn checked_add(&self, rhs: &T) -> Self { + checked_add_scalar(self, rhs) + } +} + +// Implementation of ArraySaturatingAdd trait for PrimitiveArrays with a scalar +impl ArraySaturatingAdd for PrimitiveArray +where + T: NativeArithmetics + SaturatingAdd, +{ + fn saturating_add(&self, rhs: &T) -> Self { + saturating_add_scalar(self, rhs) + } +} + +// Implementation of ArraySaturatingAdd trait for PrimitiveArrays with a scalar +impl ArrayOverflowingAdd for PrimitiveArray +where + T: NativeArithmetics + OverflowingAdd, +{ + fn overflowing_add(&self, rhs: &T) -> (Self, Bitmap) { + overflowing_add_scalar(self, rhs) + } +} diff --git a/crates/polars-arrow/src/compute/arithmetics/basic/div.rs b/crates/polars-arrow/src/compute/arithmetics/basic/div.rs index 9b5220b1b1ef..4b27001543e0 100644 --- a/crates/polars-arrow/src/compute/arithmetics/basic/div.rs +++ b/crates/polars-arrow/src/compute/arithmetics/basic/div.rs @@ -8,6 +8,7 @@ use strength_reduce::{ use super::NativeArithmetics; use crate::array::{Array, PrimitiveArray}; +use crate::compute::arithmetics::{ArrayCheckedDiv, ArrayDiv}; use crate::compute::arity::{binary, binary_checked, unary, unary_checked}; use crate::compute::utils::check_same_len; use crate::datatypes::PrimitiveType; @@ -66,8 +67,39 @@ where binary_checked(lhs, rhs, lhs.data_type().clone(), op) } +// Implementation of ArrayDiv trait for PrimitiveArrays +impl ArrayDiv> for PrimitiveArray +where + T: NativeArithmetics + Div, +{ + fn div(&self, rhs: &PrimitiveArray) -> Self { + div(self, rhs) + } +} + +// Implementation of ArrayCheckedDiv trait for PrimitiveArrays +impl ArrayCheckedDiv> for PrimitiveArray +where + T: NativeArithmetics + CheckedDiv, +{ + fn checked_div(&self, rhs: &PrimitiveArray) -> Self { + checked_div(self, rhs) + } +} + /// Divide a primitive array of type T by a scalar T. /// Panics if the divisor is zero. +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::basic::div_scalar; +/// use polars_arrow::array::Int32Array; +/// +/// let a = Int32Array::from(&[None, Some(6), None, Some(6)]); +/// let result = div_scalar(&a, &2i32); +/// let expected = Int32Array::from(&[None, Some(3), None, Some(3)]); +/// assert_eq!(result, expected) +/// ``` pub fn div_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray where T: NativeArithmetics + Div + NumCast, @@ -129,6 +161,17 @@ where /// Checked division of a primitive array of type T by a scalar T. If the /// divisor is zero then the validity array is changed to None. +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::basic::checked_div_scalar; +/// use polars_arrow::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(-100i8)]); +/// let result = checked_div_scalar(&a, &100i8); +/// let expected = Int8Array::from(&[Some(-1i8)]); +/// assert_eq!(result, expected); +/// ``` pub fn checked_div_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray where T: NativeArithmetics + CheckedDiv, @@ -138,3 +181,23 @@ where unary_checked(lhs, op, lhs.data_type().clone()) } + +// Implementation of ArrayDiv trait for PrimitiveArrays with a scalar +impl ArrayDiv for PrimitiveArray +where + T: NativeArithmetics + Div + NumCast, +{ + fn div(&self, rhs: &T) -> Self { + div_scalar(self, rhs) + } +} + +// Implementation of ArrayCheckedDiv trait for PrimitiveArrays with a scalar +impl ArrayCheckedDiv for PrimitiveArray +where + T: NativeArithmetics + CheckedDiv, +{ + fn checked_div(&self, rhs: &T) -> Self { + checked_div_scalar(self, rhs) + } +} diff --git a/crates/polars-arrow/src/compute/arithmetics/basic/mod.rs b/crates/polars-arrow/src/compute/arithmetics/basic/mod.rs index faa55af6bbd9..b01e31c5a214 100644 --- a/crates/polars-arrow/src/compute/arithmetics/basic/mod.rs +++ b/crates/polars-arrow/src/compute/arithmetics/basic/mod.rs @@ -11,6 +11,8 @@ mod div; pub use div::*; mod mul; pub use mul::*; +mod pow; +pub use pow::*; mod rem; pub use rem::*; mod sub; diff --git a/crates/polars-arrow/src/compute/arithmetics/basic/mul.rs b/crates/polars-arrow/src/compute/arithmetics/basic/mul.rs index a1ed463f0195..becdce1eba4a 100644 --- a/crates/polars-arrow/src/compute/arithmetics/basic/mul.rs +++ b/crates/polars-arrow/src/compute/arithmetics/basic/mul.rs @@ -1,12 +1,33 @@ //! Definition of basic mul operations with primitive arrays use std::ops::Mul; +use num_traits::ops::overflowing::OverflowingMul; +use num_traits::{CheckedMul, SaturatingMul, WrappingMul}; + use super::NativeArithmetics; use crate::array::PrimitiveArray; -use crate::compute::arity::{binary, unary}; +use crate::bitmap::Bitmap; +use crate::compute::arithmetics::{ + ArrayCheckedMul, ArrayMul, ArrayOverflowingMul, ArraySaturatingMul, ArrayWrappingMul, +}; +use crate::compute::arity::{ + binary, binary_checked, binary_with_bitmap, unary, unary_checked, unary_with_bitmap, +}; /// Multiplies two primitive arrays with the same type. /// Panics if the multiplication of one pair of values overflows. +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::basic::mul; +/// use polars_arrow::array::Int32Array; +/// +/// let a = Int32Array::from(&[None, Some(6), None, Some(6)]); +/// let b = Int32Array::from(&[Some(5), None, None, Some(6)]); +/// let result = mul(&a, &b); +/// let expected = Int32Array::from(&[None, None, None, Some(36)]); +/// assert_eq!(result, expected) +/// ``` pub fn mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray where T: NativeArithmetics + Mul, @@ -14,8 +35,167 @@ where binary(lhs, rhs, lhs.data_type().clone(), |a, b| a * b) } +/// Wrapping multiplication of two [`PrimitiveArray`]s. +/// It wraps around at the boundary of the type if the result overflows. +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::basic::wrapping_mul; +/// use polars_arrow::array::PrimitiveArray; +/// +/// let a = PrimitiveArray::from([Some(100i8), Some(0x10i8), Some(100i8)]); +/// let b = PrimitiveArray::from([Some(0i8), Some(0x10i8), Some(0i8)]); +/// let result = wrapping_mul(&a, &b); +/// let expected = PrimitiveArray::from([Some(0), Some(0), Some(0)]); +/// assert_eq!(result, expected); +/// ``` +pub fn wrapping_mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + WrappingMul, +{ + let op = move |a: T, b: T| a.wrapping_mul(&b); + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Checked multiplication of two primitive arrays. If the result from the +/// multiplications overflows, the validity for that index is changed +/// returned. +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::basic::checked_mul; +/// use polars_arrow::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(100i8), Some(100i8), Some(100i8)]); +/// let b = Int8Array::from(&[Some(1i8), Some(100i8), Some(1i8)]); +/// let result = checked_mul(&a, &b); +/// let expected = Int8Array::from(&[Some(100i8), None, Some(100i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn checked_mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + CheckedMul, +{ + let op = move |a: T, b: T| a.checked_mul(&b); + + binary_checked(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Saturating multiplication of two primitive arrays. If the result from the +/// multiplication overflows, the result for the +/// operation will be the saturated value. +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::basic::saturating_mul; +/// use polars_arrow::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(-100i8)]); +/// let b = Int8Array::from(&[Some(100i8)]); +/// let result = saturating_mul(&a, &b); +/// let expected = Int8Array::from(&[Some(-128)]); +/// assert_eq!(result, expected); +/// ``` +pub fn saturating_mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + SaturatingMul, +{ + let op = move |a: T, b: T| a.saturating_mul(&b); + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Overflowing multiplication of two primitive arrays. If the result from the +/// mul overflows, the result for the operation will be an array with +/// overflowed values and a validity array indicating the overflowing elements +/// from the array. +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::basic::overflowing_mul; +/// use polars_arrow::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(1i8), Some(-100i8)]); +/// let b = Int8Array::from(&[Some(1i8), Some(100i8)]); +/// let (result, overflow) = overflowing_mul(&a, &b); +/// let expected = Int8Array::from(&[Some(1i8), Some(-16i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn overflowing_mul( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> (PrimitiveArray, Bitmap) +where + T: NativeArithmetics + OverflowingMul, +{ + let op = move |a: T, b: T| a.overflowing_mul(&b); + + binary_with_bitmap(lhs, rhs, lhs.data_type().clone(), op) +} + +// Implementation of ArrayMul trait for PrimitiveArrays +impl ArrayMul> for PrimitiveArray +where + T: NativeArithmetics + Mul, +{ + fn mul(&self, rhs: &PrimitiveArray) -> Self { + mul(self, rhs) + } +} + +impl ArrayWrappingMul> for PrimitiveArray +where + T: NativeArithmetics + WrappingMul, +{ + fn wrapping_mul(&self, rhs: &PrimitiveArray) -> Self { + wrapping_mul(self, rhs) + } +} + +// Implementation of ArrayCheckedMul trait for PrimitiveArrays +impl ArrayCheckedMul> for PrimitiveArray +where + T: NativeArithmetics + CheckedMul, +{ + fn checked_mul(&self, rhs: &PrimitiveArray) -> Self { + checked_mul(self, rhs) + } +} + +// Implementation of ArraySaturatingMul trait for PrimitiveArrays +impl ArraySaturatingMul> for PrimitiveArray +where + T: NativeArithmetics + SaturatingMul, +{ + fn saturating_mul(&self, rhs: &PrimitiveArray) -> Self { + saturating_mul(self, rhs) + } +} + +// Implementation of ArraySaturatingMul trait for PrimitiveArrays +impl ArrayOverflowingMul> for PrimitiveArray +where + T: NativeArithmetics + OverflowingMul, +{ + fn overflowing_mul(&self, rhs: &PrimitiveArray) -> (Self, Bitmap) { + overflowing_mul(self, rhs) + } +} + /// Multiply a scalar T to a primitive array of type T. /// Panics if the multiplication of the values overflows. +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::basic::mul_scalar; +/// use polars_arrow::array::Int32Array; +/// +/// let a = Int32Array::from(&[None, Some(6), None, Some(6)]); +/// let result = mul_scalar(&a, &2i32); +/// let expected = Int32Array::from(&[None, Some(12), None, Some(12)]); +/// assert_eq!(result, expected) +/// ``` pub fn mul_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray where T: NativeArithmetics + Mul, @@ -23,3 +203,136 @@ where let rhs = *rhs; unary(lhs, |a| a * rhs, lhs.data_type().clone()) } + +/// Wrapping multiplication of a scalar T to a [`PrimitiveArray`] of type T. +/// It do nothing if the result overflows. +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::basic::wrapping_mul_scalar; +/// use polars_arrow::array::Int8Array; +/// +/// let a = Int8Array::from(&[None, Some(0x10)]); +/// let result = wrapping_mul_scalar(&a, &0x10); +/// let expected = Int8Array::from(&[None, Some(0)]); +/// assert_eq!(result, expected); +/// ``` +pub fn wrapping_mul_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + WrappingMul, +{ + unary(lhs, |a| a.wrapping_mul(rhs), lhs.data_type().clone()) +} + +/// Checked multiplication of a scalar T to a primitive array of type T. If the +/// result from the multiplication overflows, then the validity for that index is +/// changed to None +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::basic::checked_mul_scalar; +/// use polars_arrow::array::Int8Array; +/// +/// let a = Int8Array::from(&[None, Some(100), None, Some(100)]); +/// let result = checked_mul_scalar(&a, &100i8); +/// let expected = Int8Array::from(&[None, None, None, None]); +/// assert_eq!(result, expected); +/// ``` +pub fn checked_mul_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + CheckedMul, +{ + let rhs = *rhs; + let op = move |a: T| a.checked_mul(&rhs); + + unary_checked(lhs, op, lhs.data_type().clone()) +} + +/// Saturated multiplication of a scalar T to a primitive array of type T. If the +/// result from the mul overflows for this type, then +/// the result will be saturated +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::basic::saturating_mul_scalar; +/// use polars_arrow::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(-100i8)]); +/// let result = saturating_mul_scalar(&a, &100i8); +/// let expected = Int8Array::from(&[Some(-128i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn saturating_mul_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + SaturatingMul, +{ + let rhs = *rhs; + let op = move |a: T| a.saturating_mul(&rhs); + + unary(lhs, op, lhs.data_type().clone()) +} + +/// Overflowing multiplication of a scalar T to a primitive array of type T. If +/// the result from the mul overflows for this type, +/// then the result will be an array with overflowed values and a validity +/// array indicating the overflowing elements from the array +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::basic::overflowing_mul_scalar; +/// use polars_arrow::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(1i8), Some(100i8)]); +/// let (result, overflow) = overflowing_mul_scalar(&a, &100i8); +/// let expected = Int8Array::from(&[Some(100i8), Some(16i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn overflowing_mul_scalar(lhs: &PrimitiveArray, rhs: &T) -> (PrimitiveArray, Bitmap) +where + T: NativeArithmetics + OverflowingMul, +{ + let rhs = *rhs; + let op = move |a: T| a.overflowing_mul(&rhs); + + unary_with_bitmap(lhs, op, lhs.data_type().clone()) +} + +// Implementation of ArrayMul trait for PrimitiveArrays with a scalar +impl ArrayMul for PrimitiveArray +where + T: NativeArithmetics + Mul, +{ + fn mul(&self, rhs: &T) -> Self { + mul_scalar(self, rhs) + } +} + +// Implementation of ArrayCheckedMul trait for PrimitiveArrays with a scalar +impl ArrayCheckedMul for PrimitiveArray +where + T: NativeArithmetics + CheckedMul, +{ + fn checked_mul(&self, rhs: &T) -> Self { + checked_mul_scalar(self, rhs) + } +} + +// Implementation of ArraySaturatingMul trait for PrimitiveArrays with a scalar +impl ArraySaturatingMul for PrimitiveArray +where + T: NativeArithmetics + SaturatingMul, +{ + fn saturating_mul(&self, rhs: &T) -> Self { + saturating_mul_scalar(self, rhs) + } +} + +// Implementation of ArraySaturatingMul trait for PrimitiveArrays with a scalar +impl ArrayOverflowingMul for PrimitiveArray +where + T: NativeArithmetics + OverflowingMul, +{ + fn overflowing_mul(&self, rhs: &T) -> (Self, Bitmap) { + overflowing_mul_scalar(self, rhs) + } +} diff --git a/crates/polars-arrow/src/compute/arithmetics/basic/pow.rs b/crates/polars-arrow/src/compute/arithmetics/basic/pow.rs new file mode 100644 index 000000000000..173c4a351aa5 --- /dev/null +++ b/crates/polars-arrow/src/compute/arithmetics/basic/pow.rs @@ -0,0 +1,49 @@ +//! Definition of basic pow operations with primitive arrays +use num_traits::{checked_pow, CheckedMul, One, Pow}; + +use super::NativeArithmetics; +use crate::array::PrimitiveArray; +use crate::compute::arity::{unary, unary_checked}; + +/// Raises an array of primitives to the power of exponent. Panics if one of +/// the values values overflows. +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::basic::powf_scalar; +/// use polars_arrow::array::Float32Array; +/// +/// let a = Float32Array::from(&[Some(2f32), None]); +/// let actual = powf_scalar(&a, 2.0); +/// let expected = Float32Array::from(&[Some(4f32), None]); +/// assert_eq!(expected, actual); +/// ``` +pub fn powf_scalar(array: &PrimitiveArray, exponent: T) -> PrimitiveArray +where + T: NativeArithmetics + Pow, +{ + unary(array, |x| x.pow(exponent), array.data_type().clone()) +} + +/// Checked operation of raising an array of primitives to the power of +/// exponent. If the result from the multiplications overflows, the validity +/// for that index is changed returned. +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::basic::checked_powf_scalar; +/// use polars_arrow::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(1i8), None, Some(7i8)]); +/// let actual = checked_powf_scalar(&a, 8usize); +/// let expected = Int8Array::from(&[Some(1i8), None, None]); +/// assert_eq!(expected, actual); +/// ``` +pub fn checked_powf_scalar(array: &PrimitiveArray, exponent: usize) -> PrimitiveArray +where + T: NativeArithmetics + CheckedMul + One, +{ + let op = move |a: T| checked_pow(a, exponent); + + unary_checked(array, op, array.data_type().clone()) +} diff --git a/crates/polars-arrow/src/compute/arithmetics/basic/rem.rs b/crates/polars-arrow/src/compute/arithmetics/basic/rem.rs index 46eeb16cb8c6..d0ac512b5604 100644 --- a/crates/polars-arrow/src/compute/arithmetics/basic/rem.rs +++ b/crates/polars-arrow/src/compute/arithmetics/basic/rem.rs @@ -1,17 +1,30 @@ use std::ops::Rem; -use num_traits::NumCast; +use num_traits::{CheckedRem, NumCast}; use strength_reduce::{ StrengthReducedU16, StrengthReducedU32, StrengthReducedU64, StrengthReducedU8, }; use super::NativeArithmetics; use crate::array::{Array, PrimitiveArray}; -use crate::compute::arity::{binary, unary}; +use crate::compute::arithmetics::{ArrayCheckedRem, ArrayRem}; +use crate::compute::arity::{binary, binary_checked, unary, unary_checked}; use crate::datatypes::PrimitiveType; /// Remainder of two primitive arrays with the same type. /// Panics if the divisor is zero of one pair of values overflows. +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::basic::rem; +/// use polars_arrow::array::Int32Array; +/// +/// let a = Int32Array::from(&[Some(10), Some(7)]); +/// let b = Int32Array::from(&[Some(5), Some(6)]); +/// let result = rem(&a, &b); +/// let expected = Int32Array::from(&[Some(0), Some(1)]); +/// assert_eq!(result, expected) +/// ``` pub fn rem(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray where T: NativeArithmetics + Rem, @@ -19,8 +32,61 @@ where binary(lhs, rhs, lhs.data_type().clone(), |a, b| a % b) } +/// Checked remainder of two primitive arrays. If the result from the remainder +/// overflows, the result for the operation will change the validity array +/// making this operation None +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::basic::checked_rem; +/// use polars_arrow::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(-100i8), Some(10i8)]); +/// let b = Int8Array::from(&[Some(100i8), Some(0i8)]); +/// let result = checked_rem(&a, &b); +/// let expected = Int8Array::from(&[Some(-0i8), None]); +/// assert_eq!(result, expected); +/// ``` +pub fn checked_rem(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + CheckedRem, +{ + let op = move |a: T, b: T| a.checked_rem(&b); + + binary_checked(lhs, rhs, lhs.data_type().clone(), op) +} + +impl ArrayRem> for PrimitiveArray +where + T: NativeArithmetics + Rem, +{ + fn rem(&self, rhs: &PrimitiveArray) -> Self { + rem(self, rhs) + } +} + +impl ArrayCheckedRem> for PrimitiveArray +where + T: NativeArithmetics + CheckedRem, +{ + fn checked_rem(&self, rhs: &PrimitiveArray) -> Self { + checked_rem(self, rhs) + } +} + /// Remainder a primitive array of type T by a scalar T. /// Panics if the divisor is zero. +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::basic::rem_scalar; +/// use polars_arrow::array::Int32Array; +/// +/// let a = Int32Array::from(&[None, Some(6), None, Some(7)]); +/// let result = rem_scalar(&a, &2i32); +/// let expected = Int32Array::from(&[None, Some(0), None, Some(1)]); +/// assert_eq!(result, expected) +/// ``` pub fn rem_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray where T: NativeArithmetics + Rem + NumCast, @@ -87,3 +153,44 @@ where _ => unary(lhs, |a| a % rhs, lhs.data_type().clone()), } } + +/// Checked remainder of a primitive array of type T by a scalar T. If the +/// divisor is zero then the validity array is changed to None. +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::basic::checked_rem_scalar; +/// use polars_arrow::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(-100i8)]); +/// let result = checked_rem_scalar(&a, &100i8); +/// let expected = Int8Array::from(&[Some(0i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn checked_rem_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + CheckedRem, +{ + let rhs = *rhs; + let op = move |a: T| a.checked_rem(&rhs); + + unary_checked(lhs, op, lhs.data_type().clone()) +} + +impl ArrayRem for PrimitiveArray +where + T: NativeArithmetics + Rem + NumCast, +{ + fn rem(&self, rhs: &T) -> Self { + rem_scalar(self, rhs) + } +} + +impl ArrayCheckedRem for PrimitiveArray +where + T: NativeArithmetics + CheckedRem, +{ + fn checked_rem(&self, rhs: &T) -> Self { + checked_rem_scalar(self, rhs) + } +} diff --git a/crates/polars-arrow/src/compute/arithmetics/basic/sub.rs b/crates/polars-arrow/src/compute/arithmetics/basic/sub.rs index 33acb99b3ef6..43f267c6bf13 100644 --- a/crates/polars-arrow/src/compute/arithmetics/basic/sub.rs +++ b/crates/polars-arrow/src/compute/arithmetics/basic/sub.rs @@ -1,12 +1,33 @@ //! Definition of basic sub operations with primitive arrays use std::ops::Sub; +use num_traits::ops::overflowing::OverflowingSub; +use num_traits::{CheckedSub, SaturatingSub, WrappingSub}; + use super::NativeArithmetics; use crate::array::PrimitiveArray; -use crate::compute::arity::{binary, unary}; +use crate::bitmap::Bitmap; +use crate::compute::arithmetics::{ + ArrayCheckedSub, ArrayOverflowingSub, ArraySaturatingSub, ArraySub, ArrayWrappingSub, +}; +use crate::compute::arity::{ + binary, binary_checked, binary_with_bitmap, unary, unary_checked, unary_with_bitmap, +}; /// Subtracts two primitive arrays with the same type. /// Panics if the subtraction of one pair of values overflows. +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::basic::sub; +/// use polars_arrow::array::Int32Array; +/// +/// let a = Int32Array::from(&[None, Some(6), None, Some(6)]); +/// let b = Int32Array::from(&[Some(5), None, None, Some(6)]); +/// let result = sub(&a, &b); +/// let expected = Int32Array::from(&[None, None, None, Some(0)]); +/// assert_eq!(result, expected) +/// ``` pub fn sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray where T: NativeArithmetics + Sub, @@ -14,8 +35,166 @@ where binary(lhs, rhs, lhs.data_type().clone(), |a, b| a - b) } +/// Wrapping subtraction of two [`PrimitiveArray`]s. +/// It wraps around at the boundary of the type if the result overflows. +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::basic::wrapping_sub; +/// use polars_arrow::array::PrimitiveArray; +/// +/// let a = PrimitiveArray::from([Some(-100i8), Some(-100i8), Some(100i8)]); +/// let b = PrimitiveArray::from([Some(0i8), Some(100i8), Some(0i8)]); +/// let result = wrapping_sub(&a, &b); +/// let expected = PrimitiveArray::from([Some(-100i8), Some(56i8), Some(100i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn wrapping_sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + WrappingSub, +{ + let op = move |a: T, b: T| a.wrapping_sub(&b); + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Checked subtraction of two primitive arrays. If the result from the +/// subtraction overflow, the validity for that index is changed +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::basic::checked_sub; +/// use polars_arrow::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(100i8), Some(-100i8), Some(100i8)]); +/// let b = Int8Array::from(&[Some(1i8), Some(100i8), Some(0i8)]); +/// let result = checked_sub(&a, &b); +/// let expected = Int8Array::from(&[Some(99i8), None, Some(100i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn checked_sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + CheckedSub, +{ + let op = move |a: T, b: T| a.checked_sub(&b); + + binary_checked(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Saturating subtraction of two primitive arrays. If the result from the sub +/// is smaller than the possible number for this type, the result for the +/// operation will be the saturated value. +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::basic::saturating_sub; +/// use polars_arrow::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(-100i8)]); +/// let b = Int8Array::from(&[Some(100i8)]); +/// let result = saturating_sub(&a, &b); +/// let expected = Int8Array::from(&[Some(-128)]); +/// assert_eq!(result, expected); +/// ``` +pub fn saturating_sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + SaturatingSub, +{ + let op = move |a: T, b: T| a.saturating_sub(&b); + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Overflowing subtraction of two primitive arrays. If the result from the sub +/// is smaller than the possible number for this type, the result for the +/// operation will be an array with overflowed values and a validity array +/// indicating the overflowing elements from the array. +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::basic::overflowing_sub; +/// use polars_arrow::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(1i8), Some(-100i8)]); +/// let b = Int8Array::from(&[Some(1i8), Some(100i8)]); +/// let (result, overflow) = overflowing_sub(&a, &b); +/// let expected = Int8Array::from(&[Some(0i8), Some(56i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn overflowing_sub( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> (PrimitiveArray, Bitmap) +where + T: NativeArithmetics + OverflowingSub, +{ + let op = move |a: T, b: T| a.overflowing_sub(&b); + + binary_with_bitmap(lhs, rhs, lhs.data_type().clone(), op) +} + +// Implementation of ArraySub trait for PrimitiveArrays +impl ArraySub> for PrimitiveArray +where + T: NativeArithmetics + Sub, +{ + fn sub(&self, rhs: &PrimitiveArray) -> Self { + sub(self, rhs) + } +} + +impl ArrayWrappingSub> for PrimitiveArray +where + T: NativeArithmetics + WrappingSub, +{ + fn wrapping_sub(&self, rhs: &PrimitiveArray) -> Self { + wrapping_sub(self, rhs) + } +} + +// Implementation of ArrayCheckedSub trait for PrimitiveArrays +impl ArrayCheckedSub> for PrimitiveArray +where + T: NativeArithmetics + CheckedSub, +{ + fn checked_sub(&self, rhs: &PrimitiveArray) -> Self { + checked_sub(self, rhs) + } +} + +// Implementation of ArraySaturatingSub trait for PrimitiveArrays +impl ArraySaturatingSub> for PrimitiveArray +where + T: NativeArithmetics + SaturatingSub, +{ + fn saturating_sub(&self, rhs: &PrimitiveArray) -> Self { + saturating_sub(self, rhs) + } +} + +// Implementation of ArraySaturatingSub trait for PrimitiveArrays +impl ArrayOverflowingSub> for PrimitiveArray +where + T: NativeArithmetics + OverflowingSub, +{ + fn overflowing_sub(&self, rhs: &PrimitiveArray) -> (Self, Bitmap) { + overflowing_sub(self, rhs) + } +} + /// Subtract a scalar T to a primitive array of type T. /// Panics if the subtraction of the values overflows. +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::basic::sub_scalar; +/// use polars_arrow::array::Int32Array; +/// +/// let a = Int32Array::from(&[None, Some(6), None, Some(6)]); +/// let result = sub_scalar(&a, &1i32); +/// let expected = Int32Array::from(&[None, Some(5), None, Some(5)]); +/// assert_eq!(result, expected) +/// ``` pub fn sub_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray where T: NativeArithmetics + Sub, @@ -23,3 +202,136 @@ where let rhs = *rhs; unary(lhs, |a| a - rhs, lhs.data_type().clone()) } + +/// Wrapping subtraction of a scalar T to a [`PrimitiveArray`] of type T. +/// It do nothing if the result overflows. +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::basic::wrapping_sub_scalar; +/// use polars_arrow::array::Int8Array; +/// +/// let a = Int8Array::from(&[None, Some(-100)]); +/// let result = wrapping_sub_scalar(&a, &100i8); +/// let expected = Int8Array::from(&[None, Some(56)]); +/// assert_eq!(result, expected); +/// ``` +pub fn wrapping_sub_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + WrappingSub, +{ + unary(lhs, |a| a.wrapping_sub(rhs), lhs.data_type().clone()) +} + +/// Checked subtraction of a scalar T to a primitive array of type T. If the +/// result from the subtraction overflows, then the validity for that index +/// is changed to None +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::basic::checked_sub_scalar; +/// use polars_arrow::array::Int8Array; +/// +/// let a = Int8Array::from(&[None, Some(-100), None, Some(-100)]); +/// let result = checked_sub_scalar(&a, &100i8); +/// let expected = Int8Array::from(&[None, None, None, None]); +/// assert_eq!(result, expected); +/// ``` +pub fn checked_sub_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + CheckedSub, +{ + let rhs = *rhs; + let op = move |a: T| a.checked_sub(&rhs); + + unary_checked(lhs, op, lhs.data_type().clone()) +} + +/// Saturated subtraction of a scalar T to a primitive array of type T. If the +/// result from the sub is smaller than the possible number for this type, then +/// the result will be saturated +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::basic::saturating_sub_scalar; +/// use polars_arrow::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(-100i8)]); +/// let result = saturating_sub_scalar(&a, &100i8); +/// let expected = Int8Array::from(&[Some(-128i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn saturating_sub_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + SaturatingSub, +{ + let rhs = *rhs; + let op = move |a: T| a.saturating_sub(&rhs); + + unary(lhs, op, lhs.data_type().clone()) +} + +/// Overflowing subtraction of a scalar T to a primitive array of type T. If +/// the result from the sub is smaller than the possible number for this type, +/// then the result will be an array with overflowed values and a validity +/// array indicating the overflowing elements from the array +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::basic::overflowing_sub_scalar; +/// use polars_arrow::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(1i8), Some(-100i8)]); +/// let (result, overflow) = overflowing_sub_scalar(&a, &100i8); +/// let expected = Int8Array::from(&[Some(-99i8), Some(56i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn overflowing_sub_scalar(lhs: &PrimitiveArray, rhs: &T) -> (PrimitiveArray, Bitmap) +where + T: NativeArithmetics + OverflowingSub, +{ + let rhs = *rhs; + let op = move |a: T| a.overflowing_sub(&rhs); + + unary_with_bitmap(lhs, op, lhs.data_type().clone()) +} + +// Implementation of ArraySub trait for PrimitiveArrays with a scalar +impl ArraySub for PrimitiveArray +where + T: NativeArithmetics + Sub, +{ + fn sub(&self, rhs: &T) -> Self { + sub_scalar(self, rhs) + } +} + +// Implementation of ArrayCheckedSub trait for PrimitiveArrays with a scalar +impl ArrayCheckedSub for PrimitiveArray +where + T: NativeArithmetics + CheckedSub, +{ + fn checked_sub(&self, rhs: &T) -> Self { + checked_sub_scalar(self, rhs) + } +} + +// Implementation of ArraySaturatingSub trait for PrimitiveArrays with a scalar +impl ArraySaturatingSub for PrimitiveArray +where + T: NativeArithmetics + SaturatingSub, +{ + fn saturating_sub(&self, rhs: &T) -> Self { + saturating_sub_scalar(self, rhs) + } +} + +// Implementation of ArraySaturatingSub trait for PrimitiveArrays with a scalar +impl ArrayOverflowingSub for PrimitiveArray +where + T: NativeArithmetics + OverflowingSub, +{ + fn overflowing_sub(&self, rhs: &T) -> (Self, Bitmap) { + overflowing_sub_scalar(self, rhs) + } +} diff --git a/crates/polars-arrow/src/compute/arithmetics/decimal/add.rs b/crates/polars-arrow/src/compute/arithmetics/decimal/add.rs new file mode 100644 index 000000000000..63f912e59e60 --- /dev/null +++ b/crates/polars-arrow/src/compute/arithmetics/decimal/add.rs @@ -0,0 +1,220 @@ +//! Defines the addition arithmetic kernels for [`PrimitiveArray`] representing decimals. +use polars_error::{polars_bail, PolarsResult}; + +use super::{adjusted_precision_scale, get_parameters, max_value, number_digits}; +use crate::array::PrimitiveArray; +use crate::compute::arithmetics::{ArrayAdd, ArrayCheckedAdd, ArraySaturatingAdd}; +use crate::compute::arity::{binary, binary_checked}; +use crate::compute::utils::{check_same_len, combine_validities}; +use crate::datatypes::DataType; + +/// Adds two decimal [`PrimitiveArray`] with the same precision and scale. +/// # Error +/// Errors if the precision and scale are different. +/// # Panic +/// This function panics iff the added numbers result in a number larger than +/// the possible number for the precision. +pub fn add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { + let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let max = max_value(precision); + let op = move |a, b| { + let res: i128 = a + b; + + assert!( + res.abs() <= max, + "Overflow in addition presented for precision {precision}" + ); + + res + }; + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Saturated addition of two decimal primitive arrays with the same precision +/// and scale. If the precision and scale is different, then an +/// InvalidArgumentError is returned. If the result from the sum is larger than +/// the possible number with the selected precision then the resulted number in +/// the arrow array is the maximum number for the selected precision. +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::decimal::saturating_add; +/// use polars_arrow::array::PrimitiveArray; +/// use polars_arrow::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(99000i128), Some(11100i128), None, Some(22200i128)]).to(DataType::Decimal(5, 2)); +/// let b = PrimitiveArray::from([Some(01000i128), Some(22200i128), None, Some(11100i128)]).to(DataType::Decimal(5, 2)); +/// +/// let result = saturating_add(&a, &b); +/// let expected = PrimitiveArray::from([Some(99999i128), Some(33300i128), None, Some(33300i128)]).to(DataType::Decimal(5, 2)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn saturating_add( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> PrimitiveArray { + let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let max = max_value(precision); + let op = move |a, b| { + let res: i128 = a + b; + + if res.abs() > max { + if res > 0 { + max + } else { + -max + } + } else { + res + } + }; + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Checked addition of two decimal primitive arrays with the same precision +/// and scale. If the precision and scale is different, then an +/// InvalidArgumentError is returned. If the result from the sum is larger than +/// the possible number with the selected precision (overflowing), then the +/// validity for that index is changed to None +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::decimal::checked_add; +/// use polars_arrow::array::PrimitiveArray; +/// use polars_arrow::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(99000i128), Some(11100i128), None, Some(22200i128)]).to(DataType::Decimal(5, 2)); +/// let b = PrimitiveArray::from([Some(01000i128), Some(22200i128), None, Some(11100i128)]).to(DataType::Decimal(5, 2)); +/// +/// let result = checked_add(&a, &b); +/// let expected = PrimitiveArray::from([None, Some(33300i128), None, Some(33300i128)]).to(DataType::Decimal(5, 2)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn checked_add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { + let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let max = max_value(precision); + let op = move |a, b| { + let result: i128 = a + b; + + if result.abs() > max { + None + } else { + Some(result) + } + }; + + binary_checked(lhs, rhs, lhs.data_type().clone(), op) +} + +// Implementation of ArrayAdd trait for PrimitiveArrays +impl ArrayAdd> for PrimitiveArray { + fn add(&self, rhs: &PrimitiveArray) -> Self { + add(self, rhs) + } +} + +// Implementation of ArrayCheckedAdd trait for PrimitiveArrays +impl ArrayCheckedAdd> for PrimitiveArray { + fn checked_add(&self, rhs: &PrimitiveArray) -> Self { + checked_add(self, rhs) + } +} + +// Implementation of ArraySaturatingAdd trait for PrimitiveArrays +impl ArraySaturatingAdd> for PrimitiveArray { + fn saturating_add(&self, rhs: &PrimitiveArray) -> Self { + saturating_add(self, rhs) + } +} + +/// Adaptive addition of two decimal primitive arrays with different precision +/// and scale. If the precision and scale is different, then the smallest scale +/// and precision is adjusted to the largest precision and scale. If during the +/// addition one of the results is larger than the max possible value, the +/// result precision is changed to the precision of the max value +/// +/// ```nocode +/// 11111.11 -> 7, 2 +/// 11111.111 -> 8, 3 +/// ------------------ +/// 22222.221 -> 8, 3 +/// ``` +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::decimal::adaptive_add; +/// use polars_arrow::array::PrimitiveArray; +/// use polars_arrow::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(11111_11i128)]).to(DataType::Decimal(7, 2)); +/// let b = PrimitiveArray::from([Some(11111_111i128)]).to(DataType::Decimal(8, 3)); +/// let result = adaptive_add(&a, &b).unwrap(); +/// let expected = PrimitiveArray::from([Some(22222_221i128)]).to(DataType::Decimal(8, 3)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn adaptive_add( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> PolarsResult> { + check_same_len(lhs, rhs)?; + + let (lhs_p, lhs_s, rhs_p, rhs_s) = + if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) = + (lhs.data_type(), rhs.data_type()) + { + (*lhs_p, *lhs_s, *rhs_p, *rhs_s) + } else { + polars_bail!(ComputeError: "Incorrect data type for the array") + }; + + // The resulting precision is mutable because it could change while + // looping through the iterator + let (mut res_p, res_s, diff) = adjusted_precision_scale(lhs_p, lhs_s, rhs_p, rhs_s); + + let shift = 10i128.pow(diff as u32); + let mut max = max_value(res_p); + + let values = lhs + .values() + .iter() + .zip(rhs.values().iter()) + .map(|(l, r)| { + // Based on the array's scales one of the arguments in the sum has to be shifted + // to the left to match the final scale + let res = if lhs_s > rhs_s { + l + r * shift + } else { + l * shift + r + }; + + // The precision of the resulting array will change if one of the + // sums during the iteration produces a value bigger than the + // possible value for the initial precision + + // 99.9999 -> 6, 4 + // 00.0001 -> 6, 4 + // ----------------- + // 100.0000 -> 7, 4 + if res.abs() > max { + res_p = number_digits(res); + max = max_value(res_p); + } + res + }) + .collect::>(); + + let validity = combine_validities(lhs.validity(), rhs.validity()); + + Ok(PrimitiveArray::::new( + DataType::Decimal(res_p, res_s), + values.into(), + validity, + )) +} diff --git a/crates/polars-arrow/src/compute/arithmetics/decimal/div.rs b/crates/polars-arrow/src/compute/arithmetics/decimal/div.rs new file mode 100644 index 000000000000..6516717e3239 --- /dev/null +++ b/crates/polars-arrow/src/compute/arithmetics/decimal/div.rs @@ -0,0 +1,301 @@ +//! Defines the division arithmetic kernels for Decimal +//! `PrimitiveArrays`. + +use polars_error::{polars_bail, PolarsResult}; + +use super::{adjusted_precision_scale, get_parameters, max_value, number_digits}; +use crate::array::PrimitiveArray; +use crate::compute::arithmetics::{ArrayCheckedDiv, ArrayDiv}; +use crate::compute::arity::{binary, binary_checked, unary}; +use crate::compute::utils::{check_same_len, combine_validities}; +use crate::datatypes::DataType; +use crate::scalar::{PrimitiveScalar, Scalar}; + +/// Divide two decimal primitive arrays with the same precision and scale. If +/// the precision and scale is different, then an InvalidArgumentError is +/// returned. This function panics if the dividend is divided by 0 or None. +/// This function also panics if the division produces a number larger +/// than the possible number for the array precision. +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::decimal::div; +/// use polars_arrow::array::PrimitiveArray; +/// use polars_arrow::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(1_00i128), Some(4_00i128), Some(6_00i128)]).to(DataType::Decimal(5, 2)); +/// let b = PrimitiveArray::from([Some(1_00i128), Some(2_00i128), Some(2_00i128)]).to(DataType::Decimal(5, 2)); +/// +/// let result = div(&a, &b); +/// let expected = PrimitiveArray::from([Some(1_00i128), Some(2_00i128), Some(3_00i128)]).to(DataType::Decimal(5, 2)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn div(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { + let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let scale = 10i128.pow(scale as u32); + let max = max_value(precision); + let op = move |a: i128, b: i128| { + // The division is done using the numbers without scale. + // The dividend is scaled up to maintain precision after the + // division + + // 222.222 --> 222222000 + // 123.456 --> 123456 + // -------- --------- + // 1.800 <-- 1800 + let numeral: i128 = a * scale; + + // The division can overflow if the dividend is divided + // by zero. + let res: i128 = numeral.checked_div(b).expect("Found division by zero"); + + assert!( + res.abs() <= max, + "Overflow in multiplication presented for precision {precision}" + ); + + res + }; + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Multiply a decimal [`PrimitiveArray`] with a [`PrimitiveScalar`] with the same precision and scale. If +/// the precision and scale is different, then an InvalidArgumentError is +/// returned. This function panics if the multiplied numbers result in a number +/// larger than the possible number for the selected precision. +pub fn div_scalar(lhs: &PrimitiveArray, rhs: &PrimitiveScalar) -> PrimitiveArray { + let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let rhs = if let Some(rhs) = *rhs.value() { + rhs + } else { + return PrimitiveArray::::new_null(lhs.data_type().clone(), lhs.len()); + }; + + let scale = 10i128.pow(scale as u32); + let max = max_value(precision); + + let op = move |a: i128| { + // The division is done using the numbers without scale. + // The dividend is scaled up to maintain precision after the + // division + + // 222.222 --> 222222000 + // 123.456 --> 123456 + // -------- --------- + // 1.800 <-- 1800 + let numeral: i128 = a * scale; + + // The division can overflow if the dividend is divided + // by zero. + let res: i128 = numeral.checked_div(rhs).expect("Found division by zero"); + + assert!( + res.abs() <= max, + "Overflow in multiplication presented for precision {precision}" + ); + + res + }; + + unary(lhs, op, lhs.data_type().clone()) +} + +/// Saturated division of two decimal primitive arrays with the same +/// precision and scale. If the precision and scale is different, then an +/// InvalidArgumentError is returned. If the result from the division is +/// larger than the possible number with the selected precision then the +/// resulted number in the arrow array is the maximum number for the selected +/// precision. The function panics if divided by zero. +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::decimal::saturating_div; +/// use polars_arrow::array::PrimitiveArray; +/// use polars_arrow::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(999_99i128), Some(4_00i128), Some(6_00i128)]).to(DataType::Decimal(5, 2)); +/// let b = PrimitiveArray::from([Some(000_01i128), Some(2_00i128), Some(2_00i128)]).to(DataType::Decimal(5, 2)); +/// +/// let result = saturating_div(&a, &b); +/// let expected = PrimitiveArray::from([Some(999_99i128), Some(2_00i128), Some(3_00i128)]).to(DataType::Decimal(5, 2)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn saturating_div( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> PrimitiveArray { + let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let scale = 10i128.pow(scale as u32); + let max = max_value(precision); + + let op = move |a: i128, b: i128| { + let numeral: i128 = a * scale; + + match numeral.checked_div(b) { + Some(res) => match res { + res if res.abs() > max => { + if res > 0 { + max + } else { + -max + } + }, + _ => res, + }, + None => 0, + } + }; + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Checked division of two decimal primitive arrays with the same precision +/// and scale. If the precision and scale is different, then an +/// InvalidArgumentError is returned. If the divisor is zero, then the +/// validity for that index is changed to None +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::decimal::checked_div; +/// use polars_arrow::array::PrimitiveArray; +/// use polars_arrow::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(1_00i128), Some(4_00i128), Some(6_00i128)]).to(DataType::Decimal(5, 2)); +/// let b = PrimitiveArray::from([Some(000_00i128), None, Some(2_00i128)]).to(DataType::Decimal(5, 2)); +/// +/// let result = checked_div(&a, &b); +/// let expected = PrimitiveArray::from([None, None, Some(3_00i128)]).to(DataType::Decimal(5, 2)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn checked_div(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { + let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let scale = 10i128.pow(scale as u32); + let max = max_value(precision); + + let op = move |a: i128, b: i128| { + let numeral: i128 = a * scale; + + match numeral.checked_div(b) { + Some(res) => match res { + res if res.abs() > max => None, + _ => Some(res), + }, + None => None, + } + }; + + binary_checked(lhs, rhs, lhs.data_type().clone(), op) +} + +// Implementation of ArrayDiv trait for PrimitiveArrays +impl ArrayDiv> for PrimitiveArray { + fn div(&self, rhs: &PrimitiveArray) -> Self { + div(self, rhs) + } +} + +// Implementation of ArrayCheckedDiv trait for PrimitiveArrays +impl ArrayCheckedDiv> for PrimitiveArray { + fn checked_div(&self, rhs: &PrimitiveArray) -> Self { + checked_div(self, rhs) + } +} + +/// Adaptive division of two decimal primitive arrays with different precision +/// and scale. If the precision and scale is different, then the smallest scale +/// and precision is adjusted to the largest precision and scale. If during the +/// division one of the results is larger than the max possible value, the +/// result precision is changed to the precision of the max value. The function +/// panics when divided by zero. +/// +/// ```nocode +/// 1000.00 -> 7, 2 +/// 10.0000 -> 6, 4 +/// ----------------- +/// 100.0000 -> 9, 4 +/// ``` +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::decimal::adaptive_div; +/// use polars_arrow::array::PrimitiveArray; +/// use polars_arrow::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(1000_00i128)]).to(DataType::Decimal(7, 2)); +/// let b = PrimitiveArray::from([Some(10_0000i128)]).to(DataType::Decimal(6, 4)); +/// let result = adaptive_div(&a, &b).unwrap(); +/// let expected = PrimitiveArray::from([Some(100_0000i128)]).to(DataType::Decimal(9, 4)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn adaptive_div( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> PolarsResult> { + check_same_len(lhs, rhs)?; + + let (lhs_p, lhs_s, rhs_p, rhs_s) = + if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) = + (lhs.data_type(), rhs.data_type()) + { + (*lhs_p, *lhs_s, *rhs_p, *rhs_s) + } else { + polars_bail!(ComputeError: "Incorrect data type for the array") + }; + + // The resulting precision is mutable because it could change while + // looping through the iterator + let (mut res_p, res_s, diff) = adjusted_precision_scale(lhs_p, lhs_s, rhs_p, rhs_s); + + let shift = 10i128.pow(diff as u32); + let shift_1 = 10i128.pow(res_s as u32); + let mut max = max_value(res_p); + + let values = lhs + .values() + .iter() + .zip(rhs.values().iter()) + .map(|(l, r)| { + let numeral: i128 = l * shift_1; + + // Based on the array's scales one of the arguments in the sum has to be shifted + // to the left to match the final scale + let res = if lhs_s > rhs_s { + numeral.checked_div(r * shift) + } else { + (numeral * shift).checked_div(*r) + } + .expect("Found division by zero"); + + // The precision of the resulting array will change if one of the + // multiplications during the iteration produces a value bigger + // than the possible value for the initial precision + + // 10.0000 -> 6, 4 + // 00.1000 -> 6, 4 + // ----------------- + // 100.0000 -> 7, 4 + if res.abs() > max { + res_p = number_digits(res); + max = max_value(res_p); + } + + res + }) + .collect::>(); + + let validity = combine_validities(lhs.validity(), rhs.validity()); + + Ok(PrimitiveArray::::new( + DataType::Decimal(res_p, res_s), + values.into(), + validity, + )) +} diff --git a/crates/polars-arrow/src/compute/arithmetics/decimal/mod.rs b/crates/polars-arrow/src/compute/arithmetics/decimal/mod.rs new file mode 100644 index 000000000000..d0cabb7d359a --- /dev/null +++ b/crates/polars-arrow/src/compute/arithmetics/decimal/mod.rs @@ -0,0 +1,120 @@ +//! Defines the arithmetic kernels for Decimal `PrimitiveArrays`. The +//! [`Decimal`](crate::datatypes::DataType::Decimal) type specifies the +//! precision and scale parameters. These affect the arithmetic operations and +//! need to be considered while doing operations with Decimal numbers. + +mod add; +pub use add::*; +mod div; +pub use div::*; +mod mul; +pub use mul::*; +use polars_error::{polars_bail, PolarsResult}; + +mod sub; +pub use sub::*; + +use crate::datatypes::DataType; + +/// Maximum value that can exist with a selected precision +#[inline] +fn max_value(precision: usize) -> i128 { + 10i128.pow(precision as u32) - 1 +} + +// Calculates the number of digits in a i128 number +fn number_digits(num: i128) -> usize { + let mut num = num.abs(); + let mut digit: i128 = 0; + let base = 10i128; + + while num != 0 { + num /= base; + digit += 1; + } + + digit as usize +} + +fn get_parameters(lhs: &DataType, rhs: &DataType) -> PolarsResult<(usize, usize)> { + if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) = + (lhs.to_logical_type(), rhs.to_logical_type()) + { + if lhs_p == rhs_p && lhs_s == rhs_s { + Ok((*lhs_p, *lhs_s)) + } else { + polars_bail!(InvalidOperation: + "Arrays must have the same precision and scale" + ) + } + } else { + unreachable!() + } +} + +/// Returns the adjusted precision and scale for the lhs and rhs precision and +/// scale +fn adjusted_precision_scale( + lhs_p: usize, + lhs_s: usize, + rhs_p: usize, + rhs_s: usize, +) -> (usize, usize, usize) { + // The initial new precision and scale is based on the number of digits + // that lhs and rhs number has before and after the point. The max + // number of digits before and after the point will make the last + // precision and scale of the result + + // Digits before/after point + // before after + // 11.1111 -> 5, 4 -> 2 4 + // 11111.01 -> 7, 2 -> 5 2 + // ----------------- + // 11122.1211 -> 9, 4 -> 5 4 + let lhs_digits_before = lhs_p - lhs_s; + let rhs_digits_before = rhs_p - rhs_s; + + let res_digits_before = std::cmp::max(lhs_digits_before, rhs_digits_before); + + let (res_s, diff) = if lhs_s > rhs_s { + (lhs_s, lhs_s - rhs_s) + } else { + (rhs_s, rhs_s - lhs_s) + }; + + let res_p = res_digits_before + res_s; + + (res_p, res_s, diff) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_max_value() { + assert_eq!(999, max_value(3)); + assert_eq!(99999, max_value(5)); + assert_eq!(999999, max_value(6)); + } + + #[test] + fn test_number_digits() { + assert_eq!(2, number_digits(12i128)); + assert_eq!(3, number_digits(123i128)); + assert_eq!(4, number_digits(1234i128)); + assert_eq!(6, number_digits(123456i128)); + assert_eq!(7, number_digits(1234567i128)); + assert_eq!(7, number_digits(-1234567i128)); + assert_eq!(3, number_digits(-123i128)); + } + + #[test] + fn test_adjusted_precision_scale() { + // 11.1111 -> 5, 4 -> 2 4 + // 11111.01 -> 7, 2 -> 5 2 + // ----------------- + // 11122.1211 -> 9, 4 -> 5 4 + assert_eq!((9, 4, 2), adjusted_precision_scale(5, 4, 7, 2)) + } +} diff --git a/crates/polars-arrow/src/compute/arithmetics/decimal/mul.rs b/crates/polars-arrow/src/compute/arithmetics/decimal/mul.rs new file mode 100644 index 000000000000..698b47717ffb --- /dev/null +++ b/crates/polars-arrow/src/compute/arithmetics/decimal/mul.rs @@ -0,0 +1,313 @@ +//! Defines the multiplication arithmetic kernels for Decimal +//! `PrimitiveArrays`. + +use polars_error::{polars_bail, PolarsResult}; + +use super::{adjusted_precision_scale, get_parameters, max_value, number_digits}; +use crate::array::PrimitiveArray; +use crate::compute::arithmetics::{ArrayCheckedMul, ArrayMul, ArraySaturatingMul}; +use crate::compute::arity::{binary, binary_checked, unary}; +use crate::compute::utils::{check_same_len, combine_validities}; +use crate::datatypes::DataType; +use crate::scalar::{PrimitiveScalar, Scalar}; + +/// Multiply two decimal primitive arrays with the same precision and scale. If +/// the precision and scale is different, then an InvalidArgumentError is +/// returned. This function panics if the multiplied numbers result in a number +/// larger than the possible number for the selected precision. +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::decimal::mul; +/// use polars_arrow::array::PrimitiveArray; +/// use polars_arrow::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(1_00i128), Some(1_00i128), None, Some(2_00i128)]).to(DataType::Decimal(5, 2)); +/// let b = PrimitiveArray::from([Some(1_00i128), Some(2_00i128), None, Some(2_00i128)]).to(DataType::Decimal(5, 2)); +/// +/// let result = mul(&a, &b); +/// let expected = PrimitiveArray::from([Some(1_00i128), Some(2_00i128), None, Some(4_00i128)]).to(DataType::Decimal(5, 2)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { + let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let scale = 10i128.pow(scale as u32); + let max = max_value(precision); + + let op = move |a: i128, b: i128| { + // The multiplication between i128 can overflow if they are + // very large numbers. For that reason a checked + // multiplication is used. + let res: i128 = a.checked_mul(b).expect("Mayor overflow for multiplication"); + + // The multiplication is done using the numbers without scale. + // The resulting scale of the value has to be corrected by + // dividing by (10^scale) + + // 111.111 --> 111111 + // 222.222 --> 222222 + // -------- ------- + // 24691.308 <-- 24691308642 + let res = res / scale; + + assert!( + res.abs() <= max, + "Overflow in multiplication presented for precision {precision}" + ); + + res + }; + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Multiply a decimal [`PrimitiveArray`] with a [`PrimitiveScalar`] with the same precision and scale. If +/// the precision and scale is different, then an InvalidArgumentError is +/// returned. This function panics if the multiplied numbers result in a number +/// larger than the possible number for the selected precision. +pub fn mul_scalar(lhs: &PrimitiveArray, rhs: &PrimitiveScalar) -> PrimitiveArray { + let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let rhs = if let Some(rhs) = *rhs.value() { + rhs + } else { + return PrimitiveArray::::new_null(lhs.data_type().clone(), lhs.len()); + }; + + let scale = 10i128.pow(scale as u32); + let max = max_value(precision); + + let op = move |a: i128| { + // The multiplication between i128 can overflow if they are + // very large numbers. For that reason a checked + // multiplication is used. + let res: i128 = a + .checked_mul(rhs) + .expect("Mayor overflow for multiplication"); + + // The multiplication is done using the numbers without scale. + // The resulting scale of the value has to be corrected by + // dividing by (10^scale) + + // 111.111 --> 111111 + // 222.222 --> 222222 + // -------- ------- + // 24691.308 <-- 24691308642 + let res = res / scale; + + assert!( + res.abs() <= max, + "Overflow in multiplication presented for precision {precision}" + ); + + res + }; + + unary(lhs, op, lhs.data_type().clone()) +} + +/// Saturated multiplication of two decimal primitive arrays with the same +/// precision and scale. If the precision and scale is different, then an +/// InvalidArgumentError is returned. If the result from the multiplication is +/// larger than the possible number with the selected precision then the +/// resulted number in the arrow array is the maximum number for the selected +/// precision. +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::decimal::saturating_mul; +/// use polars_arrow::array::PrimitiveArray; +/// use polars_arrow::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(999_99i128), Some(1_00i128), None, Some(2_00i128)]).to(DataType::Decimal(5, 2)); +/// let b = PrimitiveArray::from([Some(10_00i128), Some(2_00i128), None, Some(2_00i128)]).to(DataType::Decimal(5, 2)); +/// +/// let result = saturating_mul(&a, &b); +/// let expected = PrimitiveArray::from([Some(999_99i128), Some(2_00i128), None, Some(4_00i128)]).to(DataType::Decimal(5, 2)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn saturating_mul( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> PrimitiveArray { + let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let scale = 10i128.pow(scale as u32); + let max = max_value(precision); + + let op = move |a: i128, b: i128| match a.checked_mul(b) { + Some(res) => { + let res = res / scale; + + match res { + res if res.abs() > max => { + if res > 0 { + max + } else { + -max + } + }, + _ => res, + } + }, + None => max, + }; + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Checked multiplication of two decimal primitive arrays with the same +/// precision and scale. If the precision and scale is different, then an +/// InvalidArgumentError is returned. If the result from the mul is larger than +/// the possible number with the selected precision (overflowing), then the +/// validity for that index is changed to None +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::decimal::checked_mul; +/// use polars_arrow::array::PrimitiveArray; +/// use polars_arrow::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(999_99i128), Some(1_00i128), None, Some(2_00i128)]).to(DataType::Decimal(5, 2)); +/// let b = PrimitiveArray::from([Some(10_00i128), Some(2_00i128), None, Some(2_00i128)]).to(DataType::Decimal(5, 2)); +/// +/// let result = checked_mul(&a, &b); +/// let expected = PrimitiveArray::from([None, Some(2_00i128), None, Some(4_00i128)]).to(DataType::Decimal(5, 2)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn checked_mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { + let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let scale = 10i128.pow(scale as u32); + let max = max_value(precision); + + let op = move |a: i128, b: i128| match a.checked_mul(b) { + Some(res) => { + let res = res / scale; + + match res { + res if res.abs() > max => None, + _ => Some(res), + } + }, + None => None, + }; + + binary_checked(lhs, rhs, lhs.data_type().clone(), op) +} + +// Implementation of ArrayMul trait for PrimitiveArrays +impl ArrayMul> for PrimitiveArray { + fn mul(&self, rhs: &PrimitiveArray) -> Self { + mul(self, rhs) + } +} + +// Implementation of ArrayCheckedMul trait for PrimitiveArrays +impl ArrayCheckedMul> for PrimitiveArray { + fn checked_mul(&self, rhs: &PrimitiveArray) -> Self { + checked_mul(self, rhs) + } +} + +// Implementation of ArraySaturatingMul trait for PrimitiveArrays +impl ArraySaturatingMul> for PrimitiveArray { + fn saturating_mul(&self, rhs: &PrimitiveArray) -> Self { + saturating_mul(self, rhs) + } +} + +/// Adaptive multiplication of two decimal primitive arrays with different +/// precision and scale. If the precision and scale is different, then the +/// smallest scale and precision is adjusted to the largest precision and +/// scale. If during the multiplication one of the results is larger than the +/// max possible value, the result precision is changed to the precision of the +/// max value +/// +/// ```nocode +/// 11111.0 -> 6, 1 +/// 10.002 -> 5, 3 +/// ----------------- +/// 111132.222 -> 9, 3 +/// ``` +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::decimal::adaptive_mul; +/// use polars_arrow::array::PrimitiveArray; +/// use polars_arrow::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(11111_0i128), Some(1_0i128)]).to(DataType::Decimal(6, 1)); +/// let b = PrimitiveArray::from([Some(10_002i128), Some(2_000i128)]).to(DataType::Decimal(5, 3)); +/// let result = adaptive_mul(&a, &b).unwrap(); +/// let expected = PrimitiveArray::from([Some(111132_222i128), Some(2_000i128)]).to(DataType::Decimal(9, 3)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn adaptive_mul( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> PolarsResult> { + check_same_len(lhs, rhs)?; + + let (lhs_p, lhs_s, rhs_p, rhs_s) = + if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) = + (lhs.data_type(), rhs.data_type()) + { + (*lhs_p, *lhs_s, *rhs_p, *rhs_s) + } else { + polars_bail!(ComputeError: "Incorrect data type for the array") + }; + + // The resulting precision is mutable because it could change while + // looping through the iterator + let (mut res_p, res_s, diff) = adjusted_precision_scale(lhs_p, lhs_s, rhs_p, rhs_s); + + let shift = 10i128.pow(diff as u32); + let shift_1 = 10i128.pow(res_s as u32); + let mut max = max_value(res_p); + + let values = lhs + .values() + .iter() + .zip(rhs.values().iter()) + .map(|(l, r)| { + // Based on the array's scales one of the arguments in the sum has to be shifted + // to the left to match the final scale + let res = if lhs_s > rhs_s { + l.checked_mul(r * shift) + } else { + (l * shift).checked_mul(*r) + } + .expect("Mayor overflow for multiplication"); + + let res = res / shift_1; + + // The precision of the resulting array will change if one of the + // multiplications during the iteration produces a value bigger + // than the possible value for the initial precision + + // 10.0000 -> 6, 4 + // 10.0000 -> 6, 4 + // ----------------- + // 100.0000 -> 7, 4 + if res.abs() > max { + res_p = number_digits(res); + max = max_value(res_p); + } + + res + }) + .collect::>(); + + let validity = combine_validities(lhs.validity(), rhs.validity()); + + Ok(PrimitiveArray::::new( + DataType::Decimal(res_p, res_s), + values.into(), + validity, + )) +} diff --git a/crates/polars-arrow/src/compute/arithmetics/decimal/sub.rs b/crates/polars-arrow/src/compute/arithmetics/decimal/sub.rs new file mode 100644 index 000000000000..73840acc34b4 --- /dev/null +++ b/crates/polars-arrow/src/compute/arithmetics/decimal/sub.rs @@ -0,0 +1,237 @@ +//! Defines the subtract arithmetic kernels for Decimal `PrimitiveArrays`. + +use polars_error::{polars_bail, PolarsResult}; + +use super::{adjusted_precision_scale, get_parameters, max_value, number_digits}; +use crate::array::PrimitiveArray; +use crate::compute::arithmetics::{ArrayCheckedSub, ArraySaturatingSub, ArraySub}; +use crate::compute::arity::{binary, binary_checked}; +use crate::compute::utils::{check_same_len, combine_validities}; +use crate::datatypes::DataType; + +/// Subtract two decimal primitive arrays with the same precision and scale. If +/// the precision and scale is different, then an InvalidArgumentError is +/// returned. This function panics if the subtracted numbers result in a number +/// smaller than the possible number for the selected precision. +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::decimal::sub; +/// use polars_arrow::array::PrimitiveArray; +/// use polars_arrow::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(1i128), Some(1i128), None, Some(2i128)]).to(DataType::Decimal(5, 2)); +/// let b = PrimitiveArray::from([Some(1i128), Some(2i128), None, Some(2i128)]).to(DataType::Decimal(5, 2)); +/// +/// let result = sub(&a, &b); +/// let expected = PrimitiveArray::from([Some(0i128), Some(-1i128), None, Some(0i128)]).to(DataType::Decimal(5, 2)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { + let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let max = max_value(precision); + + let op = move |a, b| { + let res: i128 = a - b; + + assert!( + res.abs() <= max, + "Overflow in subtract presented for precision {precision}" + ); + + res + }; + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Saturated subtraction of two decimal primitive arrays with the same +/// precision and scale. If the precision and scale is different, then an +/// InvalidArgumentError is returned. If the result from the sum is smaller +/// than the possible number with the selected precision then the resulted +/// number in the arrow array is the minimum number for the selected precision. +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::decimal::saturating_sub; +/// use polars_arrow::array::PrimitiveArray; +/// use polars_arrow::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(-99000i128), Some(11100i128), None, Some(22200i128)]).to(DataType::Decimal(5, 2)); +/// let b = PrimitiveArray::from([Some(01000i128), Some(22200i128), None, Some(11100i128)]).to(DataType::Decimal(5, 2)); +/// +/// let result = saturating_sub(&a, &b); +/// let expected = PrimitiveArray::from([Some(-99999i128), Some(-11100i128), None, Some(11100i128)]).to(DataType::Decimal(5, 2)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn saturating_sub( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> PrimitiveArray { + let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let max = max_value(precision); + + let op = move |a, b| { + let res: i128 = a - b; + + match res { + res if res.abs() > max => { + if res > 0 { + max + } else { + -max + } + }, + _ => res, + } + }; + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +// Implementation of ArraySub trait for PrimitiveArrays +impl ArraySub> for PrimitiveArray { + fn sub(&self, rhs: &PrimitiveArray) -> Self { + sub(self, rhs) + } +} + +// Implementation of ArrayCheckedSub trait for PrimitiveArrays +impl ArrayCheckedSub> for PrimitiveArray { + fn checked_sub(&self, rhs: &PrimitiveArray) -> Self { + checked_sub(self, rhs) + } +} + +// Implementation of ArraySaturatingSub trait for PrimitiveArrays +impl ArraySaturatingSub> for PrimitiveArray { + fn saturating_sub(&self, rhs: &PrimitiveArray) -> Self { + saturating_sub(self, rhs) + } +} +/// Checked subtract of two decimal primitive arrays with the same precision +/// and scale. If the precision and scale is different, then an +/// InvalidArgumentError is returned. If the result from the sub is larger than +/// the possible number with the selected precision (overflowing), then the +/// validity for that index is changed to None +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::decimal::checked_sub; +/// use polars_arrow::array::PrimitiveArray; +/// use polars_arrow::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(-99000i128), Some(11100i128), None, Some(22200i128)]).to(DataType::Decimal(5, 2)); +/// let b = PrimitiveArray::from([Some(01000i128), Some(22200i128), None, Some(11100i128)]).to(DataType::Decimal(5, 2)); +/// +/// let result = checked_sub(&a, &b); +/// let expected = PrimitiveArray::from([None, Some(-11100i128), None, Some(11100i128)]).to(DataType::Decimal(5, 2)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn checked_sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { + let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let max = max_value(precision); + + let op = move |a, b| { + let res: i128 = a - b; + + match res { + res if res.abs() > max => None, + _ => Some(res), + } + }; + + binary_checked(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Adaptive subtract of two decimal primitive arrays with different precision +/// and scale. If the precision and scale is different, then the smallest scale +/// and precision is adjusted to the largest precision and scale. If during the +/// addition one of the results is smaller than the min possible value, the +/// result precision is changed to the precision of the min value +/// +/// ```nocode +/// 99.9999 -> 6, 4 +/// -00.0001 -> 6, 4 +/// ----------------- +/// 100.0000 -> 7, 4 +/// ``` +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::decimal::adaptive_sub; +/// use polars_arrow::array::PrimitiveArray; +/// use polars_arrow::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(99_9999i128)]).to(DataType::Decimal(6, 4)); +/// let b = PrimitiveArray::from([Some(-00_0001i128)]).to(DataType::Decimal(6, 4)); +/// let result = adaptive_sub(&a, &b).unwrap(); +/// let expected = PrimitiveArray::from([Some(100_0000i128)]).to(DataType::Decimal(7, 4)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn adaptive_sub( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> PolarsResult> { + check_same_len(lhs, rhs)?; + + let (lhs_p, lhs_s, rhs_p, rhs_s) = + if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) = + (lhs.data_type(), rhs.data_type()) + { + (*lhs_p, *lhs_s, *rhs_p, *rhs_s) + } else { + polars_bail!(ComputeError: "Incorrect data type for the array") + }; + + // The resulting precision is mutable because it could change while + // looping through the iterator + let (mut res_p, res_s, diff) = adjusted_precision_scale(lhs_p, lhs_s, rhs_p, rhs_s); + + let shift = 10i128.pow(diff as u32); + let mut max = max_value(res_p); + + let values = lhs + .values() + .iter() + .zip(rhs.values().iter()) + .map(|(l, r)| { + // Based on the array's scales one of the arguments in the sum has to be shifted + // to the left to match the final scale + let res: i128 = if lhs_s > rhs_s { + l - r * shift + } else { + l * shift - r + }; + + // The precision of the resulting array will change if one of the + // subtraction during the iteration produces a value bigger than the + // possible value for the initial precision + + // -99.9999 -> 6, 4 + // 00.0001 -> 6, 4 + // ----------------- + // -100.0000 -> 7, 4 + if res.abs() > max { + res_p = number_digits(res); + max = max_value(res_p); + } + + res + }) + .collect::>(); + + let validity = combine_validities(lhs.validity(), rhs.validity()); + + Ok(PrimitiveArray::::new( + DataType::Decimal(res_p, res_s), + values.into(), + validity, + )) +} diff --git a/crates/polars-arrow/src/compute/arithmetics/mod.rs b/crates/polars-arrow/src/compute/arithmetics/mod.rs index 38883ee044cf..c25b265b1a35 100644 --- a/crates/polars-arrow/src/compute/arithmetics/mod.rs +++ b/crates/polars-arrow/src/compute/arithmetics/mod.rs @@ -1 +1,132 @@ +//! Defines basic arithmetic kernels for [`PrimitiveArray`](crate::array::PrimitiveArray)s. +//! +//! The Arithmetics module is composed by basic arithmetics operations that can +//! be performed on [`PrimitiveArray`](crate::array::PrimitiveArray). +//! +//! Whenever possible, each operation declares variations +//! of the basic operation that offers different guarantees: +//! * plain: panics on overflowing and underflowing. +//! * checked: turns an overflowing to a null. +//! * saturating: turns the overflowing to the MAX or MIN value respectively. +//! * overflowing: returns an extra [`Bitmap`] denoting whether the operation overflowed. +//! * adaptive: for [`Decimal`](crate::datatypes::DataType::Decimal) only, +//! adjusts the precision and scale to make the resulting value fit. +#[forbid(unsafe_code)] pub mod basic; +#[cfg(feature = "compute_arithmetics_decimal")] +pub mod decimal; + +use crate::bitmap::Bitmap; + +pub trait ArrayAdd: Sized { + /// Adds itself to `rhs` + fn add(&self, rhs: &Rhs) -> Self; +} + +/// Defines wrapping addition operation for primitive arrays +pub trait ArrayWrappingAdd: Sized { + /// Adds itself to `rhs` using wrapping addition + fn wrapping_add(&self, rhs: &Rhs) -> Self; +} + +/// Defines checked addition operation for primitive arrays +pub trait ArrayCheckedAdd: Sized { + /// Checked add + fn checked_add(&self, rhs: &Rhs) -> Self; +} + +/// Defines saturating addition operation for primitive arrays +pub trait ArraySaturatingAdd: Sized { + /// Saturating add + fn saturating_add(&self, rhs: &Rhs) -> Self; +} + +/// Defines Overflowing addition operation for primitive arrays +pub trait ArrayOverflowingAdd: Sized { + /// Overflowing add + fn overflowing_add(&self, rhs: &Rhs) -> (Self, Bitmap); +} + +/// Defines basic subtraction operation for primitive arrays +pub trait ArraySub: Sized { + /// subtraction + fn sub(&self, rhs: &Rhs) -> Self; +} + +/// Defines wrapping subtraction operation for primitive arrays +pub trait ArrayWrappingSub: Sized { + /// wrapping subtraction + fn wrapping_sub(&self, rhs: &Rhs) -> Self; +} + +/// Defines checked subtraction operation for primitive arrays +pub trait ArrayCheckedSub: Sized { + /// checked subtraction + fn checked_sub(&self, rhs: &Rhs) -> Self; +} + +/// Defines saturating subtraction operation for primitive arrays +pub trait ArraySaturatingSub: Sized { + /// saturarting subtraction + fn saturating_sub(&self, rhs: &Rhs) -> Self; +} + +/// Defines Overflowing subtraction operation for primitive arrays +pub trait ArrayOverflowingSub: Sized { + /// overflowing subtraction + fn overflowing_sub(&self, rhs: &Rhs) -> (Self, Bitmap); +} + +/// Defines basic multiplication operation for primitive arrays +pub trait ArrayMul: Sized { + /// multiplication + fn mul(&self, rhs: &Rhs) -> Self; +} + +/// Defines wrapping multiplication operation for primitive arrays +pub trait ArrayWrappingMul: Sized { + /// wrapping multiplication + fn wrapping_mul(&self, rhs: &Rhs) -> Self; +} + +/// Defines checked multiplication operation for primitive arrays +pub trait ArrayCheckedMul: Sized { + /// checked multiplication + fn checked_mul(&self, rhs: &Rhs) -> Self; +} + +/// Defines saturating multiplication operation for primitive arrays +pub trait ArraySaturatingMul: Sized { + /// saturating multiplication + fn saturating_mul(&self, rhs: &Rhs) -> Self; +} + +/// Defines Overflowing multiplication operation for primitive arrays +pub trait ArrayOverflowingMul: Sized { + /// overflowing multiplication + fn overflowing_mul(&self, rhs: &Rhs) -> (Self, Bitmap); +} + +/// Defines basic division operation for primitive arrays +pub trait ArrayDiv: Sized { + /// division + fn div(&self, rhs: &Rhs) -> Self; +} + +/// Defines checked division operation for primitive arrays +pub trait ArrayCheckedDiv: Sized { + /// checked division + fn checked_div(&self, rhs: &Rhs) -> Self; +} + +/// Defines basic reminder operation for primitive arrays +pub trait ArrayRem: Sized { + /// remainder + fn rem(&self, rhs: &Rhs) -> Self; +} + +/// Defines checked reminder operation for primitive arrays +pub trait ArrayCheckedRem: Sized { + /// checked remainder + fn checked_rem(&self, rhs: &Rhs) -> Self; +} diff --git a/crates/polars-arrow/src/compute/arithmetics/time.rs b/crates/polars-arrow/src/compute/arithmetics/time.rs new file mode 100644 index 000000000000..0e3003e638b3 --- /dev/null +++ b/crates/polars-arrow/src/compute/arithmetics/time.rs @@ -0,0 +1,432 @@ +//! Defines the arithmetic kernels for adding a Duration to a Timestamp, +//! Time32, Time64, Date32 and Date64. +//! +//! For the purposes of Arrow Implementations, adding this value to a Timestamp +//! ("t1") naively (i.e. simply summing the two number) is acceptable even +//! though in some cases the resulting Timestamp (t2) would not account for +//! leap-seconds during the elapsed time between "t1" and "t2". Similarly, +//! representing the difference between two Unix timestamp is acceptable, but +//! would yield a value that is possibly a few seconds off from the true +//! elapsed time. + +use std::ops::{Add, Sub}; + +use num_traits::AsPrimitive; +use polars_error::{polars_bail, PolarsResult}; + +use crate::array::PrimitiveArray; +use crate::compute::arity::{binary, unary}; +use crate::datatypes::{DataType, TimeUnit}; +use crate::scalar::{PrimitiveScalar, Scalar}; +use crate::temporal_conversions; +use crate::types::{months_days_ns, NativeType}; + +/// Creates the scale required to add or subtract a Duration to a time array +/// (Timestamp, Time, or Date). The resulting scale always multiplies the rhs +/// number (Duration) so it can be added to the lhs number (time array). +fn create_scale(lhs: &DataType, rhs: &DataType) -> PolarsResult { + // Matching on both data types from both numbers to calculate the correct + // scale for the operation. The timestamp, Time and duration have a + // Timeunit enum in its data type. This enum is used to describe the + // addition of the duration. The Date32 and Date64 have different rules for + // the scaling. + let scale = match (lhs, rhs) { + (DataType::Timestamp(timeunit_a, _), DataType::Duration(timeunit_b)) + | (DataType::Time32(timeunit_a), DataType::Duration(timeunit_b)) + | (DataType::Time64(timeunit_a), DataType::Duration(timeunit_b)) => { + // The scale is based on the TimeUnit that each of the numbers have. + temporal_conversions::timeunit_scale(*timeunit_a, *timeunit_b) + }, + (DataType::Date32, DataType::Duration(timeunit)) => { + // Date32 represents the time elapsed time since UNIX epoch + // (1970-01-01) in days (32 bits). The duration value has to be + // scaled to days to be able to add the value to the Date. + temporal_conversions::timeunit_scale(TimeUnit::Second, *timeunit) + / temporal_conversions::SECONDS_IN_DAY as f64 + }, + (DataType::Date64, DataType::Duration(timeunit)) => { + // Date64 represents the time elapsed time since UNIX epoch + // (1970-01-01) in milliseconds (64 bits). The duration value has + // to be scaled to milliseconds to be able to add the value to the + // Date. + temporal_conversions::timeunit_scale(TimeUnit::Millisecond, *timeunit) + }, + _ => { + polars_bail!(ComputeError: + "Incorrect data type for the arguments" + ) + }, + }; + + Ok(scale) +} + +/// Adds a duration to a time array (Timestamp, Time and Date). The timeunit +/// enum is used to scale correctly both arrays; adding seconds with seconds, +/// or milliseconds with milliseconds. +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::time::add_duration; +/// use polars_arrow::array::PrimitiveArray; +/// use polars_arrow::datatypes::{DataType, TimeUnit}; +/// +/// let timestamp = PrimitiveArray::from([ +/// Some(100000i64), +/// Some(200000i64), +/// None, +/// Some(300000i64), +/// ]) +/// .to(DataType::Timestamp( +/// TimeUnit::Second, +/// Some("America/New_York".to_string()), +/// )); +/// +/// let duration = PrimitiveArray::from([Some(10i64), Some(20i64), None, Some(30i64)]) +/// .to(DataType::Duration(TimeUnit::Second)); +/// +/// let result = add_duration(×tamp, &duration); +/// let expected = PrimitiveArray::from([ +/// Some(100010i64), +/// Some(200020i64), +/// None, +/// Some(300030i64), +/// ]) +/// .to(DataType::Timestamp( +/// TimeUnit::Second, +/// Some("America/New_York".to_string()), +/// )); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn add_duration( + time: &PrimitiveArray, + duration: &PrimitiveArray, +) -> PrimitiveArray +where + f64: AsPrimitive, + T: NativeType + Add, +{ + let scale = create_scale(time.data_type(), duration.data_type()).unwrap(); + + // Closure for the binary operation. The closure contains the scale + // required to add a duration to the timestamp array. + let op = move |a: T, b: i64| a + (b as f64 * scale).as_(); + + binary(time, duration, time.data_type().clone(), op) +} + +/// Adds a duration to a time array (Timestamp, Time and Date). The timeunit +/// enum is used to scale correctly both arrays; adding seconds with seconds, +/// or milliseconds with milliseconds. +pub fn add_duration_scalar( + time: &PrimitiveArray, + duration: &PrimitiveScalar, +) -> PrimitiveArray +where + f64: AsPrimitive, + T: NativeType + Add, +{ + let scale = create_scale(time.data_type(), duration.data_type()).unwrap(); + let duration = if let Some(duration) = *duration.value() { + duration + } else { + return PrimitiveArray::::new_null(time.data_type().clone(), time.len()); + }; + + // Closure for the binary operation. The closure contains the scale + // required to add a duration to the timestamp array. + let op = move |a: T| a + (duration as f64 * scale).as_(); + + unary(time, op, time.data_type().clone()) +} + +/// Subtract a duration to a time array (Timestamp, Time and Date). The timeunit +/// enum is used to scale correctly both arrays; adding seconds with seconds, +/// or milliseconds with milliseconds. +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::time::subtract_duration; +/// use polars_arrow::array::PrimitiveArray; +/// use polars_arrow::datatypes::{DataType, TimeUnit}; +/// +/// let timestamp = PrimitiveArray::from([ +/// Some(100000i64), +/// Some(200000i64), +/// None, +/// Some(300000i64), +/// ]) +/// .to(DataType::Timestamp( +/// TimeUnit::Second, +/// Some("America/New_York".to_string()), +/// )); +/// +/// let duration = PrimitiveArray::from([Some(10i64), Some(20i64), None, Some(30i64)]) +/// .to(DataType::Duration(TimeUnit::Second)); +/// +/// let result = subtract_duration(×tamp, &duration); +/// let expected = PrimitiveArray::from([ +/// Some(99990i64), +/// Some(199980i64), +/// None, +/// Some(299970i64), +/// ]) +/// .to(DataType::Timestamp( +/// TimeUnit::Second, +/// Some("America/New_York".to_string()), +/// )); +/// +/// assert_eq!(result, expected); +/// +/// ``` +pub fn subtract_duration( + time: &PrimitiveArray, + duration: &PrimitiveArray, +) -> PrimitiveArray +where + f64: AsPrimitive, + T: NativeType + Sub, +{ + let scale = create_scale(time.data_type(), duration.data_type()).unwrap(); + + // Closure for the binary operation. The closure contains the scale + // required to add a duration to the timestamp array. + let op = move |a: T, b: i64| a - (b as f64 * scale).as_(); + + binary(time, duration, time.data_type().clone(), op) +} + +/// Subtract a duration to a time array (Timestamp, Time and Date). The timeunit +/// enum is used to scale correctly both arrays; adding seconds with seconds, +/// or milliseconds with milliseconds. +pub fn sub_duration_scalar( + time: &PrimitiveArray, + duration: &PrimitiveScalar, +) -> PrimitiveArray +where + f64: AsPrimitive, + T: NativeType + Sub, +{ + let scale = create_scale(time.data_type(), duration.data_type()).unwrap(); + let duration = if let Some(duration) = *duration.value() { + duration + } else { + return PrimitiveArray::::new_null(time.data_type().clone(), time.len()); + }; + + let op = move |a: T| a - (duration as f64 * scale).as_(); + + unary(time, op, time.data_type().clone()) +} + +/// Calculates the difference between two timestamps returning an array of type +/// Duration. The timeunit enum is used to scale correctly both arrays; +/// subtracting seconds with seconds, or milliseconds with milliseconds. +/// +/// # Examples +/// ``` +/// use polars_arrow::compute::arithmetics::time::subtract_timestamps; +/// use polars_arrow::array::PrimitiveArray; +/// use polars_arrow::datatypes::{DataType, TimeUnit}; +/// let timestamp_a = PrimitiveArray::from([ +/// Some(100_010i64), +/// Some(200_020i64), +/// None, +/// Some(300_030i64), +/// ]) +/// .to(DataType::Timestamp(TimeUnit::Second, None)); +/// +/// let timestamp_b = PrimitiveArray::from([ +/// Some(100_000i64), +/// Some(200_000i64), +/// None, +/// Some(300_000i64), +/// ]) +/// .to(DataType::Timestamp(TimeUnit::Second, None)); +/// +/// let expected = PrimitiveArray::from([Some(10i64), Some(20i64), None, Some(30i64)]) +/// .to(DataType::Duration(TimeUnit::Second)); +/// +/// let result = subtract_timestamps(×tamp_a, &×tamp_b).unwrap(); +/// assert_eq!(result, expected); +/// ``` +pub fn subtract_timestamps( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> PolarsResult> { + // Matching on both data types from both arrays. + // Both timestamps have a Timeunit enum in its data type. + // This enum is used to adjust the scale between the timestamps. + match (lhs.data_type(), rhs.data_type()) { + // Naive timestamp comparison. It doesn't take into account timezones + // from the Timestamp timeunit. + (DataType::Timestamp(timeunit_a, None), DataType::Timestamp(timeunit_b, None)) => { + // Closure for the binary operation. The closure contains the scale + // required to calculate the difference between the timestamps. + let scale = temporal_conversions::timeunit_scale(*timeunit_a, *timeunit_b); + let op = move |a, b| a - (b as f64 * scale) as i64; + + Ok(binary(lhs, rhs, DataType::Duration(*timeunit_a), op)) + }, + _ => polars_bail!(ComputeError: + "Incorrect data type for the arguments" + ) + } +} + +/// Calculates the difference between two timestamps as [`DataType::Duration`] with the same time scale. +pub fn sub_timestamps_scalar( + lhs: &PrimitiveArray, + rhs: &PrimitiveScalar, +) -> PolarsResult> { + let (scale, timeunit_a) = + if let (DataType::Timestamp(timeunit_a, None), DataType::Timestamp(timeunit_b, None)) = + (lhs.data_type(), rhs.data_type()) + { + ( + temporal_conversions::timeunit_scale(*timeunit_a, *timeunit_b), + timeunit_a, + ) + } else { + return Err(Error::InvalidArgumentError( + "sub_timestamps_scalar requires both arguments to be timestamps without timezone" + .to_string(), + )); + }; + + let rhs = if let Some(value) = *rhs.value() { + value + } else { + return Ok(PrimitiveArray::::new_null( + lhs.data_type().clone(), + lhs.len(), + )); + }; + + let op = move |a| a - (rhs as f64 * scale) as i64; + + Ok(unary(lhs, op, DataType::Duration(*timeunit_a))) +} + +/// Adds an interval to a [`DataType::Timestamp`]. +pub fn add_interval( + timestamp: &PrimitiveArray, + interval: &PrimitiveArray, +) -> PolarsResult> { + match timestamp.data_type().to_logical_type() { + DataType::Timestamp(time_unit, Some(timezone_str)) => { + let time_unit = *time_unit; + let timezone = temporal_conversions::parse_offset(timezone_str); + match timezone { + Ok(timezone) => Ok(binary( + timestamp, + interval, + timestamp.data_type().clone(), + |timestamp, interval| { + temporal_conversions::add_interval( + timestamp, time_unit, interval, &timezone, + ) + }, + )), + #[cfg(feature = "chrono-tz")] + Err(_) => { + let timezone = temporal_conversions::parse_offset_tz(timezone_str)?; + Ok(binary( + timestamp, + interval, + timestamp.data_type().clone(), + |timestamp, interval| { + temporal_conversions::add_interval( + timestamp, time_unit, interval, &timezone, + ) + }, + )) + }, + #[cfg(not(feature = "chrono-tz"))] + _ => Err(Error::InvalidArgumentError(format!( + "timezone \"{}\" cannot be parsed (feature chrono-tz is not active)", + timezone_str + ))), + } + }, + DataType::Timestamp(time_unit, None) => { + let time_unit = *time_unit; + Ok(binary( + timestamp, + interval, + timestamp.data_type().clone(), + |timestamp, interval| { + temporal_conversions::add_naive_interval(timestamp, time_unit, interval) + }, + )) + }, + _ => Err(Error::InvalidArgumentError( + "Adding an interval is only supported for `DataType::Timestamp`".to_string(), + )), + } +} + +/// Adds an interval to a [`DataType::Timestamp`]. +pub fn add_interval_scalar( + timestamp: &PrimitiveArray, + interval: &PrimitiveScalar, +) -> PolarsResult> { + let interval = if let Some(interval) = *interval.value() { + interval + } else { + return Ok(PrimitiveArray::::new_null( + timestamp.data_type().clone(), + timestamp.len(), + )); + }; + + match timestamp.data_type().to_logical_type() { + DataType::Timestamp(time_unit, Some(timezone_str)) => { + let time_unit = *time_unit; + let timezone = temporal_conversions::parse_offset(timezone_str); + match timezone { + Ok(timezone) => Ok(unary( + timestamp, + |timestamp| { + temporal_conversions::add_interval( + timestamp, time_unit, interval, &timezone, + ) + }, + timestamp.data_type().clone(), + )), + #[cfg(feature = "chrono-tz")] + Err(_) => { + let timezone = temporal_conversions::parse_offset_tz(timezone_str)?; + Ok(unary( + timestamp, + |timestamp| { + temporal_conversions::add_interval( + timestamp, time_unit, interval, &timezone, + ) + }, + timestamp.data_type().clone(), + )) + }, + #[cfg(not(feature = "chrono-tz"))] + _ => Err(Error::InvalidArgumentError(format!( + "timezone \"{}\" cannot be parsed (feature chrono-tz is not active)", + timezone_str + ))), + } + }, + DataType::Timestamp(time_unit, None) => { + let time_unit = *time_unit; + Ok(unary( + timestamp, + |timestamp| { + temporal_conversions::add_naive_interval(timestamp, time_unit, interval) + }, + timestamp.data_type().clone(), + )) + }, + _ => Err(Error::InvalidArgumentError( + "Adding an interval is only supported for `DataType::Timestamp`".to_string(), + )), + } +} diff --git a/crates/polars-arrow/src/compute/cast/mod.rs b/crates/polars-arrow/src/compute/cast/mod.rs index 23edbd1c9056..d33525285c7d 100644 --- a/crates/polars-arrow/src/compute/cast/mod.rs +++ b/crates/polars-arrow/src/compute/cast/mod.rs @@ -520,12 +520,7 @@ pub fn cast( // Safety: offsets _are_ monotonically increasing let offsets = unsafe { Offsets::new_unchecked(offsets) }; - let list_array = ListArray::::new( - to_type.clone(), - offsets.into(), - values, - array.validity().cloned(), - ); + let list_array = ListArray::::new(to_type.clone(), offsets.into(), values, None); Ok(Box::new(list_array)) }, @@ -585,11 +580,9 @@ pub fn cast( LargeUtf8 => Ok(Box::new(utf8_to_large_utf8( array.as_any().downcast_ref().unwrap(), ))), - Timestamp(time_unit, None) => { - utf8_to_naive_timestamp_dyn::(array, time_unit.to_owned()) - }, - Timestamp(time_unit, Some(time_zone)) => { - utf8_to_timestamp_dyn::(array, time_zone.clone(), time_unit.to_owned()) + Timestamp(tu, None) => utf8_to_naive_timestamp_dyn::(array, tu.to_owned()), + Timestamp(tu, Some(tz)) => { + utf8_to_timestamp_dyn::(array, tz.clone(), tu.to_owned()) }, _ => polars_bail!(InvalidOperation: "casting from {from_type:?} to {to_type:?} not supported", @@ -614,11 +607,9 @@ pub fn cast( to_type.clone(), ) .boxed()), - Timestamp(time_unit, None) => { - utf8_to_naive_timestamp_dyn::(array, time_unit.to_owned()) - }, - Timestamp(time_unit, Some(time_zone)) => { - utf8_to_timestamp_dyn::(array, time_zone.clone(), time_unit.to_owned()) + Timestamp(tu, None) => utf8_to_naive_timestamp_dyn::(array, tu.to_owned()), + Timestamp(tu, Some(tz)) => { + utf8_to_timestamp_dyn::(array, tz.clone(), tu.to_owned()) }, _ => polars_bail!(InvalidOperation: "casting from {from_type:?} to {to_type:?} not supported", diff --git a/crates/polars-arrow/src/compute/cast/utf8_to.rs b/crates/polars-arrow/src/compute/cast/utf8_to.rs index 4b6ac51fafd5..85a252544e5e 100644 --- a/crates/polars-arrow/src/compute/cast/utf8_to.rs +++ b/crates/polars-arrow/src/compute/cast/utf8_to.rs @@ -112,10 +112,10 @@ pub fn utf8_to_dictionary( pub(super) fn utf8_to_naive_timestamp_dyn( from: &dyn Array, - time_unit: TimeUnit, + tu: TimeUnit, ) -> PolarsResult> { let from = from.as_any().downcast_ref().unwrap(); - Ok(Box::new(utf8_to_naive_timestamp::(from, time_unit))) + Ok(Box::new(utf8_to_naive_timestamp::(from, tu))) } /// [`crate::temporal_conversions::utf8_to_timestamp`] applied for RFC3339 formatting @@ -129,10 +129,10 @@ pub fn utf8_to_naive_timestamp( pub(super) fn utf8_to_timestamp_dyn( from: &dyn Array, timezone: String, - time_unit: TimeUnit, + tu: TimeUnit, ) -> PolarsResult> { let from = from.as_any().downcast_ref().unwrap(); - utf8_to_timestamp::(from, timezone, time_unit) + utf8_to_timestamp::(from, timezone, tu) .map(Box::new) .map(|x| x as Box) } @@ -141,9 +141,9 @@ pub(super) fn utf8_to_timestamp_dyn( pub fn utf8_to_timestamp( from: &Utf8Array, timezone: String, - time_unit: TimeUnit, + tu: TimeUnit, ) -> PolarsResult> { - utf8_to_timestamp_(from, RFC3339, timezone, time_unit) + utf8_to_timestamp_(from, RFC3339, timezone, tu) } /// Conversion of utf8 diff --git a/crates/polars-arrow/src/compute/if_then_else.rs b/crates/polars-arrow/src/compute/if_then_else.rs index 9433f431fb19..292f4e484f81 100644 --- a/crates/polars-arrow/src/compute/if_then_else.rs +++ b/crates/polars-arrow/src/compute/if_then_else.rs @@ -6,6 +6,24 @@ use crate::bitmap::utils::SlicesIterator; /// Returns the values from `lhs` if the predicate is `true` or from the `rhs` if the predicate is false /// Returns `None` if the predicate is `None`. +/// # Example +/// ```rust +/// # use polars_arrow::error::Result; +/// use polars_arrow::compute::if_then_else::if_then_else; +/// use polars_arrow::array::{Int32Array, BooleanArray}; +/// +/// # fn main() -> Result<()> { +/// let lhs = Int32Array::from_slice(&[1, 2, 3]); +/// let rhs = Int32Array::from_slice(&[4, 5, 6]); +/// let predicate = BooleanArray::from(&[Some(true), None, Some(false)]); +/// let result = if_then_else(&predicate, &lhs, &rhs)?; +/// +/// let expected = Int32Array::from(&[Some(1), None, Some(6)]); +/// +/// assert_eq!(expected, result.as_ref()); +/// # Ok(()) +/// # } +/// ``` pub fn if_then_else( predicate: &BooleanArray, lhs: &dyn Array, diff --git a/crates/polars-arrow/src/ffi/schema.rs b/crates/polars-arrow/src/ffi/schema.rs index ded8215ca3ce..ebbf5b8f6c76 100644 --- a/crates/polars-arrow/src/ffi/schema.rs +++ b/crates/polars-arrow/src/ffi/schema.rs @@ -145,10 +145,6 @@ impl ArrowSchema { } } - pub fn is_null(&self) -> bool { - self.private_data.is_null() - } - /// returns the format of this schema. pub(crate) fn format(&self) -> &str { assert!(!self.format.is_null()); diff --git a/crates/polars-arrow/src/io/ipc/mod.rs b/crates/polars-arrow/src/io/ipc/mod.rs index 39ad7753359f..6ac9c3011b79 100644 --- a/crates/polars-arrow/src/io/ipc/mod.rs +++ b/crates/polars-arrow/src/io/ipc/mod.rs @@ -25,6 +25,54 @@ //! the case of the `File` variant it also implements [`Seek`](std::io::Seek). In //! practice it means that `File`s can be arbitrarily accessed while `Stream`s are only //! read in certain order - the one they were written in (first in, first out). +//! +//! # Examples +//! Read and write to a file: +//! ``` +//! use polars_arrow::io::ipc::{{read::{FileReader, read_file_metadata}}, {write::{FileWriter, WriteOptions}}}; +//! # use std::fs::File; +//! # use polars_arrow::datatypes::{Field, Schema, DataType}; +//! # use polars_arrow::array::{Int32Array, Array}; +//! # use polars_arrow::chunk::Chunk; +//! # use polars_arrow::error::Error; +//! // Setup the writer +//! let path = "example.arrow".to_string(); +//! let mut file = File::create(&path)?; +//! let x_coord = Field::new("x", DataType::Int32, false); +//! let y_coord = Field::new("y", DataType::Int32, false); +//! let schema = Schema::from(vec![x_coord, y_coord]); +//! let options = WriteOptions {compression: None}; +//! let mut writer = FileWriter::try_new(file, schema, None, options)?; +//! +//! // Setup the data +//! let x_data = Int32Array::from_slice([-1i32, 1]); +//! let y_data = Int32Array::from_slice([1i32, -1]); +//! let chunk = Chunk::try_new(vec![x_data.boxed(), y_data.boxed()])?; +//! +//! // Write the messages and finalize the stream +//! for _ in 0..5 { +//! writer.write(&chunk, None); +//! } +//! writer.finish(); +//! +//! // Fetch some of the data and get the reader back +//! let mut reader = File::open(&path)?; +//! let metadata = read_file_metadata(&mut reader)?; +//! let mut reader = FileReader::new(reader, metadata, None, None); +//! let row1 = reader.next().unwrap(); // [[-1, 1], [1, -1]] +//! let row2 = reader.next().unwrap(); // [[-1, 1], [1, -1]] +//! let mut reader = reader.into_inner(); +//! // Do more stuff with the reader, like seeking ahead. +//! # Ok::<(), Error>(()) +//! ``` +//! +//! For further information and examples please consult the +//! [user guide](https://jorgecarleitao.github.io/polars_arrow/io/index.html). +//! For even more examples check the `examples` folder in the main repository +//! ([1](https://github.com/jorgecarleitao/polars_arrow/blob/main/examples/ipc_file_read.rs), +//! [2](https://github.com/jorgecarleitao/polars_arrow/blob/main/examples/ipc_file_write.rs), +//! [3](https://github.com/jorgecarleitao/polars_arrow/tree/main/examples/ipc_pyarrow)). + mod compression; mod endianness; diff --git a/crates/polars-arrow/src/io/ipc/write/file_async.rs b/crates/polars-arrow/src/io/ipc/write/file_async.rs index 5ed1350a65ff..cad67f35fdea 100644 --- a/crates/polars-arrow/src/io/ipc/write/file_async.rs +++ b/crates/polars-arrow/src/io/ipc/write/file_async.rs @@ -21,6 +21,46 @@ type WriteOutput = (usize, Option, Vec, Option); /// /// The file header is automatically written before writing the first chunk, and the file footer is /// automatically written when the sink is closed. +/// +/// # Examples +/// +/// ``` +/// use futures::{SinkExt, TryStreamExt, io::Cursor}; +/// use polars_arrow::array::{Array, Int32Array}; +/// use polars_arrow::datatypes::{DataType, Field, Schema}; +/// use polars_arrow::chunk::Chunk; +/// use polars_arrow::io::ipc::write::file_async::FileSink; +/// use polars_arrow::io::ipc::read::file_async::{read_file_metadata_async, FileStream}; +/// # futures::executor::block_on(async move { +/// let schema = Schema::from(vec![ +/// Field::new("values", DataType::Int32, true), +/// ]); +/// +/// let mut buffer = Cursor::new(vec![]); +/// let mut sink = FileSink::new( +/// &mut buffer, +/// schema, +/// None, +/// Default::default(), +/// ); +/// +/// // Write chunks to file +/// for i in 0..3 { +/// let values = Int32Array::from(&[Some(i), None]); +/// let chunk = Chunk::new(vec![values.boxed()]); +/// sink.feed(chunk.into()).await?; +/// } +/// sink.close().await?; +/// drop(sink); +/// +/// // Read chunks from file +/// buffer.set_position(0); +/// let metadata = read_file_metadata_async(&mut buffer).await?; +/// let mut stream = FileStream::new(buffer, metadata, None, None); +/// let chunks = stream.try_collect::>().await?; +/// # polars_arrow::error::Result::Ok(()) +/// # }).unwrap(); +/// ``` pub struct FileSink<'a, W: AsyncWrite + Unpin + Send + 'a> { writer: Option, task: Option>>>, diff --git a/crates/polars-arrow/src/io/ipc/write/stream_async.rs b/crates/polars-arrow/src/io/ipc/write/stream_async.rs index 49305c2ab383..7e8d056ce52b 100644 --- a/crates/polars-arrow/src/io/ipc/write/stream_async.rs +++ b/crates/polars-arrow/src/io/ipc/write/stream_async.rs @@ -17,6 +17,37 @@ use crate::datatypes::*; /// A sink that writes array [`chunks`](crate::chunk::Chunk) as an IPC stream. /// /// The stream header is automatically written before writing the first chunk. +/// +/// # Examples +/// +/// ``` +/// use futures::SinkExt; +/// use polars_arrow::array::{Array, Int32Array}; +/// use polars_arrow::datatypes::{DataType, Field, Schema}; +/// use polars_arrow::chunk::Chunk; +/// # use polars_arrow::io::ipc::write::stream_async::StreamSink; +/// # futures::executor::block_on(async move { +/// let schema = Schema::from(vec![ +/// Field::new("values", DataType::Int32, true), +/// ]); +/// +/// let mut buffer = vec![]; +/// let mut sink = StreamSink::new( +/// &mut buffer, +/// &schema, +/// None, +/// Default::default(), +/// ); +/// +/// for i in 0..3 { +/// let values = Int32Array::from(&[Some(i), None]); +/// let chunk = Chunk::new(vec![values.boxed()]); +/// sink.feed(chunk.into()).await?; +/// } +/// sink.close().await?; +/// # polars_arrow::error::Result::Ok(()) +/// # }).unwrap(); +/// ``` pub struct StreamSink<'a, W: AsyncWrite + Unpin + Send + 'a> { writer: Option, task: Option>>>, diff --git a/crates/polars-arrow/src/io/parquet/read/deserialize/nested_utils.rs b/crates/polars-arrow/src/io/parquet/read/deserialize/nested_utils.rs index 482d5117a7da..9466b93cb7dc 100644 --- a/crates/polars-arrow/src/io/parquet/read/deserialize/nested_utils.rs +++ b/crates/polars-arrow/src/io/parquet/read/deserialize/nested_utils.rs @@ -499,10 +499,15 @@ where D: NestedDecoder<'a>, { // front[a1, a2, a3, ...]back - if *remaining == 0 && items.is_empty() { - return MaybeNext::None; + if items.len() > 1 { + return MaybeNext::Some(Ok(items.pop_front().unwrap())); + } + if *remaining == 0 { + return match items.pop_front() { + Some(decoded) => MaybeNext::Some(Ok(decoded)), + None => MaybeNext::None, + }; } - match iter.next() { Err(e) => MaybeNext::Some(Err(e.into())), Ok(None) => { @@ -536,9 +541,7 @@ where Err(e) => return MaybeNext::Some(Err(e)), }; - // this comparison is strictly greater to ensure the contents of the - // row are fully read. - if !items.is_empty() + if (items.len() == 1) && items.front().unwrap().0.len() > chunk_size.unwrap_or(usize::MAX) { MaybeNext::Some(Ok(items.pop_front().unwrap())) diff --git a/crates/polars-arrow/src/io/parquet/write/sink.rs b/crates/polars-arrow/src/io/parquet/write/sink.rs index d8d2734ce461..284e2a6f7639 100644 --- a/crates/polars-arrow/src/io/parquet/write/sink.rs +++ b/crates/polars-arrow/src/io/parquet/write/sink.rs @@ -18,6 +18,47 @@ use crate::datatypes::Schema; /// /// Any values in the sink's `metadata` field will be written to the file's footer /// when the sink is closed. +/// +/// # Examples +/// +/// ``` +/// use futures::SinkExt; +/// use polars_arrow::array::{Array, Int32Array}; +/// use polars_arrow::datatypes::{DataType, Field, Schema}; +/// use polars_arrow::chunk::Chunk; +/// use polars_arrow::io::parquet::write::{Encoding, WriteOptions, CompressionOptions, Version}; +/// # use polars_arrow::io::parquet::write::FileSink; +/// # futures::executor::block_on(async move { +/// +/// let schema = Schema::from(vec![ +/// Field::new("values", DataType::Int32, true), +/// ]); +/// let encoding = vec![vec![Encoding::Plain]]; +/// let options = WriteOptions { +/// write_statistics: true, +/// compression: CompressionOptions::Uncompressed, +/// version: Version::V2, +/// data_pagesize_limit: None, +/// }; +/// +/// let mut buffer = vec![]; +/// let mut sink = FileSink::try_new( +/// &mut buffer, +/// schema, +/// encoding, +/// options, +/// )?; +/// +/// for i in 0..3 { +/// let values = Int32Array::from(&[Some(i), None]); +/// let chunk = Chunk::new(vec![values.boxed()]); +/// sink.feed(chunk).await?; +/// } +/// sink.metadata.insert(String::from("key"), Some(String::from("value"))); +/// sink.close().await?; +/// # polars_arrow::error::Result::Ok(()) +/// # }).unwrap(); +/// ``` pub struct FileSink<'a, W: AsyncWrite + Send + Unpin> { writer: Option>, task: Option>>>>, diff --git a/crates/polars-arrow/src/legacy/kernels/ewm/mod.rs b/crates/polars-arrow/src/legacy/kernels/ewm/mod.rs index 5984106f1521..8f45bbbef2fb 100644 --- a/crates/polars-arrow/src/legacy/kernels/ewm/mod.rs +++ b/crates/polars-arrow/src/legacy/kernels/ewm/mod.rs @@ -2,12 +2,9 @@ mod average; mod variance; pub use average::*; -#[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; pub use variance::*; -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Debug, Copy, Clone)] #[must_use] pub struct EWMOptions { pub alpha: f64, diff --git a/crates/polars-core/Cargo.toml b/crates/polars-core/Cargo.toml index 8ec0afe8a936..012c5687d959 100644 --- a/crates/polars-core/Cargo.toml +++ b/crates/polars-core/Cargo.toml @@ -89,12 +89,17 @@ reinterpret = [] take_opt_iter = [] # allow group_by operation on list type group_by_list = [] +# cumsum, cummin, etc. +cum_agg = [] # rolling window functions rolling_window = [] +diff = [] +pct_change = ["diff"] moment = [] diagonal_concat = [] horizontal_concat = [] abs = [] +ewma = [] dataframe_arithmetic = [] product = [] unique_counts = [] @@ -138,13 +143,16 @@ docs-selection = [ "temporal", "random", "zip_with", + "round_series", "checked_arithmetic", "is_first_distinct", "is_last_distinct", "asof_join", "dot_product", "row_hash", + "cum_agg", "rolling_window", + "diff", "moment", "dtype-categorical", "dtype-decimal", diff --git a/crates/polars-core/src/chunked_array/builder/binary.rs b/crates/polars-core/src/chunked_array/builder/binary.rs index bed05a434ba1..119dc461c7ed 100644 --- a/crates/polars-core/src/chunked_array/builder/binary.rs +++ b/crates/polars-core/src/chunked_array/builder/binary.rs @@ -1,5 +1,3 @@ -use polars_error::constants::LENGTH_LIMIT_MSG; - use super::*; pub struct BinaryChunkedBuilder { @@ -42,8 +40,7 @@ impl BinaryChunkedBuilder { pub fn finish(mut self) -> BinaryChunked { let arr = self.builder.as_box(); - let length = IdxSize::try_from(arr.len()).expect(LENGTH_LIMIT_MSG); - let null_count = arr.null_count() as IdxSize; + let length = arr.len() as IdxSize; ChunkedArray { field: Arc::new(self.field), @@ -51,7 +48,6 @@ impl BinaryChunkedBuilder { phantom: PhantomData, bit_settings: Default::default(), length, - null_count, } } diff --git a/crates/polars-core/src/chunked_array/builder/boolean.rs b/crates/polars-core/src/chunked_array/builder/boolean.rs index 407bc3abcf53..655d94ff1a7d 100644 --- a/crates/polars-core/src/chunked_array/builder/boolean.rs +++ b/crates/polars-core/src/chunked_array/builder/boolean.rs @@ -21,14 +21,14 @@ impl ChunkedBuilder for BooleanChunkedBuilder { fn finish(mut self) -> BooleanChunked { let arr = self.array_builder.as_box(); + let length = arr.len() as IdxSize; let mut ca = ChunkedArray { field: Arc::new(self.field), chunks: vec![arr], phantom: PhantomData, bit_settings: Default::default(), - length: 0, - null_count: 0, + length, }; ca.compute_len(); ca diff --git a/crates/polars-core/src/chunked_array/builder/primitive.rs b/crates/polars-core/src/chunked_array/builder/primitive.rs index eae7977612fe..f5314a5fb62a 100644 --- a/crates/polars-core/src/chunked_array/builder/primitive.rs +++ b/crates/polars-core/src/chunked_array/builder/primitive.rs @@ -27,13 +27,13 @@ where fn finish(mut self) -> ChunkedArray { let arr = self.array_builder.as_box(); + let length = arr.len() as IdxSize; let mut ca = ChunkedArray { field: Arc::new(self.field), chunks: vec![arr], phantom: PhantomData, bit_settings: Default::default(), - length: 0, - null_count: 0, + length, }; ca.compute_len(); ca diff --git a/crates/polars-core/src/chunked_array/builder/utf8.rs b/crates/polars-core/src/chunked_array/builder/utf8.rs index 1a1c793563ed..49f933c790ed 100644 --- a/crates/polars-core/src/chunked_array/builder/utf8.rs +++ b/crates/polars-core/src/chunked_array/builder/utf8.rs @@ -41,14 +41,14 @@ impl Utf8ChunkedBuilder { pub fn finish(mut self) -> Utf8Chunked { let arr = self.builder.as_box(); + let length = arr.len() as IdxSize; let mut ca = ChunkedArray { field: Arc::new(self.field), chunks: vec![arr], phantom: PhantomData, bit_settings: Default::default(), - length: 0, - null_count: 0, + length, }; ca.compute_len(); ca diff --git a/crates/polars-core/src/chunked_array/cast.rs b/crates/polars-core/src/chunked_array/cast.rs index 2c216a69731a..8daddeac1d81 100644 --- a/crates/polars-core/src/chunked_array/cast.rs +++ b/crates/polars-core/src/chunked_array/cast.rs @@ -203,21 +203,21 @@ impl ChunkCast for Utf8Chunked { Ok(out) }, #[cfg(feature = "dtype-datetime")] - DataType::Datetime(time_unit, time_zone) => { - let out = match time_zone { + DataType::Datetime(tu, tz) => { + let out = match tz { #[cfg(feature = "timezones")] - Some(time_zone) => { - validate_time_zone(time_zone)?; + Some(tz) => { + validate_time_zone(tz)?; let result = cast_chunks( &self.chunks, - &Datetime(time_unit.to_owned(), Some(time_zone.clone())), + &Datetime(tu.to_owned(), Some(tz.clone())), true, )?; Series::try_from((self.name(), result)) }, _ => { let result = - cast_chunks(&self.chunks, &Datetime(time_unit.to_owned(), None), true)?; + cast_chunks(&self.chunks, &Datetime(tu.to_owned(), None), true)?; Series::try_from((self.name(), result)) }, }; @@ -365,11 +365,7 @@ impl ChunkCast for ListChunked { } unsafe fn cast_unchecked(&self, data_type: &DataType) -> PolarsResult { - use DataType::*; - match data_type { - List(child_type) => cast_list_unchecked(self, child_type), - _ => self.cast(data_type), - } + self.cast(data_type) } } @@ -418,8 +414,6 @@ impl ChunkCast for ArrayChunked { // Returns inner data type. This is needed because a cast can instantiate the dtype inner // values for instance with categoricals fn cast_list(ca: &ListChunked, child_type: &DataType) -> PolarsResult<(ArrayRef, DataType)> { - // We still rechunk because we must bubble up a single data-type - // TODO!: consider a version that works on chunks and merges the data-types and arrays. let ca = ca.rechunk(); let arr = ca.downcast_iter().next().unwrap(); // safety: inner dtype is passed correctly @@ -443,32 +437,6 @@ fn cast_list(ca: &ListChunked, child_type: &DataType) -> PolarsResult<(ArrayRef, Ok((Box::new(new_arr), inner_dtype)) } -unsafe fn cast_list_unchecked(ca: &ListChunked, child_type: &DataType) -> PolarsResult { - // TODO! add chunked, but this must correct for list offsets. - let ca = ca.rechunk(); - let arr = ca.downcast_iter().next().unwrap(); - // safety: inner dtype is passed correctly - let s = unsafe { - Series::from_chunks_and_dtype_unchecked("", vec![arr.values().clone()], &ca.inner_dtype()) - }; - let new_inner = s.cast_unchecked(child_type)?; - let new_values = new_inner.array_ref(0).clone(); - - let data_type = ListArray::::default_datatype(new_values.data_type().clone()); - let new_arr = ListArray::::new( - data_type, - arr.offsets().clone(), - new_values, - arr.validity().cloned(), - ); - Ok(ListChunked::from_chunks_and_dtype_unchecked( - ca.name(), - vec![Box::new(new_arr)], - DataType::List(Box::new(child_type.clone())), - ) - .into_series()) -} - // Returns inner data type. This is needed because a cast can instantiate the dtype inner // values for instance with categoricals #[cfg(feature = "dtype-array")] diff --git a/crates/polars-core/src/chunked_array/from.rs b/crates/polars-core/src/chunked_array/from.rs index c384dec3e241..b20ea1cde3ca 100644 --- a/crates/polars-core/src/chunked_array/from.rs +++ b/crates/polars-core/src/chunked_array/from.rs @@ -1,5 +1,3 @@ -use polars_error::constants::LENGTH_LIMIT_MSG; - use super::*; #[allow(clippy::all)] @@ -145,12 +143,10 @@ where ); let mut length = 0; - let mut null_count = 0; let chunks = chunks .into_iter() .map(|x| { length += x.len(); - null_count += x.null_count(); Box::new(x) as Box }) .collect(); @@ -160,8 +156,7 @@ where chunks, phantom: PhantomData, bit_settings: Default::default(), - length: length.try_into().expect(LENGTH_LIMIT_MSG), - null_count: null_count as IdxSize, + length: length.try_into().unwrap(), } } @@ -189,7 +184,6 @@ where phantom: PhantomData, bit_settings: Default::default(), length: 0, - null_count: 0, }; out.compute_len(); out @@ -219,7 +213,6 @@ where phantom: PhantomData, bit_settings: Default::default(), length: 0, - null_count: 0, }; out.compute_len(); out @@ -242,7 +235,6 @@ where phantom: PhantomData, bit_settings, length: 0, - null_count: 0, }; out.compute_len(); if !keep_sorted { @@ -266,7 +258,6 @@ where phantom: PhantomData, bit_settings: Default::default(), length: 0, - null_count: 0, }; out.compute_len(); out @@ -282,8 +273,12 @@ where Self::with_chunk(name, to_primitive::(v, None)) } - /// Create a new ChunkedArray from a Vec and a validity mask. - pub fn from_vec_validity(name: &str, values: Vec, buffer: Option) -> Self { + /// Nullify values in slice with an existing null bitmap + pub fn new_from_owned_with_null_bitmap( + name: &str, + values: Vec, + buffer: Option, + ) -> Self { let arr = to_array::(values, buffer); let mut out = ChunkedArray { field: Arc::new(Field::new(name, T::get_dtype())), diff --git a/crates/polars-core/src/chunked_array/list/iterator.rs b/crates/polars-core/src/chunked_array/list/iterator.rs index 4c51fdb271c7..2dc2c7eb8559 100644 --- a/crates/polars-core/src/chunked_array/list/iterator.rs +++ b/crates/polars-core/src/chunked_array/list/iterator.rs @@ -208,14 +208,15 @@ impl ListChunked { .map(|(opt_s, opt_v)| { let out = f(opt_s, opt_v); match out { - Some(out) => { - fast_explode &= !out.is_empty(); + Some(out) if out.is_empty() => { + fast_explode = false; Some(out) }, None => { fast_explode = false; out }, + _ => out, } }) .collect_trusted() @@ -228,51 +229,6 @@ impl ListChunked { out } - pub fn try_zip_and_apply_amortized<'a, T, I, F>( - &'a self, - ca: &'a ChunkedArray, - mut f: F, - ) -> PolarsResult - where - T: PolarsDataType, - &'a ChunkedArray: IntoIterator, - I: TrustedLen>>, - F: FnMut( - Option>, - Option>, - ) -> PolarsResult>, - { - if self.is_empty() { - return Ok(self.clone()); - } - let mut fast_explode = self.null_count() == 0; - // SAFETY: unstable series never lives longer than the iterator. - let mut out: ListChunked = unsafe { - self.amortized_iter() - .zip(ca) - .map(|(opt_s, opt_v)| { - let out = f(opt_s, opt_v)?; - match out { - Some(out) => { - fast_explode &= !out.is_empty(); - Ok(Some(out)) - }, - None => { - fast_explode = false; - Ok(out) - }, - } - }) - .collect::>()? - }; - - out.rename(self.name()); - if fast_explode { - out.set_fast_explode(); - } - Ok(out) - } - /// Apply a closure `F` elementwise. #[must_use] pub fn apply_amortized<'a, F>(&'a self, mut f: F) -> Self diff --git a/crates/polars-core/src/chunked_array/logical/categorical/from.rs b/crates/polars-core/src/chunked_array/logical/categorical/from.rs index 1c7e028fdf87..e268144c9ddd 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/from.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/from.rs @@ -7,7 +7,7 @@ use crate::using_string_cache; impl From<&CategoricalChunked> for DictionaryArray { fn from(ca: &CategoricalChunked) -> Self { - let keys = ca.physical().rechunk(); + let keys = ca.logical().rechunk(); let keys = keys.downcast_iter().next().unwrap(); let map = &**ca.get_rev_map(); let dtype = ArrowDataType::Dictionary( @@ -42,7 +42,7 @@ impl From<&CategoricalChunked> for DictionaryArray { } impl From<&CategoricalChunked> for DictionaryArray { fn from(ca: &CategoricalChunked) -> Self { - let keys = ca.physical().rechunk(); + let keys = ca.logical().rechunk(); let keys = keys.downcast_iter().next().unwrap(); let map = &**ca.get_rev_map(); let dtype = ArrowDataType::Dictionary( diff --git a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs index d1671e6e7f05..d66e6318b5ef 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs @@ -23,7 +23,7 @@ bitflags! { #[derive(Clone)] pub struct CategoricalChunked { - physical: Logical, + logical: Logical, /// 1st bit: original local categorical /// meaning that n_unique is the same as the cat map length /// 2nd bit: use lexical sorting @@ -32,7 +32,7 @@ pub struct CategoricalChunked { impl CategoricalChunked { pub(crate) fn field(&self) -> Field { - let name = self.physical().name(); + let name = self.logical().name(); Field::new(name, self.dtype().clone()) } @@ -40,29 +40,23 @@ impl CategoricalChunked { self.len() == 0 } - #[inline] pub fn len(&self) -> usize { - self.physical.len() - } - - #[inline] - pub fn null_count(&self) -> usize { - self.physical.null_count() + self.logical.len() } pub fn name(&self) -> &str { - self.physical.name() + self.logical.name() } // TODO: Rename this /// Get a reference to the physical array (the categories). - pub fn physical(&self) -> &UInt32Chunked { - &self.physical + pub fn logical(&self) -> &UInt32Chunked { + &self.logical } /// Get a mutable reference to the physical array (the categories). - pub(crate) fn physical_mut(&mut self) -> &mut UInt32Chunked { - &mut self.physical + pub(crate) fn logical_mut(&mut self) -> &mut UInt32Chunked { + &mut self.logical } /// Convert a categorical column to its local representation. @@ -78,7 +72,7 @@ impl CategoricalChunked { // if all physical map keys are equal to their values, // we can skip the apply and only update the rev_map let local_ca = self - .physical() + .logical() .apply(|opt_v| opt_v.map(|v| *physical_map.get(&v).unwrap())); let mut out = @@ -90,12 +84,12 @@ impl CategoricalChunked { } pub(crate) fn get_flags(&self) -> Settings { - self.physical().get_flags() + self.logical().get_flags() } /// Set flags for the Chunked Array pub(crate) fn set_flags(&mut self, flags: Settings) { - self.physical_mut().set_flags(flags) + self.logical_mut().set_flags(flags) } /// Build a categorical from an original RevMap. That means that the number of categories in the `RevMapping == self.unique().len()`. @@ -111,7 +105,7 @@ impl CategoricalChunked { let mut bit_settings = BitSettings::default(); bit_settings.insert(BitSettings::ORIGINAL); Self { - physical: logical, + logical, bit_settings, } } @@ -141,7 +135,7 @@ impl CategoricalChunked { let mut logical = Logical::::new_logical::(idx); logical.2 = Some(DataType::Categorical(Some(rev_map))); Self { - physical: logical, + logical, bit_settings: Default::default(), } } @@ -149,14 +143,14 @@ impl CategoricalChunked { /// # Safety /// The existing index values must be in bounds of the new [`RevMapping`]. pub(crate) unsafe fn set_rev_map(&mut self, rev_map: Arc, keep_fast_unique: bool) { - self.physical.2 = Some(DataType::Categorical(Some(rev_map))); + self.logical.2 = Some(DataType::Categorical(Some(rev_map))); if !keep_fast_unique { self.set_fast_unique(false) } } pub(crate) fn can_fast_unique(&self) -> bool { - self.bit_settings.contains(BitSettings::ORIGINAL) && self.physical.chunks.len() == 1 + self.bit_settings.contains(BitSettings::ORIGINAL) && self.logical.chunks.len() == 1 } pub(crate) fn set_fast_unique(&mut self, toggle: bool) { @@ -169,7 +163,7 @@ impl CategoricalChunked { /// Get a reference to the mapping of categorical types to the string values. pub fn get_rev_map(&self) -> &Arc { - if let DataType::Categorical(Some(rev_map)) = &self.physical.2.as_ref().unwrap() { + if let DataType::Categorical(Some(rev_map)) = &self.logical.2.as_ref().unwrap() { rev_map } else { panic!("implementation error") @@ -178,7 +172,7 @@ impl CategoricalChunked { /// Create an `[Iterator]` that iterates over the `&str` values of the `[CategoricalChunked]`. pub fn iter_str(&self) -> CatIter<'_> { - let iter = self.physical().into_iter(); + let iter = self.logical().into_iter(); CatIter { rev: self.get_rev_map(), iter, @@ -188,7 +182,7 @@ impl CategoricalChunked { impl LogicalType for CategoricalChunked { fn dtype(&self) -> &DataType { - self.physical.2.as_ref().unwrap() + self.logical.2.as_ref().unwrap() } fn get_any_value(&self, i: usize) -> PolarsResult> { @@ -197,7 +191,7 @@ impl LogicalType for CategoricalChunked { } unsafe fn get_any_value_unchecked(&self, i: usize) -> AnyValue<'_> { - match self.physical.0.get_unchecked(i) { + match self.logical.0.get_unchecked(i) { Some(i) => AnyValue::Categorical(i, self.get_rev_map(), SyncPtr::new_null()), None => AnyValue::Null, } @@ -209,16 +203,16 @@ impl LogicalType for CategoricalChunked { let mapping = &**self.get_rev_map(); let mut builder = - Utf8ChunkedBuilder::new(self.physical.name(), self.len(), self.len() * 5); + Utf8ChunkedBuilder::new(self.logical.name(), self.len(), self.len() * 5); let f = |idx: u32| mapping.get(idx); - if !self.physical.has_validity() { - self.physical + if !self.logical.has_validity() { + self.logical .into_no_null_iter() .for_each(|idx| builder.append_value(f(idx))); } else { - self.physical.into_iter().for_each(|opt_idx| { + self.logical.into_iter().for_each(|opt_idx| { builder.append_option(opt_idx.map(f)); }); } @@ -228,13 +222,13 @@ impl LogicalType for CategoricalChunked { }, DataType::UInt32 => { let ca = unsafe { - UInt32Chunked::from_chunks(self.physical.name(), self.physical.chunks.clone()) + UInt32Chunked::from_chunks(self.logical.name(), self.logical.chunks.clone()) }; Ok(ca.into_series()) }, #[cfg(feature = "dtype-categorical")] DataType::Categorical(_) => Ok(self.clone().into_series()), - _ => self.physical.cast(dtype), + _ => self.logical.cast(dtype), } } } diff --git a/crates/polars-core/src/chunked_array/logical/categorical/ops/append.rs b/crates/polars-core/src/chunked_array/logical/categorical/ops/append.rs index 190c46dbf352..6385cfba3a1b 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/ops/append.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/ops/append.rs @@ -1,23 +1,13 @@ -use polars_error::constants::LENGTH_LIMIT_MSG; - use super::*; use crate::chunked_array::ops::append::new_chunks; use crate::series::IsSorted; impl CategoricalChunked { - fn set_lengths(&mut self, other: &Self) { - let length_self = &mut self.physical_mut().length; - *length_self = length_self - .checked_add(other.len() as IdxSize) - .expect(LENGTH_LIMIT_MSG); - self.physical_mut().null_count += other.null_count() as IdxSize; - } - pub fn append(&mut self, other: &Self) -> PolarsResult<()> { - if self.physical.null_count() == self.len() && other.physical.null_count() == other.len() { + if self.logical.null_count() == self.len() && other.logical.null_count() == other.len() { let len = self.len(); - self.set_lengths(other); - new_chunks(&mut self.physical.chunks, &other.physical().chunks, len); + self.logical_mut().length += other.len() as IdxSize; + new_chunks(&mut self.logical.chunks, &other.logical().chunks, len); return Ok(()); } let is_local_different_source = @@ -33,10 +23,10 @@ impl CategoricalChunked { let new_rev_map = self._merge_categorical_map(other)?; unsafe { self.set_rev_map(new_rev_map, false) }; - self.set_lengths(other); - new_chunks(&mut self.physical.chunks, &other.physical().chunks, len); + self.logical_mut().length += other.len() as IdxSize; + new_chunks(&mut self.logical.chunks, &other.logical().chunks, len); } - self.physical.set_sorted_flag(IsSorted::Not); + self.logical.set_sorted_flag(IsSorted::Not); Ok(()) } } diff --git a/crates/polars-core/src/chunked_array/logical/categorical/ops/unique.rs b/crates/polars-core/src/chunked_array/logical/categorical/ops/unique.rs index 42a448ecbbd8..9ac7d32ae749 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/ops/unique.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/ops/unique.rs @@ -7,10 +7,10 @@ impl CategoricalChunked { if self.can_fast_unique() { let ca = match &**cat_map { RevMapping::Local(a) => { - UInt32Chunked::from_iter_values(self.physical().name(), 0..(a.len() as u32)) + UInt32Chunked::from_iter_values(self.logical().name(), 0..(a.len() as u32)) }, RevMapping::Global(map, _, _) => { - UInt32Chunked::from_iter_values(self.physical().name(), map.keys().copied()) + UInt32Chunked::from_iter_values(self.logical().name(), map.keys().copied()) }, }; // safety: @@ -22,7 +22,7 @@ impl CategoricalChunked { Ok(out) } } else { - let ca = self.physical().unique()?; + let ca = self.logical().unique()?; // safety: // we only removed some indexes so we are still in bounds unsafe { @@ -38,14 +38,14 @@ impl CategoricalChunked { if self.can_fast_unique() { Ok(self.get_rev_map().len()) } else { - self.physical().n_unique() + self.logical().n_unique() } } pub fn value_counts(&self) -> PolarsResult { - let groups = self.physical().group_tuples(true, false).unwrap(); - let physical_values = unsafe { - self.physical() + let groups = self.logical().group_tuples(true, false).unwrap(); + let logical_values = unsafe { + self.logical() .clone() .into_series() .agg_first(&groups) @@ -55,7 +55,7 @@ impl CategoricalChunked { }; let mut values = self.clone(); - *values.physical_mut() = physical_values; + *values.logical_mut() = logical_values; let mut counts = groups.group_count(); counts.rename("counts"); diff --git a/crates/polars-core/src/chunked_array/logical/categorical/ops/zip.rs b/crates/polars-core/src/chunked_array/logical/categorical/ops/zip.rs index 8ece943cb0fc..7fcad0c73cbe 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/ops/zip.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/ops/zip.rs @@ -10,10 +10,10 @@ impl CategoricalChunked { RevMapping::Local(rev_map) => { // the logic for merging the rev maps will concatenate utf8 arrays // to make sure the indexes still make sense we need to offset the right hand side - self.physical() - .zip_with(mask, &(other.physical() + rev_map.len() as u32))? + self.logical() + .zip_with(mask, &(other.logical() + rev_map.len() as u32))? }, - _ => self.physical().zip_with(mask, other.physical())?, + _ => self.logical().zip_with(mask, other.logical())?, }; let new_state = self._merge_categorical_map(other)?; diff --git a/crates/polars-core/src/chunked_array/mod.rs b/crates/polars-core/src/chunked_array/mod.rs index 2ae1f7fd0c3a..fa79caea754c 100644 --- a/crates/polars-core/src/chunked_array/mod.rs +++ b/crates/polars-core/src/chunked_array/mod.rs @@ -140,7 +140,6 @@ pub struct ChunkedArray { phantom: PhantomData, pub(crate) bit_settings: Settings, length: IdxSize, - null_count: IdxSize, } bitflags! { @@ -304,7 +303,6 @@ impl ChunkedArray { /// /// # Safety /// The caller must ensure to not change the [`DataType`] or `length` of any of the chunks. - /// And the `null_count` remains correct. #[inline] pub unsafe fn chunks_mut(&mut self) -> &mut Vec { &mut self.chunks @@ -315,6 +313,12 @@ impl ChunkedArray { self.chunks.len() == 1 && self.null_count() == 0 } + /// Count the null values. + #[inline] + pub fn null_count(&self) -> usize { + self.chunks.iter().map(|arr| arr.null_count()).sum() + } + /// Create a new [`ChunkedArray`] from self, where the chunks are replaced. /// /// # Safety @@ -606,7 +610,6 @@ impl Clone for ChunkedArray { phantom: PhantomData, bit_settings: self.bit_settings, length: self.length, - null_count: self.null_count, } } } @@ -830,7 +833,7 @@ pub(crate) mod test { let ca = Utf8Chunked::new("", &[Some("foo"), None, Some("bar"), Some("ham")]); let ca = ca.cast(&DataType::Categorical(None)).unwrap(); let ca = ca.categorical().unwrap(); - let v: Vec<_> = ca.physical().into_iter().collect(); + let v: Vec<_> = ca.logical().into_iter().collect(); assert_eq!(v, &[Some(0), None, Some(1), Some(2)]); } diff --git a/crates/polars-core/src/chunked_array/object/builder.rs b/crates/polars-core/src/chunked_array/object/builder.rs index a6f8b9072c98..351cdc58a383 100644 --- a/crates/polars-core/src/chunked_array/object/builder.rs +++ b/crates/polars-core/src/chunked_array/object/builder.rs @@ -59,10 +59,6 @@ where let null_bitmap: Option = self.bitmask_builder.into(); let len = self.values.len(); - let null_count = null_bitmap - .as_ref() - .map(|validity| validity.unset_bits()) - .unwrap_or(0) as IdxSize; let arr = Box::new(ObjectArray { values: Arc::new(self.values), @@ -76,7 +72,6 @@ where phantom: PhantomData, bit_settings: Default::default(), length: len as IdxSize, - null_count, } } } @@ -141,7 +136,6 @@ where phantom: PhantomData, bit_settings: Default::default(), length: len as IdxSize, - null_count: 0, } } diff --git a/crates/polars-core/src/chunked_array/ops/append.rs b/crates/polars-core/src/chunked_array/ops/append.rs index 027ccb09d168..c14405ed377d 100644 --- a/crates/polars-core/src/chunked_array/ops/append.rs +++ b/crates/polars-core/src/chunked_array/ops/append.rs @@ -19,20 +19,16 @@ where T: PolarsDataType, for<'a> T::Physical<'a>: TotalOrd, { - // TODO: attempt to maintain sortedness better in case of nulls. - - // If either is empty, copy the sorted flag from the other. - if ca.is_empty() { + // If either is empty (or completely null), copy the sorted flag from the other. + if ca.len() == ca.null_count() { ca.set_sorted_flag(other.is_sorted_flag()); return; } - if other.is_empty() { + if other.len() == other.null_count() { return; } - // Both need to be sorted, in the same order, if the order is maintained. - // TODO: rework sorted flags, ascending and descending are not mutually - // exclusive for all-equal/all-null arrays. + // Both need to be sorted, in the same order. let ls = ca.is_sorted_flag(); let rs = other.is_sorted_flag(); if ls != rs || ls == IsSorted::Not || rs == IsSorted::Not { @@ -42,23 +38,12 @@ where // Check the order is maintained. let still_sorted = { - // To prevent potential quadratic append behavior we do not find - // the last non-null element in ca. - if let Some(left) = ca.last() { - if let Some(right_idx) = other.first_non_null() { - let right = other.get(right_idx).unwrap(); - if ca.is_sorted_ascending_flag() { - left.tot_le(&right) - } else { - left.tot_ge(&right) - } - } else { - // Right is only nulls, trivially sorted. - true - } + let left = ca.get(ca.last_non_null().unwrap()).unwrap(); + let right = other.get(other.first_non_null().unwrap()).unwrap(); + if ca.is_sorted_ascending_flag() { + left.tot_le(&right) } else { - // Last element in left is null, pessimistically assume not sorted. - false + left.tot_ge(&right) } }; if !still_sorted { @@ -78,7 +63,6 @@ where update_sorted_flag_before_append::(self, other); let len = self.len(); self.length += other.length; - self.null_count += other.null_count; new_chunks(&mut self.chunks, &other.chunks, len); } } @@ -91,7 +75,6 @@ impl ListChunked { let len = self.len(); self.length += other.length; - self.null_count += other.null_count; new_chunks(&mut self.chunks, &other.chunks, len); self.set_sorted_flag(IsSorted::Not); if !other._can_fast_explode() { @@ -110,7 +93,6 @@ impl ArrayChunked { let len = self.len(); self.length += other.length; - self.null_count += other.null_count; new_chunks(&mut self.chunks, &other.chunks, len); self.set_sorted_flag(IsSorted::Not); Ok(()) @@ -123,7 +105,6 @@ impl ObjectChunked { pub fn append(&mut self, other: &Self) { let len = self.len(); self.length += other.length; - self.null_count += other.null_count; self.set_sorted_flag(IsSorted::Not); new_chunks(&mut self.chunks, &other.chunks, len); } diff --git a/crates/polars-core/src/chunked_array/ops/apply.rs b/crates/polars-core/src/chunked_array/ops/apply.rs index 1254363eaa75..093e6c172d95 100644 --- a/crates/polars-core/src/chunked_array/ops/apply.rs +++ b/crates/polars-core/src/chunked_array/ops/apply.rs @@ -220,7 +220,6 @@ impl ChunkedArray { .for_each(|arr| arrow::compute::arity_assign::unary(arr, f)) }; // can be in any order now - self.compute_len(); self.set_sorted_flag(IsSorted::Not); } } diff --git a/crates/polars-core/src/chunked_array/ops/chunkops.rs b/crates/polars-core/src/chunked_array/ops/chunkops.rs index b4cc3b6c5ec2..076dc6476702 100644 --- a/crates/polars-core/src/chunked_array/ops/chunkops.rs +++ b/crates/polars-core/src/chunked_array/ops/chunkops.rs @@ -55,17 +55,10 @@ fn slice( impl ChunkedArray { /// Get the length of the ChunkedArray - #[inline] pub fn len(&self) -> usize { self.length as usize } - /// Count the null values. - #[inline] - pub fn null_count(&self) -> usize { - self.null_count as usize - } - /// Check if ChunkedArray is empty. pub fn is_empty(&self) -> bool { self.len() == 0 @@ -81,11 +74,6 @@ impl ChunkedArray { } } self.length = IdxSize::try_from(inner(&self.chunks)).expect(LENGTH_LIMIT_MSG); - self.null_count = self - .chunks - .iter() - .map(|arr| arr.null_count()) - .sum::() as IdxSize; if self.length <= 1 { self.set_sorted_flag(IsSorted::Ascending) diff --git a/crates/polars-core/src/chunked_array/ops/compare_inner.rs b/crates/polars-core/src/chunked_array/ops/compare_inner.rs index 7d710d92361a..c62d291e8de2 100644 --- a/crates/polars-core/src/chunked_array/ops/compare_inner.rs +++ b/crates/polars-core/src/chunked_array/ops/compare_inner.rs @@ -155,7 +155,7 @@ impl<'a> GetInner for GlobalCategorical<'a> { #[cfg(feature = "dtype-categorical")] impl<'a> IntoPartialOrdInner<'a> for &'a CategoricalChunked { fn into_partial_ord_inner(self) -> Box { - let cats = self.physical(); + let cats = self.logical(); match &**self.get_rev_map() { RevMapping::Global(p1, p2, _) => Box::new(GlobalCategorical { p1, p2, cats }), RevMapping::Local(rev_map) => Box::new(LocalCategorical { rev_map, cats }), diff --git a/crates/polars-core/src/chunked_array/ops/cum_agg.rs b/crates/polars-core/src/chunked_array/ops/cum_agg.rs new file mode 100644 index 000000000000..a1b0f2e65ee0 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/cum_agg.rs @@ -0,0 +1,176 @@ +use std::iter::FromIterator; +use std::ops::{Add, AddAssign, Mul}; + +use num_traits::Bounded; + +use crate::prelude::*; +use crate::utils::CustomIterTools; + +fn det_max(state: &mut T, v: Option) -> Option> +where + T: Copy + PartialOrd + AddAssign + Add, +{ + match v { + Some(v) => { + if v > *state { + *state = v + } + Some(Some(*state)) + }, + None => Some(None), + } +} + +fn det_min(state: &mut T, v: Option) -> Option> +where + T: Copy + PartialOrd + AddAssign + Add, +{ + match v { + Some(v) => { + if v < *state { + *state = v + } + Some(Some(*state)) + }, + None => Some(None), + } +} + +fn det_sum(state: &mut Option, v: Option) -> Option> +where + T: Copy + PartialOrd + AddAssign + Add, +{ + match (*state, v) { + (Some(state_inner), Some(v)) => { + *state = Some(state_inner + v); + Some(*state) + }, + (None, Some(v)) => { + *state = Some(v); + Some(*state) + }, + (_, None) => Some(None), + } +} + +fn det_prod(state: &mut Option, v: Option) -> Option> +where + T: Copy + PartialOrd + Mul, +{ + match (*state, v) { + (Some(state_inner), Some(v)) => { + *state = Some(state_inner * v); + Some(*state) + }, + (None, Some(v)) => { + *state = Some(v); + Some(*state) + }, + (_, None) => Some(None), + } +} + +impl ChunkCumAgg for ChunkedArray +where + T: PolarsNumericType, + ChunkedArray: FromIterator>, +{ + fn cummax(&self, reverse: bool) -> ChunkedArray { + let init = Bounded::min_value(); + + let mut ca: Self = match reverse { + false => self.into_iter().scan(init, det_max).collect_trusted(), + true => self + .into_iter() + .rev() + .scan(init, det_max) + .collect_reversed(), + }; + + ca.rename(self.name()); + ca + } + + fn cummin(&self, reverse: bool) -> ChunkedArray { + let init = Bounded::max_value(); + let mut ca: Self = match reverse { + false => self.into_iter().scan(init, det_min).collect_trusted(), + true => self + .into_iter() + .rev() + .scan(init, det_min) + .collect_reversed(), + }; + + ca.rename(self.name()); + ca + } + + fn cumsum(&self, reverse: bool) -> ChunkedArray { + let init = None; + let mut ca: Self = match reverse { + false => self.into_iter().scan(init, det_sum).collect_trusted(), + true => self + .into_iter() + .rev() + .scan(init, det_sum) + .collect_reversed(), + }; + + ca.rename(self.name()); + ca + } + + fn cumprod(&self, reverse: bool) -> ChunkedArray { + let init = None; + let mut ca: Self = match reverse { + false => self.into_iter().scan(init, det_prod).collect_trusted(), + true => self + .into_iter() + .rev() + .scan(init, det_prod) + .collect_reversed(), + }; + + ca.rename(self.name()); + ca + } +} + +#[cfg(test)] +mod test { + use crate::prelude::*; + + #[test] + #[cfg(feature = "dtype-u8")] + fn test_cummax() { + let ca = UInt8Chunked::new("foo", &[None, Some(1), Some(3), None, Some(1)]); + let out = ca.cummax(true); + assert_eq!(Vec::from(&out), &[None, Some(3), Some(3), None, Some(1)]); + let out = ca.cummax(false); + assert_eq!(Vec::from(&out), &[None, Some(1), Some(3), None, Some(3)]); + } + + #[test] + #[cfg(feature = "dtype-u8")] + fn test_cummin() { + let ca = UInt8Chunked::new("foo", &[None, Some(1), Some(3), None, Some(2)]); + let out = ca.cummin(true); + assert_eq!(Vec::from(&out), &[None, Some(1), Some(2), None, Some(2)]); + let out = ca.cummin(false); + assert_eq!(Vec::from(&out), &[None, Some(1), Some(1), None, Some(1)]); + } + + #[test] + fn test_cumsum() { + let ca = Int32Chunked::new("foo", &[None, Some(1), Some(3), None, Some(1)]); + let out = ca.cumsum(true); + assert_eq!(Vec::from(&out), &[None, Some(5), Some(4), None, Some(1)]); + let out = ca.cumsum(false); + assert_eq!(Vec::from(&out), &[None, Some(1), Some(4), None, Some(5)]); + + // just check if the trait bounds allow for floats + let ca = Float32Chunked::new("foo", &[None, Some(1.0), Some(3.0), None, Some(1.0)]); + let _out = ca.cumsum(false); + } +} diff --git a/crates/polars-core/src/chunked_array/ops/mod.rs b/crates/polars-core/src/chunked_array/ops/mod.rs index 6a54b4ae9b5a..9a89da39968d 100644 --- a/crates/polars-core/src/chunked_array/ops/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/mod.rs @@ -16,6 +16,8 @@ pub mod arity; mod bit_repr; pub(crate) mod chunkops; pub(crate) mod compare_inner; +#[cfg(feature = "cum_agg")] +mod cum_agg; #[cfg(feature = "dtype-decimal")] mod decimal; pub(crate) mod downcast; @@ -89,6 +91,26 @@ pub trait ChunkAnyValue { fn get_any_value(&self, index: usize) -> PolarsResult; } +#[cfg(feature = "cum_agg")] +pub trait ChunkCumAgg { + /// Get an array with the cumulative max computed at every element + fn cummax(&self, _reverse: bool) -> ChunkedArray { + panic!("operation cummax not supported for this dtype") + } + /// Get an array with the cumulative min computed at every element + fn cummin(&self, _reverse: bool) -> ChunkedArray { + panic!("operation cummin not supported for this dtype") + } + /// Get an array with the cumulative sum computed at every element + fn cumsum(&self, _reverse: bool) -> ChunkedArray { + panic!("operation cumsum not supported for this dtype") + } + /// Get an array with the cumulative product computed at every element + fn cumprod(&self, _reverse: bool) -> ChunkedArray { + panic!("operation cumprod not supported for this dtype") + } +} + /// Explode/ flatten a List or Utf8 Series pub trait ChunkExplode { fn explode(&self) -> PolarsResult { diff --git a/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs b/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs index d74dcfca91f2..9f0d5a4ebcf2 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs @@ -75,7 +75,7 @@ pub fn _get_rows_encoded_compat_array(by: &Series) -> PolarsResult { if ca.uses_lexical_ordering() { by.to_arrow(0) } else { - ca.physical().chunks[0].clone() + ca.logical().chunks[0].clone() } }, _ => by.to_arrow(0), diff --git a/crates/polars-core/src/chunked_array/ops/sort/categorical.rs b/crates/polars-core/src/chunked_array/ops/sort/categorical.rs index de25449f976e..337dee580b2a 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/categorical.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/categorical.rs @@ -31,7 +31,7 @@ impl CategoricalChunked { if self.uses_lexical_ordering() { let mut vals = self - .physical() + .logical() .into_no_null_iter() .zip(self.iter_str()) .collect_trusted::>(); @@ -57,7 +57,7 @@ impl CategoricalChunked { ) }; } - let cats = self.physical().sort_with(options); + let cats = self.logical().sort_with(options); // safety: // we only reordered the indexes so we are still in bounds unsafe { @@ -84,11 +84,11 @@ impl CategoricalChunked { self.name(), iters, options, - self.physical().null_count(), + self.logical().null_count(), self.len(), ) } else { - self.physical().arg_sort(options) + self.logical().arg_sort(options) } } @@ -96,7 +96,7 @@ impl CategoricalChunked { pub(crate) fn arg_sort_multiple(&self, options: &SortMultipleOptions) -> PolarsResult { if self.uses_lexical_ordering() { - args_validate(self.physical(), &options.other, &options.descending)?; + args_validate(self.logical(), &options.other, &options.descending)?; let mut count: IdxSize = 0; // we use bytes to save a monomorphisized str impl @@ -112,7 +112,7 @@ impl CategoricalChunked { arg_sort_multiple_impl(vals, options) } else { - self.physical().arg_sort_multiple(options) + self.logical().arg_sort_multiple(options) } } } diff --git a/crates/polars-core/src/chunked_array/upstream_traits.rs b/crates/polars-core/src/chunked_array/upstream_traits.rs index fac284c4615e..af24444fdf14 100644 --- a/crates/polars-core/src/chunked_array/upstream_traits.rs +++ b/crates/polars-core/src/chunked_array/upstream_traits.rs @@ -30,7 +30,6 @@ impl Default for ChunkedArray { phantom: PhantomData, bit_settings: Default::default(), length: 0, - null_count: 0, } } } @@ -331,7 +330,6 @@ impl FromIterator> for ObjectChunked { phantom: PhantomData, bit_settings: Default::default(), length: 0, - null_count: 0, }; out.compute_len(); out diff --git a/crates/polars-core/src/fmt.rs b/crates/polars-core/src/fmt.rs index 6dec1221081e..24758ca320aa 100644 --- a/crates/polars-core/src/fmt.rs +++ b/crates/polars-core/src/fmt.rs @@ -964,46 +964,36 @@ fn fmt_struct(f: &mut Formatter<'_>, vals: &[AnyValue]) -> fmt::Result { impl Series { pub fn fmt_list(&self) -> String { - if self.is_empty() { - return "[]".to_owned(); - } - - let max_items = std::env::var(FMT_TABLE_CELL_LIST_LEN) - .as_deref() - .unwrap_or("") - .parse() - .map_or(3, |n: i64| if n < 0 { self.len() } else { n as usize }); + match self.len() { + 0 => "[]".to_string(), + 1 => format!("[{}]", self.get(0).unwrap()), + 2 => format!("[{}, {}]", self.get(0).unwrap(), self.get(1).unwrap()), + 3 => format!( + "[{}, {}, {}]", + self.get(0).unwrap(), + self.get(1).unwrap(), + self.get(2).unwrap() + ), + _ => { + let max_items = std::env::var(FMT_TABLE_CELL_LIST_LEN) + .as_deref() + .unwrap_or("") + .parse() + .map_or(3, |n: i64| if n < 0 { self.len() } else { n as usize }); - match max_items { - 0 => "[…]".to_owned(), - _ if max_items >= self.len() => { let mut result = "[".to_owned(); - for i in 0..self.len() { - let item = self.get(i).unwrap(); + for (i, item) in self.iter().enumerate() { write!(result, "{item}").unwrap(); - // this will always leave a trailing ", " after the last item - // but for long lists, this is faster than checking against the length each time - result.push_str(", "); - } - // remove trailing ", " and replace with closing brace - result.pop(); - result.pop(); - result.push(']'); - result - }, - _ => { - let mut result = "[".to_owned(); + if i != self.len() - 1 { + result.push_str(", "); + } - for (i, item) in self.iter().enumerate() { - if i == max_items.saturating_sub(1) { + if i == max_items - 2 { result.push_str("… "); write!(result, "{}", self.get(self.len() - 1).unwrap()).unwrap(); break; - } else { - write!(result, "{item}").unwrap(); - result.push_str(", "); } } result.push(']'); @@ -1146,7 +1136,7 @@ mod test { ListPrimitiveChunkedBuilder::::new("a", 10, 10, DataType::Int32); builder.append_opt_slice(Some(&[1, 2, 3, 4, 5, 6])); builder.append_opt_slice(None); - let list_long = builder.finish().into_series(); + let list = builder.finish().into_series(); assert_eq!( r#"shape: (2,) @@ -1155,7 +1145,7 @@ Series: 'a' [list[i32]] [1, 2, … 6] null ]"#, - format!("{:?}", list_long) + format!("{:?}", list) ); std::env::set_var("POLARS_FMT_TABLE_CELL_LIST_LEN", "10"); @@ -1167,116 +1157,8 @@ Series: 'a' [list[i32]] [1, 2, 3, 4, 5, 6] null ]"#, - format!("{:?}", list_long) - ); - - std::env::set_var("POLARS_FMT_TABLE_CELL_LIST_LEN", "-1"); - - assert_eq!( - r#"shape: (2,) -Series: 'a' [list[i32]] -[ - [1, 2, 3, 4, 5, 6] - null -]"#, - format!("{:?}", list_long) - ); - - std::env::set_var("POLARS_FMT_TABLE_CELL_LIST_LEN", "0"); - - assert_eq!( - r#"shape: (2,) -Series: 'a' [list[i32]] -[ - […] - null -]"#, - format!("{:?}", list_long) - ); - - std::env::set_var("POLARS_FMT_TABLE_CELL_LIST_LEN", "1"); - - assert_eq!( - r#"shape: (2,) -Series: 'a' [list[i32]] -[ - [… 6] - null -]"#, - format!("{:?}", list_long) - ); - - std::env::set_var("POLARS_FMT_TABLE_CELL_LIST_LEN", "4"); - - assert_eq!( - r#"shape: (2,) -Series: 'a' [list[i32]] -[ - [1, 2, 3, … 6] - null -]"#, - format!("{:?}", list_long) - ); - - let mut builder = - ListPrimitiveChunkedBuilder::::new("a", 10, 10, DataType::Int32); - builder.append_opt_slice(Some(&[1])); - builder.append_opt_slice(None); - let list_short = builder.finish().into_series(); - - std::env::set_var("POLARS_FMT_TABLE_CELL_LIST_LEN", ""); - - assert_eq!( - r#"shape: (2,) -Series: 'a' [list[i32]] -[ - [1] - null -]"#, - format!("{:?}", list_short) - ); - - std::env::set_var("POLARS_FMT_TABLE_CELL_LIST_LEN", "0"); - - assert_eq!( - r#"shape: (2,) -Series: 'a' [list[i32]] -[ - […] - null -]"#, - format!("{:?}", list_short) - ); - - std::env::set_var("POLARS_FMT_TABLE_CELL_LIST_LEN", "-1"); - - assert_eq!( - r#"shape: (2,) -Series: 'a' [list[i32]] -[ - [1] - null -]"#, - format!("{:?}", list_short) - ); - - let mut builder = - ListPrimitiveChunkedBuilder::::new("a", 10, 10, DataType::Int32); - builder.append_opt_slice(Some(&[])); - builder.append_opt_slice(None); - let list_empty = builder.finish().into_series(); - - std::env::set_var("POLARS_FMT_TABLE_CELL_LIST_LEN", ""); - - assert_eq!( - r#"shape: (2,) -Series: 'a' [list[i32]] -[ - [] - null -]"#, - format!("{:?}", list_empty) - ); + format!("{:?}", list) + ) } #[test] diff --git a/crates/polars-core/src/frame/group_by/perfect.rs b/crates/polars-core/src/frame/group_by/perfect.rs index 2d050a157d1d..9e44c883315d 100644 --- a/crates/polars-core/src/frame/group_by/perfect.rs +++ b/crates/polars-core/src/frame/group_by/perfect.rs @@ -196,7 +196,7 @@ impl CategoricalChunked { if self.is_empty() { return GroupsProxy::Idx(GroupsIdx::new(vec![], vec![], true)); } - let cats = self.physical(); + let cats = self.logical(); let mut out = match &**rev_map { RevMapping::Local(cached) => { @@ -208,7 +208,7 @@ impl CategoricalChunked { // but on huge tables, this can be > 2x faster cats.group_tuples_perfect(cached.len() - 1, multithreaded, 0) } else { - self.physical().group_tuples(multithreaded, sorted).unwrap() + self.logical().group_tuples(multithreaded, sorted).unwrap() } }, RevMapping::Global(_mapping, _cached, _) => { @@ -216,7 +216,7 @@ impl CategoricalChunked { // the problem is that the global categories are not guaranteed packed together // so we might need to deref them first to local ones, but that might be more // expensive than just hashing (benchmark first) - self.physical().group_tuples(multithreaded, sorted).unwrap() + self.logical().group_tuples(multithreaded, sorted).unwrap() }, }; if sorted { diff --git a/crates/polars-core/src/series/comparison.rs b/crates/polars-core/src/series/comparison.rs index f65a6e502e64..e274770e88fe 100644 --- a/crates/polars-core/src/series/comparison.rs +++ b/crates/polars-core/src/series/comparison.rs @@ -126,7 +126,7 @@ impl ChunkCompare<&Series> for Series { #[cfg(feature = "dtype-categorical")] (Categorical(Some(rev_map_l)), Categorical(Some(rev_map_r)), _, _) => { if rev_map_l.same_src(rev_map_r) { - let rhs = rhs.categorical().unwrap().physical(); + let rhs = rhs.categorical().unwrap().logical(); // first check the rev-map if rhs.len() == 1 && rhs.null_count() == 0 { @@ -136,7 +136,7 @@ impl ChunkCompare<&Series> for Series { } } - self.categorical().unwrap().physical().equal(rhs) + self.categorical().unwrap().logical().equal(rhs) } else { polars_bail!( ComputeError: @@ -182,7 +182,7 @@ impl ChunkCompare<&Series> for Series { #[cfg(feature = "dtype-categorical")] (Categorical(Some(rev_map_l)), Categorical(Some(rev_map_r)), _, _) => { if rev_map_l.same_src(rev_map_r) { - let rhs = rhs.categorical().unwrap().physical(); + let rhs = rhs.categorical().unwrap().logical(); // first check the rev-map if rhs.len() == 1 && rhs.null_count() == 0 { @@ -192,7 +192,7 @@ impl ChunkCompare<&Series> for Series { } } - self.categorical().unwrap().physical().equal_missing(rhs) + self.categorical().unwrap().logical().equal_missing(rhs) } else { polars_bail!( ComputeError: @@ -238,7 +238,7 @@ impl ChunkCompare<&Series> for Series { #[cfg(feature = "dtype-categorical")] (Categorical(Some(rev_map_l)), Categorical(Some(rev_map_r)), _, _) => { if rev_map_l.same_src(rev_map_r) { - let rhs = rhs.categorical().unwrap().physical(); + let rhs = rhs.categorical().unwrap().logical(); // first check the rev-map if rhs.len() == 1 && rhs.null_count() == 0 { @@ -248,7 +248,7 @@ impl ChunkCompare<&Series> for Series { } } - self.categorical().unwrap().physical().not_equal(rhs) + self.categorical().unwrap().logical().not_equal(rhs) } else { polars_bail!( ComputeError: @@ -294,7 +294,7 @@ impl ChunkCompare<&Series> for Series { #[cfg(feature = "dtype-categorical")] (Categorical(Some(rev_map_l)), Categorical(Some(rev_map_r)), _, _) => { if rev_map_l.same_src(rev_map_r) { - let rhs = rhs.categorical().unwrap().physical(); + let rhs = rhs.categorical().unwrap().logical(); // first check the rev-map if rhs.len() == 1 && rhs.null_count() == 0 { @@ -304,10 +304,7 @@ impl ChunkCompare<&Series> for Series { } } - self.categorical() - .unwrap() - .physical() - .not_equal_missing(rhs) + self.categorical().unwrap().logical().not_equal_missing(rhs) } else { polars_bail!( ComputeError: diff --git a/crates/polars-core/src/series/implementations/categorical.rs b/crates/polars-core/src/series/implementations/categorical.rs index 1544f5ee1f8c..a3b096e54e4c 100644 --- a/crates/polars-core/src/series/implementations/categorical.rs +++ b/crates/polars-core/src/series/implementations/categorical.rs @@ -35,7 +35,7 @@ impl SeriesWrap { where F: Fn(&UInt32Chunked) -> UInt32Chunked, { - let cats = apply(self.0.physical()); + let cats = apply(self.0.logical()); self.finish_with_state(keep_fast_unique, cats) } @@ -47,14 +47,14 @@ impl SeriesWrap { where F: for<'b> Fn(&'a UInt32Chunked) -> PolarsResult, { - let cats = apply(self.0.physical())?; + let cats = apply(self.0.logical())?; Ok(self.finish_with_state(keep_fast_unique, cats)) } } impl private::PrivateSeries for SeriesWrap { fn compute_len(&mut self) { - self.0.physical_mut().compute_len() + self.0.logical_mut().compute_len() } fn _field(&self) -> Cow { Cow::Owned(self.0.field()) @@ -78,7 +78,7 @@ impl private::PrivateSeries for SeriesWrap { } unsafe fn equal_element(&self, idx_self: usize, idx_other: usize, other: &Series) -> bool { - self.0.physical().equal_element(idx_self, idx_other, other) + self.0.logical().equal_element(idx_self, idx_other, other) } #[cfg(feature = "zip_with")] @@ -91,24 +91,24 @@ impl private::PrivateSeries for SeriesWrap { if self.0.uses_lexical_ordering() { (&self.0).into_partial_ord_inner() } else { - self.0.physical().into_partial_ord_inner() + self.0.logical().into_partial_ord_inner() } } fn vec_hash(&self, random_state: RandomState, buf: &mut Vec) -> PolarsResult<()> { - self.0.physical().vec_hash(random_state, buf)?; + self.0.logical().vec_hash(random_state, buf)?; Ok(()) } fn vec_hash_combine(&self, build_hasher: RandomState, hashes: &mut [u64]) -> PolarsResult<()> { - self.0.physical().vec_hash_combine(build_hasher, hashes)?; + self.0.logical().vec_hash_combine(build_hasher, hashes)?; Ok(()) } #[cfg(feature = "algorithm_group_by")] unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { // we cannot cast and dispatch as the inner type of the list would be incorrect - let list = self.0.physical().agg_list(groups); + let list = self.0.logical().agg_list(groups); let mut list = list.list().unwrap().clone(); list.to_logical(self.dtype().clone()); list.into_series() @@ -122,7 +122,7 @@ impl private::PrivateSeries for SeriesWrap { } #[cfg(not(feature = "performant"))] { - self.0.physical().group_tuples(multithreaded, sorted) + self.0.logical().group_tuples(multithreaded, sorted) } } @@ -133,24 +133,24 @@ impl private::PrivateSeries for SeriesWrap { impl SeriesTrait for SeriesWrap { fn rename(&mut self, name: &str) { - self.0.physical_mut().rename(name); + self.0.logical_mut().rename(name); } fn chunk_lengths(&self) -> ChunkIdIter { - self.0.physical().chunk_id() + self.0.logical().chunk_id() } fn name(&self) -> &str { - self.0.physical().name() + self.0.logical().name() } fn chunks(&self) -> &Vec { - self.0.physical().chunks() + self.0.logical().chunks() } unsafe fn chunks_mut(&mut self) -> &mut Vec { - self.0.physical_mut().chunks_mut() + self.0.logical_mut().chunks_mut() } fn shrink_to_fit(&mut self) { - self.0.physical_mut().shrink_to_fit() + self.0.logical_mut().shrink_to_fit() } fn slice(&self, offset: i64, length: usize) -> Series { @@ -166,7 +166,7 @@ impl SeriesTrait for SeriesWrap { fn extend(&mut self, other: &Series) -> PolarsResult<()> { polars_ensure!(self.0.dtype() == other.dtype(), extend); let other = other.categorical()?; - self.0.physical_mut().extend(other.physical()); + self.0.logical_mut().extend(other.logical()); let new_rev_map = self.0._merge_categorical_map(other)?; // SAFETY // rev_maps are merged @@ -181,13 +181,13 @@ impl SeriesTrait for SeriesWrap { #[cfg(feature = "chunked_ids")] unsafe fn _take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Series { - let cats = self.0.physical().take_chunked_unchecked(by, sorted); + let cats = self.0.logical().take_chunked_unchecked(by, sorted); self.finish_with_state(false, cats).into_series() } #[cfg(feature = "chunked_ids")] unsafe fn _take_opt_chunked_unchecked(&self, by: &[Option]) -> Series { - let cats = self.0.physical().take_opt_chunked_unchecked(by); + let cats = self.0.logical().take_opt_chunked_unchecked(by); self.finish_with_state(false, cats).into_series() } @@ -246,11 +246,11 @@ impl SeriesTrait for SeriesWrap { } fn null_count(&self) -> usize { - self.0.physical().null_count() + self.0.logical().null_count() } fn has_validity(&self) -> bool { - self.0.physical().has_validity() + self.0.logical().has_validity() } #[cfg(feature = "algorithm_group_by")] @@ -265,15 +265,15 @@ impl SeriesTrait for SeriesWrap { #[cfg(feature = "algorithm_group_by")] fn arg_unique(&self) -> PolarsResult { - self.0.physical().arg_unique() + self.0.logical().arg_unique() } fn is_null(&self) -> BooleanChunked { - self.0.physical().is_null() + self.0.logical().is_null() } fn is_not_null(&self) -> BooleanChunked { - self.0.physical().is_not_null() + self.0.logical().is_not_null() } fn reverse(&self) -> Series { @@ -281,7 +281,7 @@ impl SeriesTrait for SeriesWrap { } fn as_single_ptr(&mut self) -> PolarsResult { - self.0.physical_mut().as_single_ptr() + self.0.logical_mut().as_single_ptr() } fn shift(&self, periods: i64) -> Series { @@ -289,29 +289,29 @@ impl SeriesTrait for SeriesWrap { } fn _sum_as_series(&self) -> Series { - CategoricalChunked::full_null(self.0.physical().name(), 1).into_series() + CategoricalChunked::full_null(self.0.logical().name(), 1).into_series() } fn max_as_series(&self) -> Series { - CategoricalChunked::full_null(self.0.physical().name(), 1).into_series() + CategoricalChunked::full_null(self.0.logical().name(), 1).into_series() } fn min_as_series(&self) -> Series { - CategoricalChunked::full_null(self.0.physical().name(), 1).into_series() + CategoricalChunked::full_null(self.0.logical().name(), 1).into_series() } fn median_as_series(&self) -> Series { - CategoricalChunked::full_null(self.0.physical().name(), 1).into_series() + CategoricalChunked::full_null(self.0.logical().name(), 1).into_series() } fn var_as_series(&self, _ddof: u8) -> Series { - CategoricalChunked::full_null(self.0.physical().name(), 1).into_series() + CategoricalChunked::full_null(self.0.logical().name(), 1).into_series() } fn std_as_series(&self, _ddof: u8) -> Series { - CategoricalChunked::full_null(self.0.physical().name(), 1).into_series() + CategoricalChunked::full_null(self.0.logical().name(), 1).into_series() } fn quantile_as_series( &self, _quantile: f64, _interpol: QuantileInterpolOptions, ) -> PolarsResult { - Ok(CategoricalChunked::full_null(self.0.physical().name(), 1).into_series()) + Ok(CategoricalChunked::full_null(self.0.logical().name(), 1).into_series()) } fn clone_inner(&self) -> Arc { @@ -324,6 +324,6 @@ impl private::PrivateSeriesNumeric for SeriesWrap { false } fn bit_repr_small(&self) -> UInt32Chunked { - self.0.physical().clone() + self.0.logical().clone() } } diff --git a/crates/polars-core/src/series/implementations/dates_time.rs b/crates/polars-core/src/series/implementations/dates_time.rs index 831495f19e9f..120f7a1e11b4 100644 --- a/crates/polars-core/src/series/implementations/dates_time.rs +++ b/crates/polars-core/src/series/implementations/dates_time.rs @@ -53,6 +53,17 @@ macro_rules! impl_dyn_series { .into_series() } + #[cfg(feature = "cum_agg")] + fn _cummax(&self, reverse: bool) -> Series { + self.0.cummax(reverse).$into_logical().into_series() + } + + #[cfg(feature = "cum_agg")] + fn _cummin(&self, reverse: bool) -> Series { + self.0.cummin(reverse).$into_logical().into_series() + } + + #[cfg(feature = "zip_with")] fn zip_with_same_type( &self, diff --git a/crates/polars-core/src/series/implementations/datetime.rs b/crates/polars-core/src/series/implementations/datetime.rs index 071da77dc53d..59b3bce8a5e2 100644 --- a/crates/polars-core/src/series/implementations/datetime.rs +++ b/crates/polars-core/src/series/implementations/datetime.rs @@ -49,6 +49,22 @@ impl private::PrivateSeries for SeriesWrap { .into_series() } + #[cfg(feature = "cum_agg")] + fn _cummax(&self, reverse: bool) -> Series { + self.0 + .cummax(reverse) + .into_datetime(self.0.time_unit(), self.0.time_zone().clone()) + .into_series() + } + + #[cfg(feature = "cum_agg")] + fn _cummin(&self, reverse: bool) -> Series { + self.0 + .cummin(reverse) + .into_datetime(self.0.time_unit(), self.0.time_zone().clone()) + .into_series() + } + #[cfg(feature = "zip_with")] fn zip_with_same_type(&self, mask: &BooleanChunked, other: &Series) -> PolarsResult { let other = other.to_physical_repr().into_owned(); diff --git a/crates/polars-core/src/series/implementations/duration.rs b/crates/polars-core/src/series/implementations/duration.rs index c3dca8662f0a..834c6c57c181 100644 --- a/crates/polars-core/src/series/implementations/duration.rs +++ b/crates/polars-core/src/series/implementations/duration.rs @@ -44,6 +44,22 @@ impl private::PrivateSeries for SeriesWrap { .into_series() } + #[cfg(feature = "cum_agg")] + fn _cummax(&self, reverse: bool) -> Series { + self.0 + .cummax(reverse) + .into_duration(self.0.time_unit()) + .into_series() + } + + #[cfg(feature = "cum_agg")] + fn _cummin(&self, reverse: bool) -> Series { + self.0 + .cummin(reverse) + .into_duration(self.0.time_unit()) + .into_series() + } + fn _set_flags(&mut self, flags: Settings) { self.0.deref_mut().set_flags(flags) } diff --git a/crates/polars-core/src/series/implementations/floats.rs b/crates/polars-core/src/series/implementations/floats.rs index 0e4949ad4ae3..92b7cb018def 100644 --- a/crates/polars-core/src/series/implementations/floats.rs +++ b/crates/polars-core/src/series/implementations/floats.rs @@ -40,6 +40,16 @@ macro_rules! impl_dyn_series { self.0.explode_by_offsets(offsets) } + #[cfg(feature = "cum_agg")] + fn _cummax(&self, reverse: bool) -> Series { + self.0.cummax(reverse).into_series() + } + + #[cfg(feature = "cum_agg")] + fn _cummin(&self, reverse: bool) -> Series { + self.0.cummin(reverse).into_series() + } + unsafe fn equal_element( &self, idx_self: usize, diff --git a/crates/polars-core/src/series/implementations/mod.rs b/crates/polars-core/src/series/implementations/mod.rs index 86cb1bc8efe5..15dc5dfd86f6 100644 --- a/crates/polars-core/src/series/implementations/mod.rs +++ b/crates/polars-core/src/series/implementations/mod.rs @@ -105,6 +105,16 @@ macro_rules! impl_dyn_series { self.0.explode_by_offsets(offsets) } + #[cfg(feature = "cum_agg")] + fn _cummax(&self, reverse: bool) -> Series { + self.0.cummax(reverse).into_series() + } + + #[cfg(feature = "cum_agg")] + fn _cummin(&self, reverse: bool) -> Series { + self.0.cummin(reverse).into_series() + } + unsafe fn equal_element( &self, idx_self: usize, diff --git a/crates/polars-core/src/series/into.rs b/crates/polars-core/src/series/into.rs index e718bca831fc..3f97222d17a7 100644 --- a/crates/polars-core/src/series/into.rs +++ b/crates/polars-core/src/series/into.rs @@ -59,7 +59,7 @@ impl Series { #[cfg(feature = "dtype-categorical")] DataType::Categorical(_) => { let ca = self.categorical().unwrap(); - let arr = ca.physical().chunks()[chunk_idx].clone(); + let arr = ca.logical().chunks()[chunk_idx].clone(); // SAFETY: categoricals are always u32's. let cats = unsafe { UInt32Chunked::from_chunks("", vec![arr]) }; diff --git a/crates/polars-core/src/series/mod.rs b/crates/polars-core/src/series/mod.rs index 0cce71998a2a..2ec370820c67 100644 --- a/crates/polars-core/src/series/mod.rs +++ b/crates/polars-core/src/series/mod.rs @@ -181,7 +181,6 @@ impl Series { /// # Safety /// The caller must ensure the length and the data types of `ArrayRef` does not change. - /// And that the null_count is updated (e.g. with a `compute_len()`) pub unsafe fn chunks_mut(&mut self) -> &mut Vec { #[allow(unused_mut)] let mut ca = self._get_inner_mut(); @@ -255,11 +254,6 @@ impl Series { Ok(self) } - /// Redo a length and null_count compute - pub fn compute_len(&mut self) { - self._get_inner_mut().compute_len() - } - /// Extend the memory backed by this array with the values from `other`. /// /// See [`ChunkedArray::extend`] and [`ChunkedArray::append`]. @@ -624,6 +618,94 @@ impl Series { } } + /// Get an array with the cumulative max computed at every element. + pub fn cummax(&self, _reverse: bool) -> Series { + #[cfg(feature = "cum_agg")] + { + self._cummax(_reverse) + } + #[cfg(not(feature = "cum_agg"))] + { + panic!("activate 'cum_agg' feature") + } + } + + /// Get an array with the cumulative min computed at every element. + pub fn cummin(&self, _reverse: bool) -> Series { + #[cfg(feature = "cum_agg")] + { + self._cummin(_reverse) + } + #[cfg(not(feature = "cum_agg"))] + { + panic!("activate 'cum_agg' feature") + } + } + + /// Get an array with the cumulative sum computed at every element + /// + /// If the [`DataType`] is one of `{Int8, UInt8, Int16, UInt16}` the `Series` is + /// first cast to `Int64` to prevent overflow issues. + #[allow(unused_variables)] + pub fn cumsum(&self, reverse: bool) -> Series { + #[cfg(feature = "cum_agg")] + { + use DataType::*; + match self.dtype() { + Boolean => self.cast(&DataType::UInt32).unwrap().cumsum(reverse), + Int8 | UInt8 | Int16 | UInt16 => { + let s = self.cast(&Int64).unwrap(); + s.cumsum(reverse) + }, + Int32 => self.i32().unwrap().cumsum(reverse).into_series(), + UInt32 => self.u32().unwrap().cumsum(reverse).into_series(), + UInt64 => self.u64().unwrap().cumsum(reverse).into_series(), + Int64 => self.i64().unwrap().cumsum(reverse).into_series(), + Float32 => self.f32().unwrap().cumsum(reverse).into_series(), + Float64 => self.f64().unwrap().cumsum(reverse).into_series(), + #[cfg(feature = "dtype-duration")] + Duration(tu) => { + let ca = self.to_physical_repr(); + let ca = ca.i64().unwrap(); + ca.cumsum(reverse).cast(&Duration(*tu)).unwrap() + }, + dt => panic!("cumsum not supported for dtype: {dt:?}"), + } + } + #[cfg(not(feature = "cum_agg"))] + { + panic!("activate 'cum_agg' feature") + } + } + + /// Get an array with the cumulative product computed at every element. + /// + /// If the [`DataType`] is one of `{Int8, UInt8, Int16, UInt16, Int32, UInt32}` the `Series` is + /// first cast to `Int64` to prevent overflow issues. + #[allow(unused_variables)] + pub fn cumprod(&self, reverse: bool) -> Series { + #[cfg(feature = "cum_agg")] + { + use DataType::*; + match self.dtype() { + Boolean => self.cast(&DataType::Int64).unwrap().cumprod(reverse), + Int8 | UInt8 | Int16 | UInt16 | Int32 | UInt32 => { + let s = self.cast(&Int64).unwrap(); + s.cumprod(reverse) + }, + Int64 => self.i64().unwrap().cumprod(reverse).into_series(), + UInt64 => self.u64().unwrap().cumprod(reverse).into_series(), + Float32 => self.f32().unwrap().cumprod(reverse).into_series(), + Float64 => self.f64().unwrap().cumprod(reverse).into_series(), + dt => panic!("cumprod not supported for dtype: {dt:?}"), + } + } + #[cfg(not(feature = "cum_agg"))] + { + panic!("activate 'cum_agg' feature") + } + } + /// Get the product of an array. /// /// If the [`DataType`] is one of `{Int8, UInt8, Int16, UInt16}` the `Series` is @@ -894,8 +976,7 @@ impl Series { /// Packs every element into a list. pub fn as_list(&self) -> ListChunked { let s = self.rechunk(); - // don't use `to_arrow` as we need the physical types - let values = s.chunks()[0].clone(); + let values = s.to_arrow(0); let offsets = (0i64..(s.len() as i64 + 1)).collect::>(); let offsets = unsafe { Offsets::new_unchecked(offsets) }; @@ -1050,4 +1131,13 @@ mod test { let _ = series.slice(-6, 2); let _ = series.slice(4, 2); } + + #[test] + #[cfg(feature = "round_series")] + fn test_round_series() { + let series = Series::new("a", &[1.003, 2.23222, 3.4352]); + let out = series.round(2).unwrap(); + let ca = out.f64().unwrap(); + assert_eq!(ca.get(0), Some(1.0)); + } } diff --git a/crates/polars-core/src/series/ops/diff.rs b/crates/polars-core/src/series/ops/diff.rs new file mode 100644 index 000000000000..b3452ccade81 --- /dev/null +++ b/crates/polars-core/src/series/ops/diff.rs @@ -0,0 +1,24 @@ +use crate::prelude::*; +use crate::series::ops::NullBehavior; + +impl Series { + pub fn diff(&self, n: i64, null_behavior: NullBehavior) -> PolarsResult { + use DataType::*; + let s = match self.dtype() { + UInt8 => self.cast(&Int16).unwrap(), + UInt16 => self.cast(&Int32).unwrap(), + UInt32 | UInt64 => self.cast(&Int64).unwrap(), + _ => self.clone(), + }; + + match null_behavior { + NullBehavior::Ignore => Ok(&s - &s.shift(n)), + NullBehavior::Drop => { + polars_ensure!(n > 0, InvalidOperation: "only positive integer allowed if nulls are dropped in 'diff' operation"); + let n = n as usize; + let len = s.len() - n; + Ok(&self.slice(n as i64, len) - &s.slice(0, len)) + }, + } + } +} diff --git a/crates/polars-core/src/series/ops/ewm.rs b/crates/polars-core/src/series/ops/ewm.rs new file mode 100644 index 000000000000..388a44eb4a2d --- /dev/null +++ b/crates/polars-core/src/series/ops/ewm.rs @@ -0,0 +1,104 @@ +use std::convert::TryFrom; + +pub use arrow::legacy::kernels::ewm::EWMOptions; +use arrow::legacy::kernels::ewm::{ewm_mean, ewm_std, ewm_var}; + +use crate::prelude::*; + +fn check_alpha(alpha: f64) -> PolarsResult<()> { + polars_ensure!((0.0..=1.0).contains(&alpha), ComputeError: "alpha must be in [0; 1]"); + Ok(()) +} + +impl Series { + pub fn ewm_mean(&self, options: EWMOptions) -> PolarsResult { + check_alpha(options.alpha)?; + match self.dtype() { + DataType::Float32 => { + let xs = self.f32().unwrap(); + let result = ewm_mean( + xs, + options.alpha as f32, + options.adjust, + options.min_periods, + options.ignore_nulls, + ); + Series::try_from((self.name(), Box::new(result) as ArrayRef)) + }, + DataType::Float64 => { + let xs = self.f64().unwrap(); + let result = ewm_mean( + xs, + options.alpha, + options.adjust, + options.min_periods, + options.ignore_nulls, + ); + Series::try_from((self.name(), Box::new(result) as ArrayRef)) + }, + _ => self.cast(&DataType::Float64)?.ewm_mean(options), + } + } + + pub fn ewm_std(&self, options: EWMOptions) -> PolarsResult { + check_alpha(options.alpha)?; + match self.dtype() { + DataType::Float32 => { + let xs = self.f32().unwrap(); + let result = ewm_std( + xs, + options.alpha as f32, + options.adjust, + options.bias, + options.min_periods, + options.ignore_nulls, + ); + Series::try_from((self.name(), Box::new(result) as ArrayRef)) + }, + DataType::Float64 => { + let xs = self.f64().unwrap(); + let result = ewm_std( + xs, + options.alpha, + options.adjust, + options.bias, + options.min_periods, + options.ignore_nulls, + ); + Series::try_from((self.name(), Box::new(result) as ArrayRef)) + }, + _ => self.cast(&DataType::Float64)?.ewm_std(options), + } + } + + pub fn ewm_var(&self, options: EWMOptions) -> PolarsResult { + check_alpha(options.alpha)?; + match self.dtype() { + DataType::Float32 => { + let xs = self.f32().unwrap(); + let result = ewm_var( + xs, + options.alpha as f32, + options.adjust, + options.bias, + options.min_periods, + options.ignore_nulls, + ); + Series::try_from((self.name(), Box::new(result) as ArrayRef)) + }, + DataType::Float64 => { + let xs = self.f64().unwrap(); + let result = ewm_var( + xs, + options.alpha, + options.adjust, + options.bias, + options.min_periods, + options.ignore_nulls, + ); + Series::try_from((self.name(), Box::new(result) as ArrayRef)) + }, + _ => self.cast(&DataType::Float64)?.ewm_var(options), + } + } +} diff --git a/crates/polars-core/src/series/ops/mod.rs b/crates/polars-core/src/series/ops/mod.rs index 650e5cbecaa4..8730c012b163 100644 --- a/crates/polars-core/src/series/ops/mod.rs +++ b/crates/polars-core/src/series/ops/mod.rs @@ -1,8 +1,16 @@ +#[cfg(feature = "diff")] +pub mod diff; mod downcast; +#[cfg(feature = "ewma")] +mod ewm; mod extend; #[cfg(feature = "moment")] pub mod moment; mod null; +#[cfg(feature = "pct_change")] +pub mod pct_change; +#[cfg(feature = "round_series")] +mod round; mod to_list; mod unique; #[cfg(feature = "serde")] diff --git a/crates/polars-core/src/series/ops/pct_change.rs b/crates/polars-core/src/series/ops/pct_change.rs new file mode 100644 index 000000000000..4135a0427bae --- /dev/null +++ b/crates/polars-core/src/series/ops/pct_change.rs @@ -0,0 +1,48 @@ +use crate::prelude::*; +use crate::series::ops::NullBehavior; + +impl Series { + pub fn pct_change(&self, n: i64) -> PolarsResult { + match self.dtype() { + DataType::Float64 | DataType::Float32 => {}, + _ => return self.cast(&DataType::Float64)?.pct_change(n), + } + let nn = self.fill_null(FillNullStrategy::Forward(None))?; + nn.diff(n, NullBehavior::Ignore)?.divide(&nn.shift(n)) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_nulls() -> PolarsResult<()> { + let s = Series::new("", &[Some(1), None, Some(2), None, Some(3)]); + assert_eq!( + s.pct_change(1)?, + Series::new("", &[None, Some(0.0f64), Some(1.0), Some(0.), Some(0.5)]) + ); + Ok(()) + } + + #[test] + fn test_same() -> PolarsResult<()> { + let s = Series::new("", &[Some(1), Some(1), Some(1)]); + assert_eq!( + s.pct_change(1)?, + Series::new("", &[None, Some(0.0f64), Some(0.0)]) + ); + Ok(()) + } + + #[test] + fn test_two_periods() -> PolarsResult<()> { + let s = Series::new("", &[Some(1), Some(2), Some(4), Some(8), Some(16)]); + assert_eq!( + s.pct_change(2)?, + Series::new("", &[None, None, Some(3.0f64), Some(3.0), Some(3.0)]) + ); + Ok(()) + } +} diff --git a/crates/polars-ops/src/series/ops/round.rs b/crates/polars-core/src/series/ops/round.rs similarity index 55% rename from crates/polars-ops/src/series/ops/round.rs rename to crates/polars-core/src/series/ops/round.rs index 9d090099a42a..37abe7797941 100644 --- a/crates/polars-ops/src/series/ops/round.rs +++ b/crates/polars-core/src/series/ops/round.rs @@ -1,17 +1,14 @@ use num_traits::pow::Pow; -use polars_core::prelude::*; -use crate::series::ops::SeriesSealed; +use crate::prelude::*; -pub trait RoundSeries: SeriesSealed { +impl Series { /// Round underlying floating point array to given decimal. - fn round(&self, decimals: u32) -> PolarsResult { - let s = self.as_series(); - - if let Ok(ca) = s.f32() { - return if decimals == 0 { + pub fn round(&self, decimals: u32) -> PolarsResult { + if let Ok(ca) = self.f32() { + if decimals == 0 { let s = ca.apply_values(|val| val.round()).into_series(); - Ok(s) + return Ok(s); } else { // Note we do the computation on f64 floats to not lose precision // when the computation is done, we cast to f32 @@ -19,66 +16,47 @@ pub trait RoundSeries: SeriesSealed { let s = ca .apply_values(|val| ((val as f64 * multiplier).round() / multiplier) as f32) .into_series(); - Ok(s) - }; + return Ok(s); + } } - if let Ok(ca) = s.f64() { - return if decimals == 0 { + if let Ok(ca) = self.f64() { + if decimals == 0 { let s = ca.apply_values(|val| val.round()).into_series(); - Ok(s) + return Ok(s); } else { let multiplier = 10.0.pow(decimals as f64); let s = ca .apply_values(|val| (val * multiplier).round() / multiplier) .into_series(); - Ok(s) - }; + return Ok(s); + } } - polars_bail!(opq = round, s.dtype()); + polars_bail!(opq = round, self.dtype()); } /// Floor underlying floating point array to the lowest integers smaller or equal to the float value. - fn floor(&self) -> PolarsResult { - let s = self.as_series(); - - if let Ok(ca) = s.f32() { + pub fn floor(&self) -> PolarsResult { + if let Ok(ca) = self.f32() { let s = ca.apply_values(|val| val.floor()).into_series(); return Ok(s); } - if let Ok(ca) = s.f64() { + if let Ok(ca) = self.f64() { let s = ca.apply_values(|val| val.floor()).into_series(); return Ok(s); } - polars_bail!(opq = floor, s.dtype()); + polars_bail!(opq = floor, self.dtype()); } /// Ceil underlying floating point array to the highest integers smaller or equal to the float value. - fn ceil(&self) -> PolarsResult { - let s = self.as_series(); - - if let Ok(ca) = s.f32() { + pub fn ceil(&self) -> PolarsResult { + if let Ok(ca) = self.f32() { let s = ca.apply_values(|val| val.ceil()).into_series(); return Ok(s); } - if let Ok(ca) = s.f64() { + if let Ok(ca) = self.f64() { let s = ca.apply_values(|val| val.ceil()).into_series(); return Ok(s); } - polars_bail!(opq = ceil, s.dtype()); - } -} - -impl RoundSeries for Series {} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_round_series() { - let series = Series::new("a", &[1.003, 2.23222, 3.4352]); - let out = series.round(2).unwrap(); - let ca = out.f64().unwrap(); - assert_eq!(ca.get(0), Some(1.0)); + polars_bail!(opq = ceil, self.dtype()); } } diff --git a/crates/polars-core/src/series/series_trait.rs b/crates/polars-core/src/series/series_trait.rs index f7ab23947447..a99a59dc7999 100644 --- a/crates/polars-core/src/series/series_trait.rs +++ b/crates/polars-core/src/series/series_trait.rs @@ -88,6 +88,18 @@ pub(crate) mod private { invalid_operation_panic!(explode_by_offsets, self) } + /// Get an array with the cumulative max computed at every element + #[cfg(feature = "cum_agg")] + fn _cummax(&self, _reverse: bool) -> Series { + panic!("operation cummax not supported for this dtype") + } + + /// Get an array with the cumulative min computed at every element + #[cfg(feature = "cum_agg")] + fn _cummin(&self, _reverse: bool) -> Series { + panic!("operation cummin not supported for this dtype") + } + unsafe fn equal_element( &self, _idx_self: usize, diff --git a/crates/polars-core/src/utils/mod.rs b/crates/polars-core/src/utils/mod.rs index c85c807096f5..43ee28b644f9 100644 --- a/crates/polars-core/src/utils/mod.rs +++ b/crates/polars-core/src/utils/mod.rs @@ -879,7 +879,6 @@ pub fn coalesce_nulls<'a, T: PolarsDataType>( *arr_b = arr_b.with_validity(arr.validity().cloned()) } } - b.compute_len(); (Cow::Owned(a), Cow::Owned(b)) } else { (Cow::Borrowed(a), Cow::Borrowed(b)) @@ -900,8 +899,6 @@ pub fn coalesce_nulls_series(a: &Series, b: &Series) -> (Series, Series) { *arr_a = arr_a.with_validity(validity.clone()); *arr_b = arr_b.with_validity(validity); } - a.compute_len(); - b.compute_len(); (a, b) } else { (a.clone(), b.clone()) diff --git a/crates/polars-error/Cargo.toml b/crates/polars-error/Cargo.toml index 60e4800f073f..689e755e9a20 100644 --- a/crates/polars-error/Cargo.toml +++ b/crates/polars-error/Cargo.toml @@ -11,8 +11,8 @@ description = "Error definitions for the Polars DataFrame library" [dependencies] arrow-format = { version = "0.8.1", optional = true } avro-schema = { workspace = true, optional = true } -object_store = { workspace = true, optional = true } -parquet2 = { workspace = true, optional = true } +object_store = { workspace = true, default-features = false, optional = true } +parquet2 = { workspace = true, optional = true, default-features = false } regex = { workspace = true, optional = true } simdutf8 = { workspace = true } thiserror = { workspace = true } diff --git a/crates/polars-ffi/src/lib.rs b/crates/polars-ffi/src/lib.rs index 1be2c01c477b..699d5e7a7fd5 100644 --- a/crates/polars-ffi/src/lib.rs +++ b/crates/polars-ffi/src/lib.rs @@ -24,22 +24,6 @@ pub struct SeriesExport { private_data: *mut std::os::raw::c_void, } -impl SeriesExport { - pub fn empty() -> Self { - Self { - field: std::ptr::null_mut(), - arrays: std::ptr::null_mut(), - len: 0, - release: None, - private_data: std::ptr::null_mut(), - } - } - - pub fn is_null(&self) -> bool { - self.private_data.is_null() - } -} - impl Drop for SeriesExport { fn drop(&mut self) { if let Some(release) = self.release { @@ -97,7 +81,11 @@ pub unsafe fn import_series(e: SeriesExport) -> PolarsResult { }) .collect::>>()?; - Series::try_from((field.name.as_str(), chunks)) + Ok(Series::from_chunks_and_dtype_unchecked( + &field.name, + chunks, + &(&field.data_type).into(), + )) } /// # Safety diff --git a/crates/polars-io/Cargo.toml b/crates/polars-io/Cargo.toml index 18e69c0e9647..26819a2c58f0 100644 --- a/crates/polars-io/Cargo.toml +++ b/crates/polars-io/Cargo.toml @@ -10,7 +10,7 @@ description = "IO related logic for the Polars DataFrame library" [dependencies] polars-core = { workspace = true } -polars-error = { workspace = true } +polars-error = { workspace = true, default-features = false } polars-json = { workspace = true, optional = true } polars-time = { workspace = true, features = [], optional = true } polars-utils = { workspace = true } diff --git a/crates/polars-io/src/cloud/object_store_setup.rs b/crates/polars-io/src/cloud/object_store_setup.rs index 5826c48f8bd1..91119dbbf248 100644 --- a/crates/polars-io/src/cloud/object_store_setup.rs +++ b/crates/polars-io/src/cloud/object_store_setup.rs @@ -4,7 +4,7 @@ use tokio::sync::RwLock; use super::*; -type CacheKey = (String, Option); +type CacheKey = (CloudType, Option); /// A very simple cache that only stores a single object-store. /// This greatly reduces the query times as multiple object stores (when reading many small files) @@ -22,13 +22,21 @@ fn err_missing_feature(feature: &str, scheme: &str) -> BuildResult { "feature '{}' must be enabled in order to use '{}' cloud urls", feature, scheme, ); } +#[cfg(any(feature = "azure", feature = "aws", feature = "gcp"))] +fn err_missing_configuration(feature: &str, scheme: &str) -> BuildResult { + polars_bail!( + ComputeError: + "configuration '{}' must be provided in order to use '{}' cloud urls", feature, scheme, + ); +} /// Build an [`ObjectStore`] based on the URL and passed in url. Return the cloud location and an implementation of the object store. pub async fn build_object_store(url: &str, options: Option<&CloudOptions>) -> BuildResult { let cloud_location = CloudLocation::new(url)?; + let cloud_type = CloudType::from_str(url)?; let options = options.cloned(); - let key = (url.to_string(), options); + let key = (cloud_type, options); { let cache = OBJECT_STORE_CACHE.read().await; @@ -39,13 +47,7 @@ pub async fn build_object_store(url: &str, options: Option<&CloudOptions>) -> Bu } } - let cloud_type = CloudType::from_str(url)?; - let options = key - .1 - .as_ref() - .map(Cow::Borrowed) - .unwrap_or_else(|| Cow::Owned(Default::default())); - let store = match cloud_type { + let store = match key.0 { CloudType::File => { let local = LocalFileSystem::new(); Ok::<_, PolarsError>(Arc::new(local) as Arc) @@ -53,6 +55,11 @@ pub async fn build_object_store(url: &str, options: Option<&CloudOptions>) -> Bu CloudType::Aws => { #[cfg(feature = "aws")] { + let options = key + .1 + .as_ref() + .map(Cow::Borrowed) + .unwrap_or_else(|| Cow::Owned(Default::default())); let store = options.build_aws(url).await?; Ok::<_, PolarsError>(Arc::new(store) as Arc) } @@ -61,9 +68,12 @@ pub async fn build_object_store(url: &str, options: Option<&CloudOptions>) -> Bu }, CloudType::Gcp => { #[cfg(feature = "gcp")] - { - let store = options.build_gcp(url)?; - Ok::<_, PolarsError>(Arc::new(store) as Arc) + match key.1.as_ref() { + Some(options) => { + let store = options.build_gcp(url)?; + Ok::<_, PolarsError>(Arc::new(store) as Arc) + }, + _ => return err_missing_configuration("gcp", &cloud_location.scheme), } #[cfg(not(feature = "gcp"))] return err_missing_feature("gcp", &cloud_location.scheme); @@ -71,9 +81,12 @@ pub async fn build_object_store(url: &str, options: Option<&CloudOptions>) -> Bu CloudType::Azure => { { #[cfg(feature = "azure")] - { - let store = options.build_azure(url)?; - Ok::<_, PolarsError>(Arc::new(store) as Arc) + match key.1.as_ref() { + Some(options) => { + let store = options.build_azure(url)?; + Ok::<_, PolarsError>(Arc::new(store) as Arc) + }, + _ => return err_missing_configuration("azure", &cloud_location.scheme), } } #[cfg(not(feature = "azure"))] diff --git a/crates/polars-io/src/parquet/read.rs b/crates/polars-io/src/parquet/read.rs index 4df47cf9e658..2586ce0b25d8 100644 --- a/crates/polars-io/src/parquet/read.rs +++ b/crates/polars-io/src/parquet/read.rs @@ -129,13 +129,6 @@ impl ParquetReader { self } - /// Set the [`Schema`] if already known. This must be exactly the same as - /// the schema in the file itself. - pub fn with_schema(mut self, schema: Option) -> Self { - self.schema = schema; - self - } - /// [`Schema`] of the file. pub fn schema(&mut self) -> PolarsResult { match &self.schema { diff --git a/crates/polars-io/src/parquet/read_impl.rs b/crates/polars-io/src/parquet/read_impl.rs index 51ec4669f25c..5827f44a57e1 100644 --- a/crates/polars-io/src/parquet/read_impl.rs +++ b/crates/polars-io/src/parquet/read_impl.rs @@ -96,15 +96,10 @@ pub(super) fn array_iter_to_series( } /// Materializes hive partitions. -/// We have a special num_rows arg, as df can be empty when a projection contains -/// only hive partition columns. -/// Safety: num_rows equals the height of the df when the df height is non-zero. -fn materialize_hive_partitions( - df: &mut DataFrame, - hive_partition_columns: Option<&[Series]>, - num_rows: usize, -) { +fn materialize_hive_partitions(df: &mut DataFrame, hive_partition_columns: Option<&[Series]>) { if let Some(hive_columns) = hive_partition_columns { + let num_rows = df.height(); + for s in hive_columns { unsafe { df.with_column_unchecked(s.new_from_index(0, num_rows)) }; } @@ -196,7 +191,6 @@ fn rg_to_dfs_optionally_par_over_columns( assert!(std::env::var("POLARS_PANIC_IF_PARQUET_PARSED").is_err()) } - let projection_height = (*remaining_rows).min(md.num_rows()); let chunk_size = md.num_rows(); let columns = if let ParallelStrategy::Columns = parallel { POOL.install(|| { @@ -206,7 +200,7 @@ fn rg_to_dfs_optionally_par_over_columns( column_idx_to_series( *column_i, md, - projection_height, + *remaining_rows, schema, store, chunk_size, @@ -218,26 +212,20 @@ fn rg_to_dfs_optionally_par_over_columns( projection .iter() .map(|column_i| { - column_idx_to_series( - *column_i, - md, - projection_height, - schema, - store, - chunk_size, - ) + column_idx_to_series(*column_i, md, *remaining_rows, schema, store, chunk_size) }) .collect::>>()? }; - *remaining_rows -= projection_height; + *remaining_rows = + remaining_rows.saturating_sub(file_metadata.row_groups[rg_idx].num_rows()); let mut df = DataFrame::new_no_checks(columns); if let Some(rc) = &row_count { df.with_row_count_mut(&rc.name, Some(*previous_row_count + rc.offset)); } + materialize_hive_partitions(&mut df, hive_partition_columns); - materialize_hive_partitions(&mut df, hive_partition_columns, projection_height); apply_predicate(&mut df, predicate, true)?; *previous_row_count += current_row_count; @@ -277,17 +265,17 @@ fn rg_to_dfs_par_over_rg( let row_count_start = *previous_row_count; let num_rows = rg_md.num_rows(); *previous_row_count += num_rows as IdxSize; - let projection_height = (*remaining_rows).min(num_rows); - *remaining_rows -= projection_height; + let local_limit = *remaining_rows; + *remaining_rows = remaining_rows.saturating_sub(num_rows); - (rg_idx, rg_md, projection_height, row_count_start) + (rg_idx, rg_md, local_limit, row_count_start) }) .collect::>(); let dfs = row_groups .into_par_iter() - .map(|(rg_idx, md, projection_height, row_count_start)| { - if projection_height == 0 + .map(|(rg_idx, md, local_limit, row_count_start)| { + if local_limit == 0 || use_statistics && !read_this_row_group(predicate, &file_metadata.row_groups[rg_idx], schema)? { @@ -303,14 +291,7 @@ fn rg_to_dfs_par_over_rg( let columns = projection .iter() .map(|column_i| { - column_idx_to_series( - *column_i, - md, - projection_height, - schema, - store, - chunk_size, - ) + column_idx_to_series(*column_i, md, local_limit, schema, store, chunk_size) }) .collect::>>()?; @@ -319,8 +300,8 @@ fn rg_to_dfs_par_over_rg( if let Some(rc) = &row_count { df.with_row_count_mut(&rc.name, Some(row_count_start as IdxSize + rc.offset)); } + materialize_hive_partitions(&mut df, hive_partition_columns); - materialize_hive_partitions(&mut df, hive_partition_columns, projection_height); apply_predicate(&mut df, predicate, false)?; Ok(Some(df)) diff --git a/crates/polars-lazy/Cargo.toml b/crates/polars-lazy/Cargo.toml index 554e8eeda37d..7c6a7770f8d0 100644 --- a/crates/polars-lazy/Cargo.toml +++ b/crates/polars-lazy/Cargo.toml @@ -10,7 +10,7 @@ description = "Lazy query engine for the Polars DataFrame library" [dependencies] arrow = { workspace = true } -polars-core = { workspace = true, features = ["lazy", "zip_with", "random"] } +polars-core = { workspace = true, features = ["lazy", "zip_with", "random"], default-features = false } polars-io = { workspace = true, features = ["lazy"] } polars-json = { workspace = true, optional = true } polars-ops = { workspace = true } @@ -138,7 +138,6 @@ fused = ["polars-plan/fused", "polars-ops/fused"] list_sets = ["polars-plan/list_sets", "polars-ops/list_sets"] list_any_all = ["polars-ops/list_any_all", "polars-plan/list_any_all"] list_drop_nulls = ["polars-ops/list_drop_nulls", "polars-plan/list_drop_nulls"] -list_sample = ["polars-ops/list_sample", "polars-plan/list_sample"] cutqcut = ["polars-plan/cutqcut", "polars-ops/cutqcut"] rle = ["polars-plan/rle", "polars-ops/rle"] extract_groups = ["polars-plan/extract_groups"] diff --git a/crates/polars-lazy/src/frame/mod.rs b/crates/polars-lazy/src/frame/mod.rs index 20292101113e..f08a16d28bf7 100644 --- a/crates/polars-lazy/src/frame/mod.rs +++ b/crates/polars-lazy/src/frame/mod.rs @@ -31,7 +31,7 @@ use polars_plan::global::FETCH_ROWS; #[cfg(any(feature = "ipc", feature = "parquet", feature = "csv"))] use polars_plan::logical_plan::collect_fingerprints; use polars_plan::logical_plan::optimize; -use polars_plan::utils::expr_output_name; +use polars_plan::utils::expr_to_leaf_column_names; use smartstring::alias::String as SmartString; use crate::fallible; @@ -1182,7 +1182,7 @@ impl LazyFrame { JoinBuilder::new(self) } - /// Add or replace a column, given as an expression, to a DataFrame. + /// Add a column, given as an expression, to a DataFrame. /// /// # Example /// @@ -1214,7 +1214,7 @@ impl LazyFrame { Self::from_logical_plan(lp, opt_state) } - /// Add or replace multiple columns, given as expressions, to a DataFrame. + /// Add multiple columns, given as expressions, to a DataFrame. /// /// # Example /// @@ -1239,7 +1239,7 @@ impl LazyFrame { ) } - /// Add or replace multiple columns to a DataFrame, but evaluate them sequentially. + /// Add multiple columns to a DataFrame, but evaluate them sequentially. pub fn with_columns_seq>(self, exprs: E) -> LazyFrame { let exprs = exprs.as_ref().to_vec(); self.with_columns_impl( @@ -1674,10 +1674,10 @@ impl LazyGroupBy { let keys = self .keys .iter() - .filter_map(|expr| expr_output_name(expr).ok()) + .flat_map(|k| expr_to_leaf_column_names(k).into_iter()) .collect::>(); - self.agg([col("*").exclude(&keys).head(n)]) + self.agg([col("*").exclude(&keys).head(n).keep_name()]) .explode([col("*").exclude(&keys)]) } @@ -1686,10 +1686,10 @@ impl LazyGroupBy { let keys = self .keys .iter() - .filter_map(|expr| expr_output_name(expr).ok()) + .flat_map(|k| expr_to_leaf_column_names(k).into_iter()) .collect::>(); - self.agg([col("*").exclude(&keys).tail(n)]) + self.agg([col("*").exclude(&keys).tail(n).keep_name()]) .explode([col("*").exclude(&keys)]) } diff --git a/crates/polars-lazy/src/physical_plan/executors/join.rs b/crates/polars-lazy/src/physical_plan/executors/join.rs index 5898aea109ca..fa84d46e7a84 100644 --- a/crates/polars-lazy/src/physical_plan/executors/join.rs +++ b/crates/polars-lazy/src/physical_plan/executors/join.rs @@ -1,5 +1,3 @@ -use polars_ops::frame::DataFrameJoinOps; - use super::*; pub struct JoinExec { diff --git a/crates/polars-lazy/src/physical_plan/executors/scan/csv.rs b/crates/polars-lazy/src/physical_plan/executors/scan/csv.rs index bac591b84f86..8542432b31bc 100644 --- a/crates/polars-lazy/src/physical_plan/executors/scan/csv.rs +++ b/crates/polars-lazy/src/physical_plan/executors/scan/csv.rs @@ -54,7 +54,7 @@ impl CsvExec { impl Executor for CsvExec { fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult { let finger_print = FileFingerPrint { - paths: Arc::new([self.path.clone()]), + path: self.path.clone(), predicate: self .predicate .as_ref() diff --git a/crates/polars-lazy/src/physical_plan/executors/scan/ipc.rs b/crates/polars-lazy/src/physical_plan/executors/scan/ipc.rs index 5256252d3a5d..e5ee49c06a16 100644 --- a/crates/polars-lazy/src/physical_plan/executors/scan/ipc.rs +++ b/crates/polars-lazy/src/physical_plan/executors/scan/ipc.rs @@ -34,7 +34,7 @@ impl IpcExec { impl Executor for IpcExec { fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult { let finger_print = FileFingerPrint { - paths: Arc::new([self.path.clone()]), + path: self.path.clone(), predicate: self .predicate .as_ref() diff --git a/crates/polars-lazy/src/physical_plan/executors/scan/parquet.rs b/crates/polars-lazy/src/physical_plan/executors/scan/parquet.rs index 9f99c8580870..143f3eeaaa43 100644 --- a/crates/polars-lazy/src/physical_plan/executors/scan/parquet.rs +++ b/crates/polars-lazy/src/physical_plan/executors/scan/parquet.rs @@ -57,7 +57,6 @@ impl ParquetExec { if let Some(file) = file { ParquetReader::new(file) - .with_schema(Some(self.file_info.reader_schema.clone())) .with_n_rows(n_rows) .read_parallel(self.options.parallel) .with_row_count(mem::take(&mut self.file_options.row_count)) @@ -73,7 +72,7 @@ impl ParquetExec { let reader = ParquetAsyncReader::from_uri( &self.path.to_string_lossy(), self.cloud_options.as_ref(), - Some(self.file_info.reader_schema.clone()), + Some(self.file_info.schema.clone()), self.metadata.clone(), ) .await? @@ -100,7 +99,7 @@ impl ParquetExec { impl Executor for ParquetExec { fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult { let finger_print = FileFingerPrint { - paths: Arc::new([self.path.clone()]), + path: self.path.clone(), predicate: self .predicate .as_ref() diff --git a/crates/polars-lazy/src/physical_plan/expressions/aggregation.rs b/crates/polars-lazy/src/physical_plan/expressions/aggregation.rs index be48122eb520..cd22b26d6a21 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/aggregation.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/aggregation.rs @@ -11,8 +11,6 @@ use polars_core::prelude::*; use polars_core::utils::NoNull; #[cfg(feature = "dtype-struct")] use polars_core::POOL; -#[cfg(feature = "propagate_nans")] -use polars_ops::prelude::nan_propagating_aggregate; use crate::physical_plan::state::ExecutionState; use crate::physical_plan::PartitionedAggregation; diff --git a/crates/polars-lazy/src/physical_plan/expressions/binary.rs b/crates/polars-lazy/src/physical_plan/expressions/binary.rs index ef20e7abdc33..b6cf6fcd177f 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/binary.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/binary.rs @@ -3,8 +3,6 @@ use std::sync::Arc; use polars_core::frame::group_by::GroupsProxy; use polars_core::prelude::*; use polars_core::POOL; -#[cfg(feature = "round_series")] -use polars_ops::prelude::floor_div_series; use crate::physical_plan::state::ExecutionState; use crate::prelude::*; diff --git a/crates/polars-lazy/src/physical_plan/expressions/sort.rs b/crates/polars-lazy/src/physical_plan/expressions/sort.rs index 3fdabd22ee3e..0047a4d9e118 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/sort.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/sort.rs @@ -4,7 +4,6 @@ use arrow::legacy::utils::CustomIterTools; use polars_core::frame::group_by::GroupsProxy; use polars_core::prelude::*; use polars_core::POOL; -use polars_ops::chunked_array::ListNameSpaceImpl; use rayon::prelude::*; use crate::physical_plan::state::ExecutionState; diff --git a/crates/polars-lazy/src/physical_plan/expressions/window.rs b/crates/polars-lazy/src/physical_plan/expressions/window.rs index e74d4c800210..4f98e63ad7d4 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/window.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/window.rs @@ -11,7 +11,6 @@ use polars_core::{downcast_as_macro_arg_physical, POOL}; use polars_ops::frame::join::{ default_join_ids, private_left_join_multiple_keys, ChunkJoinOptIds, JoinValidation, }; -use polars_ops::frame::SeriesJoin; use polars_utils::format_smartstring; use polars_utils::sort::perfect_sort; use polars_utils::sync::SyncPtr; diff --git a/crates/polars-lazy/src/physical_plan/planner/expr.rs b/crates/polars-lazy/src/physical_plan/planner/expr.rs index 72b0b80c5d3c..0e614c4cb683 100644 --- a/crates/polars-lazy/src/physical_plan/planner/expr.rs +++ b/crates/polars-lazy/src/physical_plan/planner/expr.rs @@ -603,11 +603,10 @@ where let mut iter = chunks.into_iter(); let first = iter.next().unwrap(); - let dtype = first.dtype(); - let out = iter.fold(first.to_physical_repr().into_owned(), |mut acc, s| { - acc.append(&s.to_physical_repr()).unwrap(); + let out = iter.fold(first, |mut acc, s| { + acc.append(&s).unwrap(); acc }); - unsafe { f(out.cast_unchecked(dtype).unwrap()).map(Some) } + f(out).map(Some) } diff --git a/crates/polars-lazy/src/physical_plan/planner/lp.rs b/crates/polars-lazy/src/physical_plan/planner/lp.rs index 1e2683d4bdd4..45af5e1f3a11 100644 --- a/crates/polars-lazy/src/physical_plan/planner/lp.rs +++ b/crates/polars-lazy/src/physical_plan/planner/lp.rs @@ -150,20 +150,16 @@ pub fn create_physical_plan( match logical_plan { #[cfg(feature = "python")] PythonScan { options, .. } => Ok(Box::new(executors::PythonScanExec { options })), - Sink { payload, .. } => match payload { - SinkType::Memory => { - polars_bail!(InvalidOperation: "memory sink not supported in the standard engine") - }, - SinkType::File { file_type, .. } => { - polars_bail!(InvalidOperation: + Sink { payload, .. } => { + match payload { + SinkType::Memory => panic!("Memory Sink not supported in the standard engine."), + SinkType::File{file_type, ..} => panic!( "sink_{file_type:?} not yet supported in standard engine. Use 'collect().write_parquet()'" - ) - }, - #[cfg(feature = "cloud")] - SinkType::Cloud { .. } => { - polars_bail!(InvalidOperation: "cloud sink not supported in standard engine.") - }, - }, + ), + #[cfg(feature = "cloud")] + SinkType::Cloud{..} => panic!("Cloud Sink not supported in standard engine.") + } + } Union { inputs, options } => { let inputs = inputs .into_iter() @@ -193,7 +189,7 @@ pub fn create_physical_plan( ))) }, Scan { - paths, + path, file_info, output_schema, scan_type, @@ -217,48 +213,39 @@ pub fn create_physical_plan( #[cfg(feature = "csv")] FileScan::Csv { options: csv_options, - } => { - assert_eq!(paths.len(), 1); - let path = paths[0].clone(); - Ok(Box::new(executors::CsvExec { - path, - schema: file_info.schema, - options: csv_options, - predicate, - file_options, - })) - }, + } => Ok(Box::new(executors::CsvExec { + path, + schema: file_info.schema, + options: csv_options, + predicate, + file_options, + })), #[cfg(feature = "ipc")] - FileScan::Ipc { options } => { - assert_eq!(paths.len(), 1); - let path = paths[0].clone(); - Ok(Box::new(executors::IpcExec { - path, - schema: file_info.schema, - predicate, - options, - file_options, - })) - }, + FileScan::Ipc { options } => Ok(Box::new(executors::IpcExec { + path, + schema: file_info.schema, + predicate, + options, + file_options, + })), #[cfg(feature = "parquet")] FileScan::Parquet { options, cloud_options, - metadata, + metadata + } => Ok(Box::new(executors::ParquetExec::new( + path, + file_info, + predicate, + options, + cloud_options, + file_options, + metadata + ))), + FileScan::Anonymous { + function, + .. } => { - assert_eq!(paths.len(), 1); - let path = paths[0].clone(); - Ok(Box::new(executors::ParquetExec::new( - path, - file_info, - predicate, - options, - cloud_options, - file_options, - metadata, - ))) - }, - FileScan::Anonymous { function, .. } => { Ok(Box::new(executors::AnonymousScanExec { function, predicate, @@ -267,7 +254,8 @@ pub fn create_physical_plan( output_schema, predicate_has_windows: state.has_windows, })) - }, + + } } }, Projection { diff --git a/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs b/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs index e3d7125b9e27..e23357f58d40 100644 --- a/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs +++ b/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs @@ -357,11 +357,6 @@ pub(crate) fn insert_streaming_nodes( DataType::Object(_) => false, #[cfg(feature = "dtype-categorical")] DataType::Categorical(_) => string_cache, - DataType::List(inner) => allowed_dtype(inner, string_cache), - #[cfg(feature = "dtype-struct")] - DataType::Struct(fields) => fields - .iter() - .all(|fld| allowed_dtype(fld.data_type(), string_cache)), _ => true, } } diff --git a/crates/polars-lazy/src/prelude.rs b/crates/polars-lazy/src/prelude.rs index a5baeda1ea78..81c31fe943db 100644 --- a/crates/polars-lazy/src/prelude.rs +++ b/crates/polars-lazy/src/prelude.rs @@ -1,6 +1,5 @@ +pub(crate) use polars_ops::prelude::*; pub use polars_ops::prelude::{JoinArgs, JoinType, JoinValidation}; -#[cfg(feature = "rank")] -pub use polars_ops::prelude::{RankMethod, RankOptions}; pub use polars_plan::logical_plan::{ AnonymousScan, AnonymousScanOptions, Literal, LiteralValue, LogicalPlan, Null, NULL, }; diff --git a/crates/polars-lazy/src/utils.rs b/crates/polars-lazy/src/utils.rs index fac410b109fb..e8fa1ed4df79 100644 --- a/crates/polars-lazy/src/utils.rs +++ b/crates/polars-lazy/src/utils.rs @@ -6,15 +6,13 @@ use polars_plan::prelude::*; /// Get a set of the data source paths in this LogicalPlan pub(crate) fn agg_source_paths( root_lp: Node, - acc_paths: &mut PlHashSet, + paths: &mut PlHashSet, lp_arena: &Arena, ) { lp_arena.iter(root_lp).for_each(|(_, lp)| { use ALogicalPlan::*; - if let Scan { paths, .. } = lp { - for path in paths.as_ref() { - acc_paths.insert(path.clone()); - } + if let Scan { path, .. } = lp { + paths.insert(path.clone()); } }) } diff --git a/crates/polars-ops/Cargo.toml b/crates/polars-ops/Cargo.toml index 1651ab8afb38..c4a4e35bc3b1 100644 --- a/crates/polars-ops/Cargo.toml +++ b/crates/polars-ops/Cargo.toml @@ -9,10 +9,10 @@ repository = { workspace = true } description = "More operations on Polars data structures" [dependencies] -polars-core = { workspace = true, features = ["algorithm_group_by"] } +polars-core = { workspace = true, features = ["algorithm_group_by"], default-features = false } polars-error = { workspace = true } polars-json = { workspace = true, optional = true } -polars-utils = { workspace = true } +polars-utils = { workspace = true, default-features = false } ahash = { workspace = true } argminmax = { version = "0.6.1", default-features = false, features = ["float"] } @@ -82,8 +82,7 @@ to_dummies = [] interpolate = [] list_to_struct = ["polars-core/dtype-struct"] list_count = [] -diff = [] -pct_change = ["diff"] +diff = ["polars-core/diff"] strings = ["polars-core/strings"] string_justify = ["polars-core/strings"] string_from_radix = ["polars-core/strings"] @@ -107,11 +106,8 @@ list_take = [] list_sets = [] list_any_all = [] list_drop_nulls = [] -list_sample = [] extract_groups = ["dtype-struct", "polars-core/regex"] is_in = ["polars-core/reinterpret"] convert_index = [] repeat_by = [] peaks = [] -cum_agg = [] -ewma = [] diff --git a/crates/polars-ops/src/chunked_array/interpolate.rs b/crates/polars-ops/src/chunked_array/interpolate.rs index 6c2fe4949f4c..b06c06574960 100644 --- a/crates/polars-ops/src/chunked_array/interpolate.rs +++ b/crates/polars-ops/src/chunked_array/interpolate.rs @@ -242,34 +242,26 @@ mod test { fn test_interpolate() { let ca = UInt32Chunked::new("", &[Some(1), None, None, Some(4), Some(5)]); let out = interpolate(&ca.into_series(), InterpolationMethod::Linear); - let out = out.f64().unwrap(); + let out = out.u32().unwrap(); assert_eq!( Vec::from(out), - &[Some(1.0), Some(2.0), Some(3.0), Some(4.0), Some(5.0)] + &[Some(1), Some(2), Some(3), Some(4), Some(5)] ); let ca = UInt32Chunked::new("", &[None, Some(1), None, None, Some(4), Some(5)]); let out = interpolate(&ca.into_series(), InterpolationMethod::Linear); - let out = out.f64().unwrap(); + let out = out.u32().unwrap(); assert_eq!( Vec::from(out), - &[None, Some(1.0), Some(2.0), Some(3.0), Some(4.0), Some(5.0)] + &[None, Some(1), Some(2), Some(3), Some(4), Some(5)] ); let ca = UInt32Chunked::new("", &[None, Some(1), None, None, Some(4), Some(5), None]); let out = interpolate(&ca.into_series(), InterpolationMethod::Linear); - let out = out.f64().unwrap(); + let out = out.u32().unwrap(); assert_eq!( Vec::from(out), - &[ - None, - Some(1.0), - Some(2.0), - Some(3.0), - Some(4.0), - Some(5.0), - None - ] + &[None, Some(1), Some(2), Some(3), Some(4), Some(5), None] ); let ca = UInt32Chunked::new("", &[None, Some(1), None, None, Some(4), Some(5), None]); let out = interpolate(&ca.into_series(), InterpolationMethod::Nearest); @@ -284,11 +276,8 @@ mod test { fn test_interpolate_decreasing_unsigned() { let ca = UInt32Chunked::new("", &[Some(4), None, None, Some(1)]); let out = interpolate(&ca.into_series(), InterpolationMethod::Linear); - let out = out.f64().unwrap(); - assert_eq!( - Vec::from(out), - &[Some(4.0), Some(3.0), Some(2.0), Some(1.0)] - ) + let out = out.u32().unwrap(); + assert_eq!(Vec::from(out), &[Some(4), Some(3), Some(2), Some(1)]) } #[test] diff --git a/crates/polars-ops/src/chunked_array/list/namespace.rs b/crates/polars-ops/src/chunked_array/list/namespace.rs index 374a81e5d2a6..f9f92a75ad85 100644 --- a/crates/polars-ops/src/chunked_array/list/namespace.rs +++ b/crates/polars-ops/src/chunked_array/list/namespace.rs @@ -17,8 +17,6 @@ use super::*; use crate::chunked_array::list::any_all::*; use crate::chunked_array::list::min_max::{list_max_function, list_min_function}; use crate::chunked_array::list::sum_mean::sum_with_nulls; -#[cfg(feature = "diff")] -use crate::prelude::diff; use crate::prelude::list::sum_mean::{mean_list_numerical, sum_list_numerical}; use crate::series::ArgAgg; @@ -258,7 +256,7 @@ pub trait ListNameSpaceImpl: AsList { #[cfg(feature = "diff")] fn lst_diff(&self, n: i64, null_behavior: NullBehavior) -> PolarsResult { let ca = self.as_list(); - ca.try_apply_amortized(|s| diff(s.as_ref(), n, null_behavior)) + ca.try_apply_amortized(|s| s.as_ref().diff(n, null_behavior)) } fn lst_shift(&self, periods: &Series) -> PolarsResult { @@ -404,86 +402,6 @@ pub trait ListNameSpaceImpl: AsList { list_ca.apply_amortized(|s| s.as_ref().drop_nulls()) } - #[cfg(feature = "list_sample")] - fn lst_sample_n( - &self, - n: &Series, - with_replacement: bool, - shuffle: bool, - seed: Option, - ) -> PolarsResult { - let ca = self.as_list(); - - let n_s = n.cast(&IDX_DTYPE)?; - let n = n_s.idx()?; - - let out = match n.len() { - 1 => { - if let Some(n) = n.get(0) { - ca.try_apply_amortized(|s| { - s.as_ref() - .sample_n(n as usize, with_replacement, shuffle, seed) - }) - } else { - Ok(ListChunked::full_null_with_dtype( - ca.name(), - ca.len(), - &ca.inner_dtype(), - )) - } - }, - _ => ca.try_zip_and_apply_amortized(n, |opt_s, opt_n| match (opt_s, opt_n) { - (Some(s), Some(n)) => s - .as_ref() - .sample_n(n as usize, with_replacement, shuffle, seed) - .map(Some), - _ => Ok(None), - }), - }; - out.map(|ok| self.same_type(ok)) - } - - #[cfg(feature = "list_sample")] - fn lst_sample_fraction( - &self, - fraction: &Series, - with_replacement: bool, - shuffle: bool, - seed: Option, - ) -> PolarsResult { - let ca = self.as_list(); - - let fraction_s = fraction.cast(&DataType::Float64)?; - let fraction = fraction_s.f64()?; - - let out = match fraction.len() { - 1 => { - if let Some(fraction) = fraction.get(0) { - ca.try_apply_amortized(|s| { - let n = (s.as_ref().len() as f64 * fraction) as usize; - s.as_ref().sample_n(n, with_replacement, shuffle, seed) - }) - } else { - Ok(ListChunked::full_null_with_dtype( - ca.name(), - ca.len(), - &ca.inner_dtype(), - )) - } - }, - _ => ca.try_zip_and_apply_amortized(fraction, |opt_s, opt_n| match (opt_s, opt_n) { - (Some(s), Some(fraction)) => { - let n = (s.as_ref().len() as f64 * fraction) as usize; - s.as_ref() - .sample_n(n, with_replacement, shuffle, seed) - .map(Some) - }, - _ => Ok(None), - }), - }; - out.map(|ok| self.same_type(ok)) - } - fn lst_concat(&self, other: &[Series]) -> PolarsResult { let ca = self.as_list(); let other_len = other.len(); diff --git a/crates/polars-ops/src/frame/join/asof/groups.rs b/crates/polars-ops/src/frame/join/asof/groups.rs index 7bd4fe1a2db4..a8e31f69af50 100644 --- a/crates/polars-ops/src/frame/join/asof/groups.rs +++ b/crates/polars-ops/src/frame/join/asof/groups.rs @@ -774,20 +774,9 @@ pub trait AsofJoinBy: IntoDf { suffix: Option<&str>, slice: Option<(i64, usize)>, ) -> PolarsResult { - let (self_sliced_slot, other_sliced_slot); // Keeps temporaries alive. - let (self_df, other_df); - if let Some((offset, len)) = slice { - self_sliced_slot = self.to_df().slice(offset, len); - other_sliced_slot = other.slice(offset, len); - self_df = &self_sliced_slot; - other_df = &other_sliced_slot; - } else { - self_df = self.to_df(); - other_df = other; - } - + let self_df = self.to_df(); let left_asof = self_df.column(left_on)?.to_physical_repr(); - let right_asof = other_df.column(right_on)?.to_physical_repr(); + let right_asof = other.column(right_on)?.to_physical_repr(); let right_asof_name = right_asof.name(); let left_asof_name = left_asof.name(); @@ -798,7 +787,7 @@ pub trait AsofJoinBy: IntoDf { )?; let mut left_by = self_df.select(left_by)?; - let mut right_by = other_df.select(right_by)?; + let mut right_by = other.select(right_by)?; unsafe { for (l, r) in left_by @@ -837,7 +826,7 @@ pub trait AsofJoinBy: IntoDf { drop_these.push(right_asof_name); } - let cols = other_df + let cols = other .get_columns() .iter() .filter_map(|s| { @@ -848,15 +837,19 @@ pub trait AsofJoinBy: IntoDf { } }) .collect(); - let proj_other_df = DataFrame::new_no_checks(cols); + let other = DataFrame::new_no_checks(cols); - let left = self_df.clone(); - let right_join_tuples = &*right_join_tuples; + let mut left = self_df.clone(); + let mut right_join_tuples = &*right_join_tuples; + + if let Some((offset, len)) = slice { + left = left.slice(offset, len); + right_join_tuples = slice_slice(right_join_tuples, offset, len); + } // SAFETY: join tuples are in bounds. - let right_df = unsafe { - proj_other_df.take_unchecked(&right_join_tuples.iter().copied().collect_ca("")) - }; + let right_df = + unsafe { other.take_unchecked(&right_join_tuples.iter().copied().collect_ca("")) }; _finish_join(left, right_df, suffix) } diff --git a/crates/polars-ops/src/frame/join/hash_join/mod.rs b/crates/polars-ops/src/frame/join/hash_join/mod.rs index b6f4df2cba6e..04aaef862338 100644 --- a/crates/polars-ops/src/frame/join/hash_join/mod.rs +++ b/crates/polars-ops/src/frame/join/hash_join/mod.rs @@ -132,11 +132,19 @@ pub trait JoinDispatch: IntoDf { ) -> PolarsResult { let ca_self = self.to_df(); let (left_idx, right_idx) = ids; - let materialize_left = - || unsafe { ca_self._create_left_df_from_slice(&left_idx, true, true) }; + let materialize_left = || { + let mut left_idx = &*left_idx; + if let Some((offset, len)) = args.slice { + left_idx = slice_slice(left_idx, offset, len); + } + unsafe { ca_self._create_left_df_from_slice(left_idx, true, true) } + }; let materialize_right = || { - let right_idx = &*right_idx; + let mut right_idx = &*right_idx; + if let Some((offset, len)) = args.slice { + right_idx = slice_slice(right_idx, offset, len); + } unsafe { other.take_unchecked(&right_idx.iter().copied().collect_ca("")) } }; let (df_left, df_right) = POOL.join(materialize_left, materialize_right); @@ -153,38 +161,39 @@ pub trait JoinDispatch: IntoDf { ) -> PolarsResult { let ca_self = self.to_df(); let suffix = &args.suffix; + let slice = args.slice; let (left_idx, right_idx) = ids; let materialize_left = || match left_idx { - ChunkJoinIds::Left(left_idx) => unsafe { + ChunkJoinIds::Left(left_idx) => { let mut left_idx = &*left_idx; - if let Some((offset, len)) = args.slice { + if let Some((offset, len)) = slice { left_idx = slice_slice(left_idx, offset, len); } - ca_self._create_left_df_from_slice(left_idx, true, true) + unsafe { ca_self._create_left_df_from_slice(left_idx, true, true) } }, - ChunkJoinIds::Right(left_idx) => unsafe { + ChunkJoinIds::Right(left_idx) => { let mut left_idx = &*left_idx; - if let Some((offset, len)) = args.slice { + if let Some((offset, len)) = slice { left_idx = slice_slice(left_idx, offset, len); } - ca_self.create_left_df_chunked(left_idx, true) + unsafe { ca_self.create_left_df_chunked(left_idx, true) } }, }; let materialize_right = || match right_idx { - ChunkJoinOptIds::Left(right_idx) => unsafe { + ChunkJoinOptIds::Left(right_idx) => { let mut right_idx = &*right_idx; - if let Some((offset, len)) = args.slice { + if let Some((offset, len)) = slice { right_idx = slice_slice(right_idx, offset, len); } - other.take_unchecked(&right_idx.iter().copied().collect_ca("")) + unsafe { other.take_unchecked(&right_idx.iter().copied().collect_ca("")) } }, - ChunkJoinOptIds::Right(right_idx) => unsafe { + ChunkJoinOptIds::Right(right_idx) => { let mut right_idx = &*right_idx; - if let Some((offset, len)) = args.slice { + if let Some((offset, len)) = slice { right_idx = slice_slice(right_idx, offset, len); } - other._take_opt_chunked_unchecked(right_idx) + unsafe { other._take_opt_chunked_unchecked(right_idx) } }, }; let (df_left, df_right) = POOL.join(materialize_left, materialize_right); @@ -204,17 +213,9 @@ pub trait JoinDispatch: IntoDf { #[cfg(feature = "dtype-categorical")] _check_categorical_src(s_left.dtype(), s_right.dtype())?; + // ensure that the chunks are aligned otherwise we go OOB let mut left = ca_self.clone(); let mut s_left = s_left.clone(); - // Eagerly limit left if possible. - if let Some((offset, len)) = args.slice { - if offset == 0 { - left = left.slice(0, len); - s_left = s_left.slice(0, len); - } - } - - // Ensure that the chunks are aligned otherwise we go OOB. let mut right = other.clone(); let mut s_right = s_right.clone(); if left.should_rechunk() { diff --git a/crates/polars-ops/src/frame/join/hash_join/zip_outer.rs b/crates/polars-ops/src/frame/join/hash_join/zip_outer.rs index 57cf99db752e..ba23f32c2910 100644 --- a/crates/polars-ops/src/frame/join/hash_join/zip_outer.rs +++ b/crates/polars-ops/src/frame/join/hash_join/zip_outer.rs @@ -17,11 +17,11 @@ pub(crate) unsafe fn zip_outer_join_column( let new_rev_map = left_column ._merge_categorical_map(right_column.categorical().unwrap()) .unwrap(); - let left = left_column.physical(); + let left = left_column.logical(); let right = right_column .categorical() .unwrap() - .physical() + .logical() .clone() .into_series(); diff --git a/crates/polars-ops/src/frame/join/mod.rs b/crates/polars-ops/src/frame/join/mod.rs index 0b92ec715960..17d3f59b833f 100644 --- a/crates/polars-ops/src/frame/join/mod.rs +++ b/crates/polars-ops/src/frame/join/mod.rs @@ -52,13 +52,12 @@ pub trait DataFrameJoinOps: IntoDf { /// /// ```no_run /// # use polars_core::prelude::*; - /// # use polars_ops::prelude::*; /// let df1: DataFrame = df!("Fruit" => &["Apple", "Banana", "Pear"], /// "Phosphorus (mg/100g)" => &[11, 22, 12])?; /// let df2: DataFrame = df!("Name" => &["Apple", "Banana", "Pear"], /// "Potassium (mg/100g)" => &[107, 358, 115])?; /// - /// let df3: DataFrame = df1.join(&df2, ["Fruit"], ["Name"], JoinArgs::new(JoinType::Inner))?; + /// let df3: DataFrame = df1.join(&df2, ["Fruit"], ["Name"], JoinType::Inner, None)?; /// assert_eq!(df3.shape(), (3, 3)); /// println!("{}", df3); /// # Ok::<(), PolarsError>(()) @@ -189,7 +188,7 @@ pub trait DataFrameJoinOps: IntoDf { _check_categorical_src(l.dtype(), r.dtype())? } - // Single keys. + // Single keys if selected_left.len() == 1 { let s_left = left_df.column(selected_left[0].name())?; let s_right = other.column(selected_right[0].name())?; @@ -256,13 +255,12 @@ pub trait DataFrameJoinOps: IntoDf { } new.unwrap() } - - // Make sure that we don't have logical types. - // We don't overwrite the original selected as that might be used to create a column in the new df. + // make sure that we don't have logical types. + // we don't overwrite the original selected as that might be used to create a column in the new df let selected_left_physical = _to_physical_and_bit_repr(&selected_left); let selected_right_physical = _to_physical_and_bit_repr(&selected_right); - // Multiple keys. + // multiple keys match args.how { JoinType::Inner => { let left = DataFrame::new_no_checks(selected_left_physical); @@ -292,11 +290,8 @@ pub trait DataFrameJoinOps: IntoDf { JoinType::Left => { let mut left = DataFrame::new_no_checks(selected_left_physical); let mut right = DataFrame::new_no_checks(selected_right_physical); - - if let Some((offset, len)) = args.slice { - left = left.slice(offset, len); - } let ids = _left_join_multiple_keys(&mut left, &mut right, None, None); + left_df._finish_left_join(ids, &remove_selected(other, &selected_right), args) }, JoinType::Outer => { @@ -374,7 +369,6 @@ pub trait DataFrameJoinOps: IntoDf { /// /// ``` /// # use polars_core::prelude::*; - /// # use polars_ops::prelude::*; /// fn join_dfs(left: &DataFrame, right: &DataFrame) -> PolarsResult { /// left.inner_join(right, ["join_column_left"], ["join_column_right"]) /// } @@ -397,7 +391,6 @@ pub trait DataFrameJoinOps: IntoDf { /// /// ```no_run /// # use polars_core::prelude::*; - /// # use polars_ops::prelude::*; /// let df1: DataFrame = df!("Wavelength (nm)" => &[480.0, 650.0, 577.0, 1201.0, 100.0])?; /// let df2: DataFrame = df!("Color" => &["Blue", "Yellow", "Red"], /// "Wavelength nm" => &[480.0, 577.0, 650.0])?; @@ -440,7 +433,6 @@ pub trait DataFrameJoinOps: IntoDf { /// /// ``` /// # use polars_core::prelude::*; - /// # use polars_ops::prelude::*; /// fn join_dfs(left: &DataFrame, right: &DataFrame) -> PolarsResult { /// left.outer_join(right, ["join_column_left"], ["join_column_right"]) /// } diff --git a/crates/polars-ops/src/series/ops/approx_algo/hyperloglogplus.rs b/crates/polars-ops/src/series/ops/approx_algo/hyperloglogplus.rs index d507a1fcf20c..7df61317d9bc 100644 --- a/crates/polars-ops/src/series/ops/approx_algo/hyperloglogplus.rs +++ b/crates/polars-ops/src/series/ops/approx_algo/hyperloglogplus.rs @@ -9,7 +9,6 @@ //! # Examples //! //! ``` -//! # use polars_ops::prelude::*; //! let mut hllp = HyperLogLog::new(); //! hllp.add(&12345); //! hllp.add(&23456); diff --git a/crates/polars-ops/src/series/ops/cum_agg.rs b/crates/polars-ops/src/series/ops/cum_agg.rs deleted file mode 100644 index 70fb1a0bfcf6..000000000000 --- a/crates/polars-ops/src/series/ops/cum_agg.rs +++ /dev/null @@ -1,230 +0,0 @@ -use std::iter::FromIterator; -use std::ops::{Add, AddAssign, Mul}; - -use num_traits::Bounded; -use polars_core::prelude::*; -use polars_core::utils::{CustomIterTools, NoNull}; -use polars_core::with_match_physical_numeric_polars_type; - -fn det_max(state: &mut T, v: Option) -> Option> -where - T: Copy + PartialOrd + AddAssign + Add, -{ - match v { - Some(v) => { - if v > *state { - *state = v - } - Some(Some(*state)) - }, - None => Some(None), - } -} - -fn det_min(state: &mut T, v: Option) -> Option> -where - T: Copy + PartialOrd + AddAssign + Add, -{ - match v { - Some(v) => { - if v < *state { - *state = v - } - Some(Some(*state)) - }, - None => Some(None), - } -} - -fn det_sum(state: &mut Option, v: Option) -> Option> -where - T: Copy + PartialOrd + AddAssign + Add, -{ - match (*state, v) { - (Some(state_inner), Some(v)) => { - *state = Some(state_inner + v); - Some(*state) - }, - (None, Some(v)) => { - *state = Some(v); - Some(*state) - }, - (_, None) => Some(None), - } -} - -fn det_prod(state: &mut Option, v: Option) -> Option> -where - T: Copy + PartialOrd + Mul, -{ - match (*state, v) { - (Some(state_inner), Some(v)) => { - *state = Some(state_inner * v); - Some(*state) - }, - (None, Some(v)) => { - *state = Some(v); - Some(*state) - }, - (_, None) => Some(None), - } -} - -fn cummax_numeric(ca: &ChunkedArray, reverse: bool) -> ChunkedArray -where - T: PolarsNumericType, - ChunkedArray: FromIterator>, -{ - let init = Bounded::min_value(); - - let out: ChunkedArray = match reverse { - false => ca.into_iter().scan(init, det_max).collect_trusted(), - true => ca.into_iter().rev().scan(init, det_max).collect_reversed(), - }; - out.with_name(ca.name()) -} - -fn cummin_numeric(ca: &ChunkedArray, reverse: bool) -> ChunkedArray -where - T: PolarsNumericType, - ChunkedArray: FromIterator>, -{ - let init = Bounded::max_value(); - let out: ChunkedArray = match reverse { - false => ca.into_iter().scan(init, det_min).collect_trusted(), - true => ca.into_iter().rev().scan(init, det_min).collect_reversed(), - }; - out.with_name(ca.name()) -} - -fn cumsum_numeric(ca: &ChunkedArray, reverse: bool) -> ChunkedArray -where - T: PolarsNumericType, - ChunkedArray: FromIterator>, -{ - let init = None; - let out: ChunkedArray = match reverse { - false => ca.into_iter().scan(init, det_sum).collect_trusted(), - true => ca.into_iter().rev().scan(init, det_sum).collect_reversed(), - }; - out.with_name(ca.name()) -} - -fn cumprod_numeric(ca: &ChunkedArray, reverse: bool) -> ChunkedArray -where - T: PolarsNumericType, - ChunkedArray: FromIterator>, -{ - let init = None; - let out: ChunkedArray = match reverse { - false => ca.into_iter().scan(init, det_prod).collect_trusted(), - true => ca.into_iter().rev().scan(init, det_prod).collect_reversed(), - }; - out.with_name(ca.name()) -} - -/// Get an array with the cumulative product computed at every element. -/// -/// If the [`DataType`] is one of `{Int8, UInt8, Int16, UInt16, Int32, UInt32}` the `Series` is -/// first cast to `Int64` to prevent overflow issues. -pub fn cumprod(s: &Series, reverse: bool) -> PolarsResult { - use DataType::*; - let out = match s.dtype() { - Boolean | Int8 | UInt8 | Int16 | UInt16 | Int32 | UInt32 => { - let s = s.cast(&Int64)?; - cumprod_numeric(s.i64()?, reverse).into_series() - }, - Int64 => cumprod_numeric(s.i64()?, reverse).into_series(), - UInt64 => cumprod_numeric(s.u64()?, reverse).into_series(), - Float32 => cumprod_numeric(s.f32()?, reverse).into_series(), - Float64 => cumprod_numeric(s.f64()?, reverse).into_series(), - dt => polars_bail!(opq = cumprod, dt), - }; - Ok(out) -} - -/// Get an array with the cumulative sum computed at every element -/// -/// If the [`DataType`] is one of `{Int8, UInt8, Int16, UInt16}` the `Series` is -/// first cast to `Int64` to prevent overflow issues. -pub fn cumsum(s: &Series, reverse: bool) -> PolarsResult { - use DataType::*; - let out = match s.dtype() { - Boolean => { - let s = s.cast(&UInt32)?; - cumsum_numeric(s.u32()?, reverse).into_series() - }, - Int8 | UInt8 | Int16 | UInt16 => { - let s = s.cast(&Int64)?; - cumsum_numeric(s.i64()?, reverse).into_series() - }, - Int32 => cumsum_numeric(s.i32()?, reverse).into_series(), - UInt32 => cumsum_numeric(s.u32()?, reverse).into_series(), - Int64 => cumsum_numeric(s.i64()?, reverse).into_series(), - UInt64 => cumsum_numeric(s.u64()?, reverse).into_series(), - Float32 => cumsum_numeric(s.f32()?, reverse).into_series(), - Float64 => cumsum_numeric(s.f64()?, reverse).into_series(), - #[cfg(feature = "dtype-duration")] - Duration(tu) => { - let s = s.to_physical_repr(); - let ca = s.i64()?; - cumsum_numeric(ca, reverse).cast(&Duration(*tu))? - }, - dt => polars_bail!(opq = cumsum, dt), - }; - Ok(out) -} - -/// Get an array with the cumulative min computed at every element. -pub fn cummin(s: &Series, reverse: bool) -> PolarsResult { - let original_type = s.dtype(); - let s = s.to_physical_repr(); - match s.dtype() { - dt if dt.is_numeric() => { - with_match_physical_numeric_polars_type!(s.dtype(), |$T| { - let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); - let out = cummin_numeric(ca, reverse).into_series(); - if original_type.is_logical(){ - out.cast(original_type) - }else{ - Ok(out) - } - }) - }, - dt => polars_bail!(opq = cummin, dt), - } -} - -/// Get an array with the cumulative max computed at every element. -pub fn cummax(s: &Series, reverse: bool) -> PolarsResult { - let original_type = s.dtype(); - let s = s.to_physical_repr(); - match s.dtype() { - dt if dt.is_numeric() => { - with_match_physical_numeric_polars_type!(s.dtype(), |$T| { - let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); - let out = cummax_numeric(ca, reverse).into_series(); - if original_type.is_logical(){ - out.cast(original_type) - }else{ - Ok(out) - } - }) - }, - dt => polars_bail!(opq = cummin, dt), - } -} - -pub fn cumcount(s: &Series, reverse: bool) -> PolarsResult { - if reverse { - let ca: NoNull = (0u32..s.len() as u32).rev().collect(); - let mut ca = ca.into_inner(); - ca.rename(s.name()); - Ok(ca.into_series()) - } else { - let ca: NoNull = (0u32..s.len() as u32).collect(); - let mut ca = ca.into_inner(); - ca.rename(s.name()); - Ok(ca.into_series()) - } -} diff --git a/crates/polars-ops/src/series/ops/diff.rs b/crates/polars-ops/src/series/ops/diff.rs deleted file mode 100644 index 8fa28768609e..000000000000 --- a/crates/polars-ops/src/series/ops/diff.rs +++ /dev/null @@ -1,22 +0,0 @@ -use polars_core::prelude::*; -use polars_core::series::ops::NullBehavior; - -pub fn diff(s: &Series, n: i64, null_behavior: NullBehavior) -> PolarsResult { - use DataType::*; - let s = match s.dtype() { - UInt8 => s.cast(&Int16)?, - UInt16 => s.cast(&Int32)?, - UInt32 | UInt64 => s.cast(&Int64)?, - _ => s.clone(), - }; - - match null_behavior { - NullBehavior::Ignore => Ok(&s - &s.shift(n)), - NullBehavior::Drop => { - polars_ensure!(n > 0, InvalidOperation: "only positive integer allowed if nulls are dropped in 'diff' operation"); - let n = n as usize; - let len = s.len() - n; - Ok(&s.slice(n as i64, len) - &s.slice(0, len)) - }, - } -} diff --git a/crates/polars-ops/src/series/ops/ewm.rs b/crates/polars-ops/src/series/ops/ewm.rs deleted file mode 100644 index 6f4458777306..000000000000 --- a/crates/polars-ops/src/series/ops/ewm.rs +++ /dev/null @@ -1,103 +0,0 @@ -use std::convert::TryFrom; - -pub use arrow::legacy::kernels::ewm::EWMOptions; -use arrow::legacy::kernels::ewm::{ - ewm_mean as kernel_ewm_mean, ewm_std as kernel_ewm_std, ewm_var as kernel_ewm_var, -}; -use polars_core::prelude::*; - -fn check_alpha(alpha: f64) -> PolarsResult<()> { - polars_ensure!((0.0..=1.0).contains(&alpha), ComputeError: "alpha must be in [0; 1]"); - Ok(()) -} - -pub fn ewm_mean(s: &Series, options: EWMOptions) -> PolarsResult { - check_alpha(options.alpha)?; - match s.dtype() { - DataType::Float32 => { - let xs = s.f32().unwrap(); - let result = kernel_ewm_mean( - xs, - options.alpha as f32, - options.adjust, - options.min_periods, - options.ignore_nulls, - ); - Series::try_from((s.name(), Box::new(result) as ArrayRef)) - }, - DataType::Float64 => { - let xs = s.f64().unwrap(); - let result = kernel_ewm_mean( - xs, - options.alpha, - options.adjust, - options.min_periods, - options.ignore_nulls, - ); - Series::try_from((s.name(), Box::new(result) as ArrayRef)) - }, - _ => ewm_mean(&s.cast(&DataType::Float64)?, options), - } -} - -pub fn ewm_std(s: &Series, options: EWMOptions) -> PolarsResult { - check_alpha(options.alpha)?; - match s.dtype() { - DataType::Float32 => { - let xs = s.f32().unwrap(); - let result = kernel_ewm_std( - xs, - options.alpha as f32, - options.adjust, - options.bias, - options.min_periods, - options.ignore_nulls, - ); - Series::try_from((s.name(), Box::new(result) as ArrayRef)) - }, - DataType::Float64 => { - let xs = s.f64().unwrap(); - let result = kernel_ewm_std( - xs, - options.alpha, - options.adjust, - options.bias, - options.min_periods, - options.ignore_nulls, - ); - Series::try_from((s.name(), Box::new(result) as ArrayRef)) - }, - _ => ewm_std(&s.cast(&DataType::Float64)?, options), - } -} - -pub fn ewm_var(s: &Series, options: EWMOptions) -> PolarsResult { - check_alpha(options.alpha)?; - match s.dtype() { - DataType::Float32 => { - let xs = s.f32().unwrap(); - let result = kernel_ewm_var( - xs, - options.alpha as f32, - options.adjust, - options.bias, - options.min_periods, - options.ignore_nulls, - ); - Series::try_from((s.name(), Box::new(result) as ArrayRef)) - }, - DataType::Float64 => { - let xs = s.f64().unwrap(); - let result = kernel_ewm_var( - xs, - options.alpha, - options.adjust, - options.bias, - options.min_periods, - options.ignore_nulls, - ); - Series::try_from((s.name(), Box::new(result) as ArrayRef)) - }, - _ => ewm_var(&s.cast(&DataType::Float64)?, options), - } -} diff --git a/crates/polars-ops/src/series/ops/is_in.rs b/crates/polars-ops/src/series/ops/is_in.rs index 9c1e7f1e4ee6..cbc7f822eeda 100644 --- a/crates/polars-ops/src/series/ops/is_in.rs +++ b/crates/polars-ops/src/series/ops/is_in.rs @@ -343,7 +343,7 @@ pub fn is_in(s: &Series, other: &Series) -> PolarsResult { use crate::frame::join::_check_categorical_src; _check_categorical_src(s.dtype(), other.dtype())?; let ca = s.categorical().unwrap(); - let ca = ca.physical(); + let ca = ca.logical(); is_in_numeric(ca, &other.to_physical_repr()) }, #[cfg(feature = "dtype-struct")] diff --git a/crates/polars-ops/src/series/ops/mod.rs b/crates/polars-ops/src/series/ops/mod.rs index ea7f4a9a455e..d4c10d7fd078 100644 --- a/crates/polars-ops/src/series/ops/mod.rs +++ b/crates/polars-ops/src/series/ops/mod.rs @@ -3,14 +3,8 @@ mod approx_algo; mod approx_unique; mod arg_min_max; mod clip; -#[cfg(feature = "cum_agg")] -mod cum_agg; #[cfg(feature = "cutqcut")] mod cut; -#[cfg(feature = "diff")] -mod diff; -#[cfg(feature = "ewma")] -mod ewm; #[cfg(feature = "round_series")] mod floor_divide; #[cfg(feature = "fused")] @@ -28,16 +22,12 @@ mod is_last_distinct; mod is_unique; #[cfg(feature = "log")] mod log; -#[cfg(feature = "pct_change")] -mod pct_change; #[cfg(feature = "rank")] mod rank; #[cfg(feature = "rle")] mod rle; #[cfg(feature = "rolling_window")] mod rolling; -#[cfg(feature = "round_series")] -mod round; #[cfg(feature = "search_sorted")] mod search_sorted; #[cfg(feature = "to_dummies")] @@ -49,14 +39,8 @@ pub use approx_algo::*; pub use approx_unique::*; pub use arg_min_max::ArgAgg; pub use clip::*; -#[cfg(feature = "cum_agg")] -pub use cum_agg::*; #[cfg(feature = "cutqcut")] pub use cut::*; -#[cfg(feature = "diff")] -pub use diff::*; -#[cfg(feature = "ewma")] -pub use ewm::*; #[cfg(feature = "round_series")] pub use floor_divide::*; #[cfg(feature = "fused")] @@ -74,8 +58,6 @@ pub use is_last_distinct::*; pub use is_unique::*; #[cfg(feature = "log")] pub use log::*; -#[cfg(feature = "pct_change")] -pub use pct_change::*; use polars_core::prelude::*; #[cfg(feature = "rank")] pub use rank::*; @@ -83,8 +65,6 @@ pub use rank::*; pub use rle::*; #[cfg(feature = "rolling_window")] pub use rolling::*; -#[cfg(feature = "round_series")] -pub use round::*; #[cfg(feature = "search_sorted")] pub use search_sorted::*; #[cfg(feature = "to_dummies")] diff --git a/crates/polars-ops/src/series/ops/pct_change.rs b/crates/polars-ops/src/series/ops/pct_change.rs deleted file mode 100644 index 56c7af142e9b..000000000000 --- a/crates/polars-ops/src/series/ops/pct_change.rs +++ /dev/null @@ -1,25 +0,0 @@ -use polars_core::prelude::*; -use polars_core::series::ops::NullBehavior; - -use crate::prelude::diff; - -pub fn pct_change(s: &Series, n: &Series) -> PolarsResult { - polars_ensure!( - n.len() == 1, - ComputeError: "n must be a single value." - ); - - match s.dtype() { - DataType::Float64 | DataType::Float32 => {}, - _ => return pct_change(&s.cast(&DataType::Float64)?, n), - } - - let fill_null_s = s.fill_null(FillNullStrategy::Forward(None))?; - - let n_s = n.cast(&DataType::Int64)?; - if let Some(n) = n_s.i64()?.get(0) { - diff(&fill_null_s, n, NullBehavior::Ignore)?.divide(&fill_null_s.shift(n)) - } else { - Ok(Series::full_null(s.name(), s.len(), s.dtype())) - } -} diff --git a/crates/polars-ops/src/series/ops/rank.rs b/crates/polars-ops/src/series/ops/rank.rs index c251aaa6922e..41f9b4ca8eb9 100644 --- a/crates/polars-ops/src/series/ops/rank.rs +++ b/crates/polars-ops/src/series/ops/rank.rs @@ -6,13 +6,10 @@ use rand::prelude::SliceRandom; use rand::prelude::*; #[cfg(feature = "random")] use rand::{rngs::SmallRng, SeedableRng}; -#[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; use crate::prelude::SeriesSealed; -#[derive(Copy, Clone, Debug, PartialEq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Copy, Clone)] pub enum RankMethod { Average, Min, @@ -24,8 +21,7 @@ pub enum RankMethod { } // We might want to add a `nulls_last` or `null_behavior` field. -#[derive(Copy, Clone, Debug, PartialEq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Copy, Clone)] pub struct RankOptions { pub method: RankMethod, pub descending: bool, @@ -113,7 +109,7 @@ fn rank(s: &Series, method: RankMethod, descending: bool, seed: Option) -> rank += 1; } } - IdxCa::from_vec_validity(s.name(), out, validity).into_series() + IdxCa::new_from_owned_with_null_bitmap(s.name(), out, validity).into_series() } else { let sorted_values = unsafe { s.take_unchecked(&sort_idx_ca) }; let not_consecutive_same = sorted_values @@ -136,7 +132,7 @@ fn rank(s: &Series, method: RankMethod, descending: bool, seed: Option) -> rank += 1; } }); - IdxCa::from_vec_validity(s.name(), out, validity).into_series() + IdxCa::new_from_owned_with_null_bitmap(s.name(), out, validity).into_series() }, Average => unsafe { let mut out = vec![0.0; s.len()]; @@ -149,7 +145,8 @@ fn rank(s: &Series, method: RankMethod, descending: bool, seed: Option) -> *out.get_unchecked_mut(*i as usize) = avg; } }); - Float64Chunked::from_vec_validity(s.name(), out, validity).into_series() + Float64Chunked::new_from_owned_with_null_bitmap(s.name(), out, validity) + .into_series() }, Min => unsafe { let mut out = vec![0 as IdxSize; s.len()]; @@ -159,7 +156,7 @@ fn rank(s: &Series, method: RankMethod, descending: bool, seed: Option) -> } rank += ties.len() as IdxSize; }); - IdxCa::from_vec_validity(s.name(), out, validity).into_series() + IdxCa::new_from_owned_with_null_bitmap(s.name(), out, validity).into_series() }, Max => unsafe { let mut out = vec![0 as IdxSize; s.len()]; @@ -169,7 +166,7 @@ fn rank(s: &Series, method: RankMethod, descending: bool, seed: Option) -> *out.get_unchecked_mut(*i as usize) = rank - 1; } }); - IdxCa::from_vec_validity(s.name(), out, validity).into_series() + IdxCa::new_from_owned_with_null_bitmap(s.name(), out, validity).into_series() }, Dense => unsafe { let mut out = vec![0 as IdxSize; s.len()]; @@ -179,7 +176,7 @@ fn rank(s: &Series, method: RankMethod, descending: bool, seed: Option) -> } rank += 1; }); - IdxCa::from_vec_validity(s.name(), out, validity).into_series() + IdxCa::new_from_owned_with_null_bitmap(s.name(), out, validity).into_series() }, Ordinal => unreachable!(), } diff --git a/crates/polars-pipe/src/executors/sources/parquet.rs b/crates/polars-pipe/src/executors/sources/parquet.rs index c487d34e8d89..b8afaf1b8208 100644 --- a/crates/polars-pipe/src/executors/sources/parquet.rs +++ b/crates/polars-pipe/src/executors/sources/parquet.rs @@ -80,7 +80,7 @@ impl ParquetSource { ParquetAsyncReader::from_uri( &uri, self.cloud_options.as_ref(), - Some(self.file_info.reader_schema.clone()), + Some(self.file_info.schema.clone()), self.metadata.clone(), ) .await? @@ -102,7 +102,6 @@ impl ParquetSource { let file = std::fs::File::open(path).unwrap(); ParquetReader::new(file) - .with_schema(Some(self.file_info.reader_schema.clone())) .with_n_rows(file_options.n_rows) .with_row_count(file_options.row_count) .with_projection(projection) diff --git a/crates/polars-pipe/src/pipeline/convert.rs b/crates/polars-pipe/src/pipeline/convert.rs index 04fb4c287e62..0c2d48ffa89e 100644 --- a/crates/polars-pipe/src/pipeline/convert.rs +++ b/crates/polars-pipe/src/pipeline/convert.rs @@ -68,7 +68,7 @@ where Ok(Box::new(sources::DataFrameSource::from_df(df)) as Box) }, Scan { - paths, + path, file_info, file_options, predicate, @@ -87,9 +87,8 @@ where FileScan::Csv { options: csv_options, } => { - assert_eq!(paths.len(), 1); let src = sources::CsvSource::new( - paths[0].clone(), + path, file_info.schema, csv_options, file_options, @@ -103,9 +102,8 @@ where cloud_options, metadata, } => { - assert_eq!(paths.len(), 1); let src = sources::ParquetSource::new( - paths[0].clone(), + path, parquet_options, cloud_options, metadata, diff --git a/crates/polars-plan/Cargo.toml b/crates/polars-plan/Cargo.toml index d70a90914835..89811eca50eb 100644 --- a/crates/polars-plan/Cargo.toml +++ b/crates/polars-plan/Cargo.toml @@ -13,10 +13,10 @@ doctest = false [dependencies] libloading = { version = "0.8.0", optional = true } -polars-core = { workspace = true, features = ["lazy", "zip_with", "random"] } +polars-core = { workspace = true, features = ["lazy", "zip_with", "random"], default-features = false } polars-ffi = { workspace = true, optional = true } -polars-io = { workspace = true, features = ["lazy"] } -polars-ops = { workspace = true, features = ["zip_with"] } +polars-io = { workspace = true, features = ["lazy"], default-features = false } +polars-ops = { workspace = true, features = ["zip_with"], default-features = false } polars-time = { workspace = true, optional = true } polars-utils = { workspace = true } @@ -89,7 +89,7 @@ extract_jsonpath = ["polars-ops/extract_jsonpath"] approx_unique = ["polars-ops/approx_unique"] is_in = ["polars-ops/is_in"] repeat_by = ["polars-ops/repeat_by"] -round_series = ["polars-ops/round_series"] +round_series = ["polars-core/round_series"] is_first_distinct = ["polars-core/is_first_distinct", "polars-ops/is_first_distinct"] is_last_distinct = ["polars-core/is_last_distinct", "polars-ops/is_last_distinct"] is_unique = ["polars-ops/is_unique"] @@ -98,7 +98,7 @@ asof_join = ["polars-core/asof_join", "polars-time", "polars-ops/asof_join"] concat_str = [] range = [] mode = ["polars-ops/mode"] -cum_agg = ["polars-ops/cum_agg"] +cum_agg = ["polars-core/cum_agg"] interpolate = ["polars-ops/interpolate"] rolling_window = [ "polars-core/rolling_window", @@ -107,13 +107,13 @@ rolling_window = [ "polars-time/rolling_window", ] rank = ["polars-ops/rank"] -diff = ["polars-ops/diff"] -pct_change = ["polars-ops/pct_change"] +diff = ["polars-core/diff", "polars-ops/diff"] +pct_change = ["polars-core/pct_change"] moment = ["polars-core/moment", "polars-ops/moment"] abs = ["polars-core/abs"] random = ["polars-core/random"] dynamic_group_by = ["polars-core/dynamic_group_by"] -ewma = ["polars-ops/ewma"] +ewma = ["polars-core/ewma"] dot_diagram = [] unique_counts = ["polars-core/unique_counts"] log = ["polars-ops/log"] @@ -136,7 +136,6 @@ fused = ["polars-ops/fused"] list_sets = ["polars-ops/list_sets"] list_any_all = ["polars-ops/list_any_all"] list_drop_nulls = ["polars-ops/list_drop_nulls"] -list_sample = ["polars-ops/list_sample"] cutqcut = ["polars-ops/cutqcut"] rle = ["polars-ops/rle"] extract_groups = ["regex", "dtype-struct", "polars-ops/extract_groups"] diff --git a/crates/polars-plan/src/dot.rs b/crates/polars-plan/src/dot.rs index 0c3bb9ce9085..a581eb7dafe6 100644 --- a/crates/polars-plan/src/dot.rs +++ b/crates/polars-plan/src/dot.rs @@ -1,6 +1,5 @@ -use std::borrow::Cow; use std::fmt::{Display, Write}; -use std::path::PathBuf; +use std::path::Path; use polars_core::prelude::*; @@ -151,9 +150,9 @@ impl LogicalPlan { count, } => { let fmt = if *count == usize::MAX { - Cow::Borrowed("CACHE") + "CACHE".to_string() } else { - Cow::Owned(format!("CACHE: {}times", *count)) + format!("CACHE: {}times", *count) }; let current_node = DotNode { branch: *cache_id, @@ -182,7 +181,7 @@ impl LogicalPlan { acc_str, prev_node, "PYTHON", - &[], + Path::new(""), options.with_columns.as_ref().map(|s| s.as_slice()), options.schema.len(), &options.predicate, @@ -313,7 +312,7 @@ impl LogicalPlan { } }, Scan { - paths, + path, file_info, predicate, scan_type, @@ -325,7 +324,7 @@ impl LogicalPlan { acc_str, prev_node, name, - paths.as_ref(), + path.as_ref(), options.with_columns.as_ref().map(|cols| cols.as_slice()), file_info.schema.len(), predicate, @@ -410,7 +409,7 @@ impl LogicalPlan { acc_str: &mut String, prev_node: DotNode, name: &str, - path: &[PathBuf], + path: &Path, with_columns: Option<&[String]>, total_columns: usize, predicate: &Option

, @@ -423,20 +422,13 @@ impl LogicalPlan { n_columns_fmt = format!("{}", columns.len()); } - let fmt = if path.len() == 1 { - path[0].to_string_lossy() - } else { - Cow::Owned(format!( - "{} files: first file: {}", - path.len(), - path[0].to_string_lossy() - )) - }; - let pred = fmt_predicate(predicate.as_ref()); let fmt = format!( "{name} SCAN {};\nπ {}/{};\nσ {}", - fmt, n_columns_fmt, total_columns, pred, + path.to_string_lossy(), + n_columns_fmt, + total_columns, + pred, ); let current_node = DotNode { branch, diff --git a/crates/polars-plan/src/dsl/function_expr/cum.rs b/crates/polars-plan/src/dsl/function_expr/cum.rs index 0ba1a0f6281d..d8ac6434809b 100644 --- a/crates/polars-plan/src/dsl/function_expr/cum.rs +++ b/crates/polars-plan/src/dsl/function_expr/cum.rs @@ -1,23 +1,33 @@ use super::*; pub(super) fn cumcount(s: &Series, reverse: bool) -> PolarsResult { - polars_ops::prelude::cumcount(s, reverse) + if reverse { + let ca: NoNull = (0u32..s.len() as u32).rev().collect(); + let mut ca = ca.into_inner(); + ca.rename(s.name()); + Ok(ca.into_series()) + } else { + let ca: NoNull = (0u32..s.len() as u32).collect(); + let mut ca = ca.into_inner(); + ca.rename(s.name()); + Ok(ca.into_series()) + } } pub(super) fn cumsum(s: &Series, reverse: bool) -> PolarsResult { - polars_ops::prelude::cumsum(s, reverse) + Ok(s.cumsum(reverse)) } pub(super) fn cumprod(s: &Series, reverse: bool) -> PolarsResult { - polars_ops::prelude::cumprod(s, reverse) + Ok(s.cumprod(reverse)) } pub(super) fn cummin(s: &Series, reverse: bool) -> PolarsResult { - polars_ops::prelude::cummin(s, reverse) + Ok(s.cummin(reverse)) } pub(super) fn cummax(s: &Series, reverse: bool) -> PolarsResult { - polars_ops::prelude::cummax(s, reverse) + Ok(s.cummax(reverse)) } pub(super) mod dtypes { diff --git a/crates/polars-plan/src/dsl/function_expr/dispatch.rs b/crates/polars-plan/src/dsl/function_expr/dispatch.rs index eb00120e970f..bbe11390fb01 100644 --- a/crates/polars-plan/src/dsl/function_expr/dispatch.rs +++ b/crates/polars-plan/src/dsl/function_expr/dispatch.rs @@ -15,12 +15,7 @@ pub(super) fn approx_n_unique(s: &Series) -> PolarsResult { #[cfg(feature = "diff")] pub(super) fn diff(s: &Series, n: i64, null_behavior: NullBehavior) -> PolarsResult { - polars_ops::prelude::diff(s, n, null_behavior) -} - -#[cfg(feature = "pct_change")] -pub(super) fn pct_change(s: &[Series]) -> PolarsResult { - polars_ops::prelude::pct_change(&s[0], &s[1]) + s.diff(n, null_behavior) } #[cfg(feature = "interpolate")] @@ -76,32 +71,3 @@ pub(super) fn max_horizontal(s: &mut [Series]) -> PolarsResult> { pub(super) fn min_horizontal(s: &mut [Series]) -> PolarsResult> { polars_ops::prelude::min_horizontal(s) } - -pub(super) fn drop_nulls(s: &Series) -> PolarsResult { - Ok(s.drop_nulls()) -} - -#[cfg(feature = "mode")] -pub(super) fn mode(s: &Series) -> PolarsResult { - mode::mode(s) -} - -#[cfg(feature = "moment")] -pub(super) fn skew(s: &Series, bias: bool) -> PolarsResult { - s.skew(bias).map(|opt_v| Series::new(s.name(), &[opt_v])) -} - -#[cfg(feature = "moment")] -pub(super) fn kurtosis(s: &Series, fisher: bool, bias: bool) -> PolarsResult { - s.kurtosis(fisher, bias) - .map(|opt_v| Series::new(s.name(), &[opt_v])) -} - -pub(super) fn arg_unique(s: &Series) -> PolarsResult { - s.arg_unique().map(|ok| ok.into_series()) -} - -#[cfg(feature = "rank")] -pub(super) fn rank(s: &Series, options: RankOptions, seed: Option) -> PolarsResult { - Ok(s.rank(options, seed)) -} diff --git a/crates/polars-plan/src/dsl/function_expr/ewm.rs b/crates/polars-plan/src/dsl/function_expr/ewm.rs deleted file mode 100644 index b824ca3013e9..000000000000 --- a/crates/polars-plan/src/dsl/function_expr/ewm.rs +++ /dev/null @@ -1,13 +0,0 @@ -use super::*; - -pub(super) fn ewm_mean(s: &Series, options: EWMOptions) -> PolarsResult { - polars_ops::prelude::ewm_mean(s, options) -} - -pub(super) fn ewm_std(s: &Series, options: EWMOptions) -> PolarsResult { - polars_ops::prelude::ewm_std(s, options) -} - -pub(super) fn ewm_var(s: &Series, options: EWMOptions) -> PolarsResult { - polars_ops::prelude::ewm_var(s, options) -} diff --git a/crates/polars-plan/src/dsl/function_expr/list.rs b/crates/polars-plan/src/dsl/function_expr/list.rs index bb238d09bd75..35155902b788 100644 --- a/crates/polars-plan/src/dsl/function_expr/list.rs +++ b/crates/polars-plan/src/dsl/function_expr/list.rs @@ -11,13 +11,6 @@ pub enum ListFunction { Contains, #[cfg(feature = "list_drop_nulls")] DropNulls, - #[cfg(feature = "list_sample")] - Sample { - is_fraction: bool, - with_replacement: bool, - shuffle: bool, - seed: Option, - }, Slice, Shift, Get, @@ -59,14 +52,6 @@ impl Display for ListFunction { Contains => "contains", #[cfg(feature = "list_drop_nulls")] DropNulls => "drop_nulls", - #[cfg(feature = "list_sample")] - Sample { is_fraction, .. } => { - if *is_fraction { - "sample_fraction" - } else { - "sample_n" - } - }, Slice => "slice", Shift => "shift", Get => "get", @@ -122,32 +107,6 @@ pub(super) fn drop_nulls(s: &Series) -> PolarsResult { Ok(list.lst_drop_nulls().into_series()) } -#[cfg(feature = "list_sample")] -pub(super) fn sample_n( - s: &[Series], - with_replacement: bool, - shuffle: bool, - seed: Option, -) -> PolarsResult { - let list = s[0].list()?; - let n = &s[1]; - list.lst_sample_n(n, with_replacement, shuffle, seed) - .map(|ok| ok.into_series()) -} - -#[cfg(feature = "list_sample")] -pub(super) fn sample_fraction( - s: &[Series], - with_replacement: bool, - shuffle: bool, - seed: Option, -) -> PolarsResult { - let list = s[0].list()?; - let fraction = &s[1]; - list.lst_sample_fraction(fraction, with_replacement, shuffle, seed) - .map(|ok| ok.into_series()) -} - fn check_slice_arg_shape(slice_len: usize, ca_len: usize, name: &str) -> PolarsResult<()> { polars_ensure!( slice_len == ca_len, diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index 3ca6b5f2c7c9..7f51397f8dfe 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -14,13 +14,10 @@ mod clip; mod coerce; mod concat; mod correlation; -#[cfg(feature = "cum_agg")] mod cum; #[cfg(feature = "temporal")] mod datetime; mod dispatch; -#[cfg(feature = "ewma")] -mod ewm; mod fill_null; #[cfg(feature = "fused")] mod fused; @@ -138,19 +135,6 @@ pub enum FunctionExpr { periods: i64, }, DropNans, - DropNulls, - #[cfg(feature = "mode")] - Mode, - #[cfg(feature = "moment")] - Skew(bool), - #[cfg(feature = "moment")] - Kurtosis(bool, bool), - ArgUnique, - #[cfg(feature = "rank")] - Rank { - options: RankOptions, - seed: Option, - }, #[cfg(feature = "round_series")] Clip { has_min: bool, @@ -166,23 +150,18 @@ pub enum FunctionExpr { #[cfg(feature = "top_k")] TopK(bool), Shift(i64), - #[cfg(feature = "cum_agg")] Cumcount { reverse: bool, }, - #[cfg(feature = "cum_agg")] Cumsum { reverse: bool, }, - #[cfg(feature = "cum_agg")] Cumprod { reverse: bool, }, - #[cfg(feature = "cum_agg")] Cummin { reverse: bool, }, - #[cfg(feature = "cum_agg")] Cummax { reverse: bool, }, @@ -203,8 +182,6 @@ pub enum FunctionExpr { ShrinkType, #[cfg(feature = "diff")] Diff(i64, NullBehavior), - #[cfg(feature = "pct_change")] - PctChange, #[cfg(feature = "interpolate")] Interpolate(InterpolationMethod), #[cfg(feature = "log")] @@ -269,15 +246,9 @@ pub enum FunctionExpr { }, SetSortedFlag(IsSorted), #[cfg(feature = "ffi_plugin")] - /// Creating this node is unsafe - /// This will lead to calls over FFI> FfiPlugin { - /// Shared library. lib: Arc, - /// Identifier in the shared lib. symbol: Arc, - /// Pickle serialized keyword arguments. - kwargs: Arc<[u8]>, }, BackwardFill { limit: FillNullLimit, @@ -288,18 +259,6 @@ pub enum FunctionExpr { SumHorizontal, MaxHorizontal, MinHorizontal, - #[cfg(feature = "ewma")] - EwmMean { - options: EWMOptions, - }, - #[cfg(feature = "ewma")] - EwmStd { - options: EWMOptions, - }, - #[cfg(feature = "ewma")] - EwmVar { - options: EWMOptions, - }, } impl Hash for FunctionExpr { @@ -336,12 +295,7 @@ impl Hash for FunctionExpr { #[cfg(feature = "dtype-categorical")] FunctionExpr::Categorical(f) => f.hash(state), #[cfg(feature = "ffi_plugin")] - FunctionExpr::FfiPlugin { - lib, - symbol, - kwargs, - } => { - kwargs.hash(state); + FunctionExpr::FfiPlugin { lib, symbol } => { lib.hash(state); symbol.hash(state); }, @@ -385,16 +339,6 @@ impl Display for FunctionExpr { RollingSkew { .. } => "rolling_skew", ShiftAndFill { .. } => "shift_and_fill", DropNans => "drop_nans", - DropNulls => "drop_nulls", - #[cfg(feature = "mode")] - Mode => "mode", - #[cfg(feature = "moment")] - Skew(_) => "skew", - #[cfg(feature = "moment")] - Kurtosis(..) => "kurtosis", - ArgUnique => "arg_unique", - #[cfg(feature = "rank")] - Rank { .. } => "rank", #[cfg(feature = "round_series")] Clip { has_min, has_max } => match (has_min, has_max) { (true, true) => "clip", @@ -416,15 +360,10 @@ impl Display for FunctionExpr { } }, Shift(_) => "shift", - #[cfg(feature = "cum_agg")] Cumcount { .. } => "cumcount", - #[cfg(feature = "cum_agg")] Cumsum { .. } => "cumsum", - #[cfg(feature = "cum_agg")] Cumprod { .. } => "cumprod", - #[cfg(feature = "cum_agg")] Cummin { .. } => "cummin", - #[cfg(feature = "cum_agg")] Cummax { .. } => "cummax", #[cfg(feature = "dtype-struct")] ValueCounts { .. } => "value_counts", @@ -440,8 +379,6 @@ impl Display for FunctionExpr { ShrinkType => "shrink_dtype", #[cfg(feature = "diff")] Diff(_, _) => "diff", - #[cfg(feature = "pct_change")] - PctChange => "pct_change", #[cfg(feature = "interpolate")] Interpolate(_) => "interpolate", #[cfg(feature = "log")] @@ -496,12 +433,6 @@ impl Display for FunctionExpr { SumHorizontal => "sum_horizontal", MaxHorizontal => "max_horizontal", MinHorizontal => "min_horizontal", - #[cfg(feature = "ewma")] - EwmMean { .. } => "ewm_mean", - #[cfg(feature = "ewma")] - EwmStd { .. } => "ewm_std", - #[cfg(feature = "ewma")] - EwmVar { .. } => "ewm_var", }; write!(f, "{s}") } @@ -649,20 +580,10 @@ impl From for SpecialEq> { map_as_slice!(shift_and_fill::shift_and_fill, periods) }, DropNans => map_owned!(nan::drop_nans), - DropNulls => map!(dispatch::drop_nulls), #[cfg(feature = "round_series")] Clip { has_min, has_max } => { map_as_slice!(clip::clip, has_min, has_max) }, - #[cfg(feature = "mode")] - Mode => map!(dispatch::mode), - #[cfg(feature = "moment")] - Skew(bias) => map!(dispatch::skew, bias), - #[cfg(feature = "moment")] - Kurtosis(fisher, bias) => map!(dispatch::kurtosis, fisher, bias), - ArgUnique => map!(dispatch::arg_unique), - #[cfg(feature = "rank")] - Rank { options, seed } => map!(dispatch::rank, options, seed), ListExpr(lf) => { use ListFunction::*; match lf { @@ -671,19 +592,6 @@ impl From for SpecialEq> { Contains => wrap!(list::contains), #[cfg(feature = "list_drop_nulls")] DropNulls => map!(list::drop_nulls), - #[cfg(feature = "list_sample")] - Sample { - is_fraction, - with_replacement, - shuffle, - seed, - } => { - if is_fraction { - map_as_slice!(list::sample_fraction, with_replacement, shuffle, seed) - } else { - map_as_slice!(list::sample_n, with_replacement, shuffle, seed) - } - }, Slice => wrap!(list::slice), Shift => map_as_slice!(list::shift), Get => wrap!(list::get), @@ -740,15 +648,10 @@ impl From for SpecialEq> { map_as_slice!(top_k, descending) }, Shift(periods) => map!(dispatch::shift, periods), - #[cfg(feature = "cum_agg")] Cumcount { reverse } => map!(cum::cumcount, reverse), - #[cfg(feature = "cum_agg")] Cumsum { reverse } => map!(cum::cumsum, reverse), - #[cfg(feature = "cum_agg")] Cumprod { reverse } => map!(cum::cumprod, reverse), - #[cfg(feature = "cum_agg")] Cummin { reverse } => map!(cum::cummin, reverse), - #[cfg(feature = "cum_agg")] Cummax { reverse } => map!(cum::cummax, reverse), #[cfg(feature = "dtype-struct")] ValueCounts { sort, parallel } => map!(dispatch::value_counts, sort, parallel), @@ -764,8 +667,6 @@ impl From for SpecialEq> { ShrinkType => map_owned!(shrink_type::shrink), #[cfg(feature = "diff")] Diff(n, null_behavior) => map!(dispatch::diff, n, null_behavior), - #[cfg(feature = "pct_change")] - PctChange => map_as_slice!(dispatch::pct_change), #[cfg(feature = "interpolate")] Interpolate(method) => { map!(dispatch::interpolate, method) @@ -846,29 +747,14 @@ impl From for SpecialEq> { }, SetSortedFlag(sorted) => map!(dispatch::set_sorted_flag, sorted), #[cfg(feature = "ffi_plugin")] - FfiPlugin { - lib, - symbol, - kwargs, - } => unsafe { - map_as_slice!( - plugin::call_plugin, - lib.as_ref(), - symbol.as_ref(), - kwargs.as_ref() - ) + FfiPlugin { lib, symbol, .. } => unsafe { + map_as_slice!(plugin::call_plugin, lib.as_ref(), symbol.as_ref()) }, BackwardFill { limit } => map!(dispatch::backward_fill, limit), ForwardFill { limit } => map!(dispatch::forward_fill, limit), SumHorizontal => map_as_slice!(dispatch::sum_horizontal), MaxHorizontal => wrap!(dispatch::max_horizontal), MinHorizontal => wrap!(dispatch::min_horizontal), - #[cfg(feature = "ewma")] - EwmMean { options } => map!(ewm::ewm_mean, options), - #[cfg(feature = "ewma")] - EwmStd { options } => map!(ewm::ewm_std, options), - #[cfg(feature = "ewma")] - EwmVar { options } => map!(ewm::ewm_var, options), } } } diff --git a/crates/polars-plan/src/dsl/function_expr/plugin.rs b/crates/polars-plan/src/dsl/function_expr/plugin.rs index 85fea0edf7b8..6c8113a54aac 100644 --- a/crates/polars-plan/src/dsl/function_expr/plugin.rs +++ b/crates/polars-plan/src/dsl/function_expr/plugin.rs @@ -1,4 +1,3 @@ -use std::ffi::CString; use std::sync::RwLock; use arrow::ffi::{import_field_from_c, ArrowSchema}; @@ -31,59 +30,24 @@ fn get_lib(lib: &str) -> PolarsResult<&'static Library> { } } -unsafe fn retrieve_error_msg(lib: &Library) -> CString { - let symbol: libloading::Symbol *mut std::os::raw::c_char> = - lib.get(b"get_last_error_message\0").unwrap(); - let msg_ptr = symbol(); - CString::from_raw(msg_ptr) -} - -pub(super) unsafe fn call_plugin( - s: &[Series], - lib: &str, - symbol: &str, - kwargs: &[u8], -) -> PolarsResult { +pub(super) unsafe fn call_plugin(s: &[Series], lib: &str, symbol: &str) -> PolarsResult { let lib = get_lib(lib)?; - // *const SeriesExport: pointer to Box - // * usize: length of that pointer - // *const u8: pointer to &[u8] - // usize: length of the u8 slice - // *mut SeriesExport: pointer where return value should be written. let symbol: libloading::Symbol< - unsafe extern "C" fn(*const SeriesExport, usize, *const u8, usize, *mut SeriesExport), + unsafe extern "C" fn(*const SeriesExport, usize) -> SeriesExport, > = lib.get(symbol.as_bytes()).unwrap(); + let n_args = s.len(); + let input = s.iter().map(export_series).collect::>(); - let input_len = s.len(); let slice_ptr = input.as_ptr(); + let out = symbol(slice_ptr, n_args); - let kwargs_ptr = kwargs.as_ptr(); - let kwargs_len = kwargs.len(); - - let mut return_value = SeriesExport::empty(); - let return_value_ptr = &mut return_value as *mut SeriesExport; - symbol( - slice_ptr, - input_len, - kwargs_ptr, - kwargs_len, - return_value_ptr, - ); - - // The inputs get dropped when the ffi side calls the drop callback. for e in input { std::mem::forget(e); } - if !return_value.is_null() { - import_series(return_value) - } else { - let msg = retrieve_error_msg(lib); - let msg = msg.to_string_lossy(); - polars_bail!(ComputeError: "the plugin failed with message: {}", msg) - } + import_series(out) } pub(super) unsafe fn plugin_field( @@ -93,12 +57,8 @@ pub(super) unsafe fn plugin_field( ) -> PolarsResult { let lib = get_lib(lib)?; - // *const ArrowSchema: pointer to heap Box - // usize: length of the boxed slice - // *mut ArrowSchema: pointer where the return value can be written - let symbol: libloading::Symbol< - unsafe extern "C" fn(*const ArrowSchema, usize, *mut ArrowSchema), - > = lib.get(symbol.as_bytes()).unwrap(); + let symbol: libloading::Symbol ArrowSchema> = + lib.get(symbol.as_bytes()).unwrap(); // we deallocate the fields buffer let fields = fields @@ -108,18 +68,8 @@ pub(super) unsafe fn plugin_field( .into_boxed_slice(); let n_args = fields.len(); let slice_ptr = fields.as_ptr(); + let out = symbol(slice_ptr, n_args); - let mut return_value = ArrowSchema::empty(); - let return_value_ptr = &mut return_value as *mut ArrowSchema; - symbol(slice_ptr, n_args, return_value_ptr); - - if !return_value.is_null() { - let arrow_field = import_field_from_c(&return_value)?; - let out = Field::from(&arrow_field); - Ok(out) - } else { - let msg = retrieve_error_msg(lib); - let msg = msg.to_string_lossy(); - polars_bail!(ComputeError: "the plugin failed with message: {}", msg) - } + let arrow_field = import_field_from_c(&out)?; + Ok(Field::from(&arrow_field)) } diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index af99f7f81b52..3d8996e74431 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -47,21 +47,8 @@ impl FunctionExpr { RollingSkew { .. } => mapper.map_to_float_dtype(), ShiftAndFill { .. } => mapper.with_same_dtype(), DropNans => mapper.with_same_dtype(), - DropNulls => mapper.with_same_dtype(), #[cfg(feature = "round_series")] Clip { .. } => mapper.with_same_dtype(), - #[cfg(feature = "mode")] - Mode => mapper.with_same_dtype(), - #[cfg(feature = "moment")] - Skew(_) => mapper.with_dtype(DataType::Float64), - #[cfg(feature = "moment")] - Kurtosis(..) => mapper.with_dtype(DataType::Float64), - ArgUnique => mapper.with_dtype(IDX_DTYPE), - #[cfg(feature = "rank")] - Rank { options, .. } => mapper.with_dtype(match options.method { - RankMethod::Average => DataType::Float64, - _ => IDX_DTYPE, - }), ListExpr(l) => { use ListFunction::*; match l { @@ -70,8 +57,6 @@ impl FunctionExpr { Contains => mapper.with_dtype(DataType::Boolean), #[cfg(feature = "list_drop_nulls")] DropNulls => mapper.with_same_dtype(), - #[cfg(feature = "list_sample")] - Sample { .. } => mapper.with_same_dtype(), Slice => mapper.with_same_dtype(), Shift => mapper.with_same_dtype(), Get => mapper.map_to_list_inner_dtype(), @@ -137,15 +122,10 @@ impl FunctionExpr { Boolean(func) => func.get_field(mapper), #[cfg(feature = "dtype-categorical")] Categorical(func) => func.get_field(mapper), - #[cfg(feature = "cum_agg")] Cumcount { .. } => mapper.with_dtype(IDX_DTYPE), - #[cfg(feature = "cum_agg")] Cumsum { .. } => mapper.map_dtype(cum::dtypes::cumsum), - #[cfg(feature = "cum_agg")] Cumprod { .. } => mapper.map_dtype(cum::dtypes::cumprod), - #[cfg(feature = "cum_agg")] Cummin { .. } => mapper.with_same_dtype(), - #[cfg(feature = "cum_agg")] Cummax { .. } => mapper.with_same_dtype(), #[cfg(feature = "approx_unique")] ApproxNUnique => mapper.with_dtype(IDX_DTYPE), @@ -162,11 +142,6 @@ impl FunctionExpr { DataType::UInt8 => DataType::Int16, dt => dt.clone(), }), - #[cfg(feature = "pct_change")] - PctChange => mapper.map_dtype(|dt| match dt { - DataType::Float64 | DataType::Float32 => dt.clone(), - _ => DataType::Float64, - }), #[cfg(feature = "interpolate")] Interpolate(method) => match method { InterpolationMethod::Linear => mapper.map_numeric_to_float_dtype(), @@ -259,7 +234,7 @@ impl FunctionExpr { Random { .. } => mapper.with_same_dtype(), SetSortedFlag(_) => mapper.with_same_dtype(), #[cfg(feature = "ffi_plugin")] - FfiPlugin { lib, symbol, .. } => unsafe { + FfiPlugin { lib, symbol } => unsafe { plugin::plugin_field(fields, lib, &format!("__polars_field_{}", symbol.as_ref())) }, BackwardFill { .. } => mapper.with_same_dtype(), @@ -267,12 +242,6 @@ impl FunctionExpr { SumHorizontal => mapper.map_to_supertype(), MaxHorizontal => mapper.map_to_supertype(), MinHorizontal => mapper.map_to_supertype(), - #[cfg(feature = "ewma")] - EwmMean { .. } => mapper.map_to_float_dtype(), - #[cfg(feature = "ewma")] - EwmStd { .. } => mapper.map_to_float_dtype(), - #[cfg(feature = "ewma")] - EwmVar { .. } => mapper.map_to_float_dtype(), } } } diff --git a/crates/polars-plan/src/dsl/functions/temporal.rs b/crates/polars-plan/src/dsl/functions/temporal.rs index b2eb3af5a229..6ae7a8ee0c5c 100644 --- a/crates/polars-plan/src/dsl/functions/temporal.rs +++ b/crates/polars-plan/src/dsl/functions/temporal.rs @@ -156,7 +156,6 @@ pub fn datetime(args: DatetimeArgs) -> Expr { /// their default value of `lit(0)`, as demonstrated below. /// /// ``` -/// # use polars_plan::prelude::*; /// let args = DurationArgs { /// days: lit(5), /// hours: col("num_hours"), @@ -166,7 +165,6 @@ pub fn datetime(args: DatetimeArgs) -> Expr { /// ``` /// If you prefer builder syntax, `with_*` methods are also available. /// ``` -/// # use polars_plan::prelude::*; /// let args = DurationArgs::new().with_weeks(lit(42)).with_hours(lit(84)); /// ``` #[derive(Debug, Clone)] diff --git a/crates/polars-plan/src/dsl/list.rs b/crates/polars-plan/src/dsl/list.rs index c8741dba73ff..6e9bde5b68eb 100644 --- a/crates/polars-plan/src/dsl/list.rs +++ b/crates/polars-plan/src/dsl/list.rs @@ -34,48 +34,6 @@ impl ListNameSpace { .map_private(FunctionExpr::ListExpr(ListFunction::DropNulls)) } - #[cfg(feature = "list_sample")] - pub fn sample_n( - self, - n: Expr, - with_replacement: bool, - shuffle: bool, - seed: Option, - ) -> Expr { - self.0.map_many_private( - FunctionExpr::ListExpr(ListFunction::Sample { - is_fraction: false, - with_replacement, - shuffle, - seed, - }), - &[n], - false, - false, - ) - } - - #[cfg(feature = "list_sample")] - pub fn sample_fraction( - self, - fraction: Expr, - with_replacement: bool, - shuffle: bool, - seed: Option, - ) -> Expr { - self.0.map_many_private( - FunctionExpr::ListExpr(ListFunction::Sample { - is_fraction: true, - with_replacement, - shuffle, - seed, - }), - &[fraction], - false, - false, - ) - } - /// Return the number of elements in each list. /// /// Null values are treated like regular elements in this context. diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index 05dca3120e10..3b335e6d5475 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -2,6 +2,8 @@ //! Domain specific language for the Lazy API. #[cfg(feature = "rolling_window")] use polars_core::utils::ensure_sorted_arg; +#[cfg(feature = "mode")] +use polars_ops::chunked_array::mode::mode; #[cfg(feature = "dtype-categorical")] pub mod cat; #[cfg(feature = "dtype-categorical")] @@ -52,7 +54,7 @@ use polars_core::prelude::*; #[cfg(feature = "diff")] use polars_core::series::ops::NullBehavior; use polars_core::series::IsSorted; -use polars_core::utils::try_get_supertype; +use polars_core::utils::{try_get_supertype, NoNull}; #[cfg(feature = "rolling_window")] use polars_time::prelude::SeriesOpsTime; pub(crate) use selector::Selector; @@ -181,7 +183,7 @@ impl Expr { /// Drop null values. pub fn drop_nulls(self) -> Self { - self.apply_private(FunctionExpr::DropNulls) + self.apply(|s| Ok(Some(s.drop_nulls())), GetOutput::same_type()) } /// Drop NaN values. @@ -343,7 +345,11 @@ impl Expr { /// Get the first index of unique values of this expression. pub fn arg_unique(self) -> Self { - self.apply_private(FunctionExpr::ArgUnique) + self.apply( + |s: Series| s.arg_unique().map(|ca| Some(ca.into_series())), + GetOutput::from_type(IDX_DTYPE), + ) + .with_fmt("arg_unique") } /// Get the index value that has the minimum value. @@ -647,6 +653,7 @@ impl Expr { options: FunctionOptions { collect_groups: ApplyOptions::ApplyGroups, fmt_str: "", + auto_explode: true, ..Default::default() }, } @@ -736,31 +743,26 @@ impl Expr { } /// Cumulatively count values from 0 to len. - #[cfg(feature = "cum_agg")] pub fn cumcount(self, reverse: bool) -> Self { self.apply_private(FunctionExpr::Cumcount { reverse }) } /// Get an array with the cumulative sum computed at every element. - #[cfg(feature = "cum_agg")] pub fn cumsum(self, reverse: bool) -> Self { self.apply_private(FunctionExpr::Cumsum { reverse }) } /// Get an array with the cumulative product computed at every element. - #[cfg(feature = "cum_agg")] pub fn cumprod(self, reverse: bool) -> Self { self.apply_private(FunctionExpr::Cumprod { reverse }) } /// Get an array with the cumulative min computed at every element. - #[cfg(feature = "cum_agg")] pub fn cummin(self, reverse: bool) -> Self { self.apply_private(FunctionExpr::Cummin { reverse }) } /// Get an array with the cumulative max computed at every element. - #[cfg(feature = "cum_agg")] pub fn cummax(self, reverse: bool) -> Self { self.apply_private(FunctionExpr::Cummax { reverse }) } @@ -1135,13 +1137,13 @@ impl Expr { #[cfg(feature = "mode")] /// Compute the mode(s) of this column. This is the most occurring value. pub fn mode(self) -> Expr { - self.apply_private(FunctionExpr::Mode) + self.apply(|s| mode(&s).map(Some), GetOutput::same_type()) + .with_fmt("mode") } /// Keep the original root name /// /// ```rust,no_run - /// # use polars_core::prelude::*; /// # use polars_plan::prelude::*; /// fn example(df: LazyFrame) -> LazyFrame { /// df.select([ @@ -1182,6 +1184,21 @@ impl Expr { /// Exclude a column from a wildcard/regex selection. /// /// You may also use regexes in the exclude as long as they start with `^` and end with `$`/ + /// + /// # Example + /// + /// ```rust + /// use polars_core::prelude::*; + /// use polars_lazy::prelude::*; + /// + /// // Select all columns except foo. + /// fn example(df: DataFrame) -> LazyFrame { + /// df.lazy() + /// .select(&[ + /// col("*").exclude(&["foo"]) + /// ]) + /// } + /// ``` pub fn exclude(self, columns: impl IntoVec) -> Expr { let v = columns .into_vec() @@ -1462,7 +1479,14 @@ impl Expr { #[cfg(feature = "rank")] /// Assign ranks to data, dealing with ties appropriately. pub fn rank(self, options: RankOptions, seed: Option) -> Expr { - self.apply_private(FunctionExpr::Rank { options, seed }) + self.apply( + move |s| Ok(Some(s.rank(options, seed))), + GetOutput::map_field(move |fld| match options.method { + RankMethod::Average => Field::new(fld.name(), DataType::Float64), + _ => Field::new(fld.name(), IDX_DTYPE), + }), + ) + .with_fmt("rank") } #[cfg(feature = "cutqcut")] @@ -1541,8 +1565,16 @@ impl Expr { #[cfg(feature = "pct_change")] /// Computes percentage change between values. - pub fn pct_change(self, n: Expr) -> Expr { - self.apply_many_private(FunctionExpr::PctChange, &[n], false, false) + pub fn pct_change(self, n: i64) -> Expr { + use DataType::*; + self.apply( + move |s| s.pct_change(n).map(Some), + GetOutput::map_dtype(|dt| match dt { + Float64 | Float32 => dt.clone(), + _ => Float64, + }), + ) + .with_fmt("pct_change") } #[cfg(feature = "moment")] @@ -1556,11 +1588,19 @@ impl Expr { /// /// see: [scipy](https://github.com/scipy/scipy/blob/47bb6febaa10658c72962b9615d5d5aa2513fa3a/scipy/stats/stats.py#L1024) pub fn skew(self, bias: bool) -> Expr { - self.apply_private(FunctionExpr::Skew(bias)) - .with_function_options(|mut options| { - options.auto_explode = true; - options - }) + self.apply( + move |s| { + s.skew(bias) + .map(|opt_v| Series::new(s.name(), &[opt_v])) + .map(Some) + }, + GetOutput::from_type(DataType::Float64), + ) + .with_function_options(|mut options| { + options.fmt_str = "skew"; + options.auto_explode = true; + options + }) } #[cfg(feature = "moment")] @@ -1572,11 +1612,18 @@ impl Expr { /// If bias is False then the kurtosis is calculated using k statistics to /// eliminate bias coming from biased moment estimators. pub fn kurtosis(self, fisher: bool, bias: bool) -> Expr { - self.apply_private(FunctionExpr::Kurtosis(fisher, bias)) - .with_function_options(|mut options| { - options.auto_explode = true; - options - }) + self.apply( + move |s| { + s.kurtosis(fisher, bias) + .map(|opt_v| Some(Series::new(s.name(), &[opt_v]))) + }, + GetOutput::from_type(DataType::Float64), + ) + .with_function_options(|mut options| { + options.fmt_str = "kurtosis"; + options.auto_explode = true; + options + }) } /// Get maximal value that could be hold by this dtype. @@ -1619,19 +1666,43 @@ impl Expr { #[cfg(feature = "ewma")] /// Calculate the exponentially-weighted moving average. pub fn ewm_mean(self, options: EWMOptions) -> Self { - self.apply_private(FunctionExpr::EwmMean { options }) + use DataType::*; + self.apply( + move |s| s.ewm_mean(options).map(Some), + GetOutput::map_dtype(|dt| match dt { + Float64 | Float32 => dt.clone(), + _ => Float64, + }), + ) + .with_fmt("ewm_mean") } #[cfg(feature = "ewma")] /// Calculate the exponentially-weighted moving standard deviation. pub fn ewm_std(self, options: EWMOptions) -> Self { - self.apply_private(FunctionExpr::EwmStd { options }) + use DataType::*; + self.apply( + move |s| s.ewm_std(options).map(Some), + GetOutput::map_dtype(|dt| match dt { + Float64 | Float32 => dt.clone(), + _ => Float64, + }), + ) + .with_fmt("ewm_std") } #[cfg(feature = "ewma")] /// Calculate the exponentially-weighted moving variance. pub fn ewm_var(self, options: EWMOptions) -> Self { - self.apply_private(FunctionExpr::EwmVar { options }) + use DataType::*; + self.apply( + move |s| s.ewm_var(options).map(Some), + GetOutput::map_dtype(|dt| match dt { + Float64 | Float32 => dt.clone(), + _ => Float64, + }), + ) + .with_fmt("ewm_var") } /// Returns whether any of the values in the column are `true`. diff --git a/crates/polars-plan/src/logical_plan/alp.rs b/crates/polars-plan/src/logical_plan/alp.rs index 1c63851a0844..d6a96e2394a1 100644 --- a/crates/polars-plan/src/logical_plan/alp.rs +++ b/crates/polars-plan/src/logical_plan/alp.rs @@ -30,7 +30,7 @@ pub enum ALogicalPlan { predicate: Node, }, Scan { - paths: Arc<[PathBuf]>, + path: PathBuf, file_info: FileInfo, predicate: Option, /// schema of the projected file @@ -293,7 +293,7 @@ impl ALogicalPlan { options: *options, }, Scan { - paths, + path, file_info, output_schema, predicate, @@ -305,7 +305,7 @@ impl ALogicalPlan { new_predicate = exprs.pop() } Scan { - paths: paths.clone(), + path: path.clone(), file_info: file_info.clone(), output_schema: output_schema.clone(), file_options: options.clone(), diff --git a/crates/polars-plan/src/logical_plan/builder.rs b/crates/polars-plan/src/logical_plan/builder.rs index b736c6cc8a98..6ccfd66afb93 100644 --- a/crates/polars-plan/src/logical_plan/builder.rs +++ b/crates/polars-plan/src/logical_plan/builder.rs @@ -119,7 +119,7 @@ impl LogicalPlanBuilder { }; Ok(LogicalPlan::Scan { - paths: Arc::new([]), + path: "".into(), file_info, predicate: None, file_options, @@ -201,7 +201,7 @@ impl LogicalPlanBuilder { hive_partitioning, }; Ok(LogicalPlan::Scan { - paths: Arc::new([path]), + path, file_info, file_options: options, predicate: None, @@ -253,7 +253,7 @@ impl LogicalPlanBuilder { hive_partitioning: false, }; Ok(LogicalPlan::Scan { - paths: Arc::new([path]), + path, file_info, file_options, predicate: None, @@ -299,8 +299,6 @@ impl LogicalPlanBuilder { } })?; - let paths = Arc::new([path]); - let mut magic_nr = [0u8; 2]; let res = file.read_exact(&mut magic_nr); if raise_if_empty { @@ -364,7 +362,7 @@ impl LogicalPlanBuilder { hive_partitioning: false, }; Ok(LogicalPlan::Scan { - paths, + path, file_info, file_options: options, predicate: None, diff --git a/crates/polars-plan/src/logical_plan/conversion.rs b/crates/polars-plan/src/logical_plan/conversion.rs index 2aecd7a693fc..f1910f2be2a9 100644 --- a/crates/polars-plan/src/logical_plan/conversion.rs +++ b/crates/polars-plan/src/logical_plan/conversion.rs @@ -168,13 +168,13 @@ pub fn to_alp( let v = match lp { LogicalPlan::Scan { file_info, - paths, + path, predicate, scan_type, file_options: options, } => ALogicalPlan::Scan { file_info, - paths, + path, output_schema: None, predicate: predicate.map(|expr| to_aexpr(expr, expr_arena)), scan_type, @@ -597,14 +597,14 @@ impl ALogicalPlan { }; match lp { ALogicalPlan::Scan { - paths, + path, file_info, predicate, scan_type, output_schema: _, file_options: options, } => LogicalPlan::Scan { - paths, + path, file_info, predicate: predicate.map(|n| node_to_expr(n, expr_arena)), scan_type, diff --git a/crates/polars-plan/src/logical_plan/format.rs b/crates/polars-plan/src/logical_plan/format.rs index 3f6631163716..ae7e4e48efd6 100644 --- a/crates/polars-plan/src/logical_plan/format.rs +++ b/crates/polars-plan/src/logical_plan/format.rs @@ -1,7 +1,7 @@ use std::borrow::Cow; use std::fmt; use std::fmt::{Debug, Display, Formatter, Write}; -use std::path::PathBuf; +use std::path::Path; use crate::prelude::*; @@ -9,7 +9,7 @@ use crate::prelude::*; fn write_scan( f: &mut Formatter, name: &str, - path: &[PathBuf], + path: &Path, indent: usize, n_columns: i64, total_columns: usize, @@ -19,17 +19,7 @@ fn write_scan( if indent != 0 { writeln!(f)?; } - let path_fmt = if path.len() == 1 { - path[0].to_string_lossy() - } else { - Cow::Owned(format!( - "{} files: first file: {}", - path.len(), - path[0].to_string_lossy() - )) - }; - - write!(f, "{:indent$}{} SCAN {}", "", name, path_fmt)?; + write!(f, "{:indent$}{} SCAN {}", "", name, path.display())?; if n_columns > 0 { write!( f, @@ -68,7 +58,7 @@ impl LogicalPlan { write_scan( f, "PYTHON", - &[], + Path::new(""), sub_indent, n_columns, total_columns, @@ -101,7 +91,7 @@ impl LogicalPlan { input._format(f, sub_indent) }, Scan { - paths, + path, file_info, predicate, scan_type, @@ -116,7 +106,7 @@ impl LogicalPlan { write_scan( f, scan_type.into(), - paths, + path, sub_indent, n_columns, file_info.schema.len(), diff --git a/crates/polars-plan/src/logical_plan/mod.rs b/crates/polars-plan/src/logical_plan/mod.rs index d394327e41f4..ecb6d1e8917b 100644 --- a/crates/polars-plan/src/logical_plan/mod.rs +++ b/crates/polars-plan/src/logical_plan/mod.rs @@ -156,7 +156,7 @@ pub enum LogicalPlan { count: usize, }, Scan { - paths: Arc<[PathBuf]>, + path: PathBuf, file_info: FileInfo, predicate: Option, file_options: FileScanOptions, diff --git a/crates/polars-plan/src/logical_plan/optimizer/cse.rs b/crates/polars-plan/src/logical_plan/optimizer/cse.rs index dce16147e60a..11e4c45ca925 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/cse.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/cse.rs @@ -117,13 +117,13 @@ fn lp_node_equal(a: &ALogicalPlan, b: &ALogicalPlan, expr_arena: &Arena) ) => Arc::ptr_eq(left_df, right_df), ( Scan { - paths: path_left, + path: path_left, predicate: predicate_left, scan_type: scan_type_left, .. }, Scan { - paths: path_right, + path: path_right, predicate: predicate_right, scan_type: scan_type_right, .. diff --git a/crates/polars-plan/src/logical_plan/optimizer/file_caching.rs b/crates/polars-plan/src/logical_plan/optimizer/file_caching.rs index 23791d3dd6b0..92e47d12c303 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/file_caching.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/file_caching.rs @@ -1,4 +1,4 @@ -use std::path::PathBuf; +use std::path::{Path, PathBuf}; use std::sync::Arc; use polars_core::datatypes::PlHashMap; @@ -9,14 +9,14 @@ use crate::prelude::*; #[derive(Hash, Eq, PartialEq, Clone, Debug)] pub struct FileFingerPrint { - pub paths: Arc<[PathBuf]>, + pub path: PathBuf, pub predicate: Option, pub slice: (usize, Option), } #[allow(clippy::type_complexity)] fn process_with_columns( - paths: &Arc<[PathBuf]>, + path: &Path, with_columns: Option<&Vec>, predicate: Option, slice: (usize, Option), @@ -25,7 +25,7 @@ fn process_with_columns( ) { let cols = file_count_and_column_union .entry(FileFingerPrint { - paths: paths.clone(), + path: path.into(), predicate, slice, }) @@ -59,7 +59,7 @@ pub fn collect_fingerprints( use ALogicalPlan::*; match lp_arena.get(root) { Scan { - paths, + path, file_options: options, predicate, scan_type, @@ -68,7 +68,7 @@ pub fn collect_fingerprints( let slice = (scan_type.skip_rows(), options.n_rows); let predicate = predicate.map(|node| node_to_expr(node, expr_arena)); let fp = FileFingerPrint { - paths: paths.clone(), + path: path.clone(), predicate, slice, }; @@ -96,7 +96,7 @@ pub fn find_column_union_and_fingerprints( use ALogicalPlan::*; match lp_arena.get(root) { Scan { - paths, + path, file_options: options, predicate, file_info, @@ -106,7 +106,7 @@ pub fn find_column_union_and_fingerprints( let slice = (scan_type.skip_rows(), options.n_rows); let predicate = predicate.map(|node| node_to_expr(node, expr_arena)); process_with_columns( - paths, + path, options.with_columns.as_deref(), predicate, slice, @@ -204,7 +204,7 @@ impl FileCacher { let lp = lp_arena.take(root); match lp { ALogicalPlan::Scan { - paths, + path, file_info, predicate, output_schema, @@ -213,7 +213,7 @@ impl FileCacher { } => { let predicate_expr = predicate.map(|node| node_to_expr(node, expr_arena)); let finger_print = FileFingerPrint { - paths, + path, predicate: predicate_expr, slice: (scan_type.skip_rows(), options.n_rows), }; @@ -230,7 +230,7 @@ impl FileCacher { options.with_columns = with_columns; let lp = ALogicalPlan::Scan { - paths: finger_print.paths.clone(), + path: finger_print.path.clone(), file_info, output_schema, predicate, diff --git a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs index 615d3f6dcc8a..eec0ddaff940 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs @@ -225,7 +225,7 @@ impl<'a> PredicatePushDown<'a> { Ok(lp) } Scan { - paths, + path, file_info, predicate, scan_type, @@ -235,13 +235,12 @@ impl<'a> PredicatePushDown<'a> { let local_predicates = partition_by_full_context(&mut acc_predicates, expr_arena); let predicate = predicate_at_scan(acc_predicates, predicate, expr_arena); - // TODO! this still assumes a single file. Fix hive partitioning for multiple files if let (Some(hive_part_stats), Some(predicate)) = (file_info.hive_parts.as_deref(), predicate) { if let Some(io_expr) = self.hive_partition_eval.unwrap()(predicate, expr_arena) { if let Some(stats_evaluator) = io_expr.as_stats_evaluator() { if !stats_evaluator.should_read(hive_part_stats.get_statistics())? { if self.verbose { - eprintln!("hive partitioning: skipped: {}", paths[0].display()) + eprintln!("hive partitioning: skipped: {}", path.display()) } let schema = output_schema.as_ref().unwrap_or(&file_info.schema); let df = DataFrame::from(schema.as_ref()); @@ -268,7 +267,7 @@ impl<'a> PredicatePushDown<'a> { let lp = if do_optimization { Scan { - paths, + path, file_info, predicate, file_options: options, @@ -277,7 +276,7 @@ impl<'a> PredicatePushDown<'a> { } } else { let lp = Scan { - paths, + path, file_info, predicate: None, file_options: options, diff --git a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/utils.rs b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/utils.rs index 9d83ec2edfa4..58179b8aae8b 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/utils.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/utils.rs @@ -213,7 +213,7 @@ fn rename_predicate_columns_due_to_aliased_projection( ) -> LoopBehavior { let projection_aexpr = expr_arena.get(projection_node); if let AExpr::Alias(_, alias_name) = projection_aexpr { - let alias_name = alias_name.clone(); + let alias_name = alias_name.as_ref(); let projection_leaves = aexpr_to_leaf_names(projection_node, expr_arena); // this means the leaf is a literal @@ -223,10 +223,9 @@ fn rename_predicate_columns_due_to_aliased_projection( // if this alias refers to one of the predicates in the upper nodes // we rename the column of the predicate before we push it downwards. - if let Some(predicate) = acc_predicates.remove(&alias_name) { + if let Some(predicate) = acc_predicates.remove(alias_name) { if projection_maybe_boundary { local_predicates.push(predicate); - remove_predicate_refers_to_alias(acc_predicates, local_predicates, &alias_name); return LoopBehavior::Continue; } if projection_leaves.len() == 1 { @@ -241,34 +240,26 @@ fn rename_predicate_columns_due_to_aliased_projection( // on this projected column so we do filter locally. local_predicates.push(predicate) } - } - - remove_predicate_refers_to_alias(acc_predicates, local_predicates, &alias_name); - } - LoopBehavior::Nothing -} + } else { + // we could not find the alias name + // that could still mean that a predicate that is a complicated binary expression + // refers to the aliased name. If we find it, we remove it for now + // TODO! rename the expression. + let mut remove_names = vec![]; + for (composed_name, _) in acc_predicates.iter() { + if key_has_name(composed_name, alias_name) { + remove_names.push(composed_name.clone()); + break; + } + } -/// we could not find the alias name -/// that could still mean that a predicate that is a complicated binary expression -/// refers to the aliased name. If we find it, we remove it for now -/// TODO! rename the expression. -fn remove_predicate_refers_to_alias( - acc_predicates: &mut PlHashMap, Node>, - local_predicates: &mut Vec, - alias_name: &str, -) { - let mut remove_names = vec![]; - for (composed_name, _) in acc_predicates.iter() { - if key_has_name(composed_name, alias_name) { - remove_names.push(composed_name.clone()); - break; + for composed_name in remove_names { + let predicate = acc_predicates.remove(&composed_name).unwrap(); + local_predicates.push(predicate) + } } } - - for composed_name in remove_names { - let predicate = acc_predicates.remove(&composed_name).unwrap(); - local_predicates.push(predicate) - } + LoopBehavior::Nothing } /// Implementation for both Hstack and Projection diff --git a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/mod.rs b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/mod.rs index b1ed7963aab6..3ff672683211 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/mod.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/mod.rs @@ -377,7 +377,7 @@ impl ProjectionPushDown { Ok(PythonScan { options, predicate }) }, Scan { - paths, + path, file_info, scan_type, predicate, @@ -421,7 +421,7 @@ impl ProjectionPushDown { } let lp = Scan { - paths, + path, file_info, output_schema, scan_type, diff --git a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/projection.rs b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/projection.rs index 8cf418011888..e49ba5cb5642 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/projection.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/projection.rs @@ -8,31 +8,6 @@ fn is_count(node: Node, expr_arena: &Arena) -> bool { } } -/// In this function we check a double projection case -/// df -/// .select(col("foo").alias("bar")) -/// .select(col("bar") -/// -/// In this query, bar cannot pass this projection, as it would not exist in DF. -/// THE ORDER IS IMPORTANT HERE! -/// this removes projection names, so any checks to upstream names should -/// be done before this branch. -fn check_double_projection( - expr: &Node, - expr_arena: &mut Arena, - acc_projections: &mut Vec, - projected_names: &mut PlHashSet>, -) { - for (_, ae) in (&*expr_arena).iter(*expr) { - if let AExpr::Alias(_, name) = ae { - if projected_names.remove(name) { - acc_projections - .retain(|expr| !aexpr_to_leaf_names(*expr, expr_arena).contains(name)); - } - } - } -} - #[allow(clippy::too_many_arguments)] pub(super) fn process_projection( proj_pd: &mut ProjectionPushDown, @@ -54,14 +29,6 @@ pub(super) fn process_projection( // simply select the first column let (first_name, _) = input_schema.try_get_at_index(0)?; let expr = expr_arena.add(AExpr::Column(Arc::from(first_name.as_str()))); - if !acc_projections.is_empty() { - check_double_projection( - &exprs[0], - expr_arena, - &mut acc_projections, - &mut projected_names, - ); - } add_expr_to_accumulated(expr, &mut acc_projections, &mut projected_names, expr_arena); local_projection.push(exprs[0]); } else { @@ -81,7 +48,24 @@ pub(super) fn process_projection( continue; } - check_double_projection(e, expr_arena, &mut acc_projections, &mut projected_names); + // in this branch we check a double projection case + // df + // .select(col("foo").alias("bar")) + // .select(col("bar") + // + // In this query, bar cannot pass this projection, as it would not exist in DF. + // THE ORDER IS IMPORTANT HERE! + // this removes projection names, so any checks to upstream names should + // be done before this branch. + for (_, ae) in (&*expr_arena).iter(*e) { + if let AExpr::Alias(_, name) = ae { + if projected_names.remove(name) { + acc_projections.retain(|expr| { + !aexpr_to_leaf_names(*expr, expr_arena).contains(name) + }); + } + } + } } // do local as we still need the effect of the projection // e.g. a projection is more than selecting a column, it can diff --git a/crates/polars-plan/src/logical_plan/optimizer/slice_pushdown_lp.rs b/crates/polars-plan/src/logical_plan/optimizer/slice_pushdown_lp.rs index 63e5a9fdd2af..66887f25ee62 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/slice_pushdown_lp.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/slice_pushdown_lp.rs @@ -121,7 +121,7 @@ impl SlicePushDown { } #[cfg(feature = "csv")] (Scan { - paths, + path, file_info, output_schema, file_options: mut options, @@ -132,7 +132,7 @@ impl SlicePushDown { csv_options.skip_rows += state.offset as usize; let lp = Scan { - paths, + path, file_info, output_schema, scan_type: FileScan::Csv {options: csv_options}, @@ -143,7 +143,7 @@ impl SlicePushDown { }, // TODO! we currently skip slice pushdown if there is a predicate. (Scan { - paths, + path, file_info, output_schema, file_options: mut options, @@ -152,7 +152,7 @@ impl SlicePushDown { }, Some(state)) if state.offset == 0 && predicate.is_none() => { options.n_rows = Some(state.len as usize); let lp = Scan { - paths, + path, file_info, output_schema, predicate, diff --git a/crates/polars-plan/src/logical_plan/projection.rs b/crates/polars-plan/src/logical_plan/projection.rs index 3c9832e4b4fa..436c31b354c2 100644 --- a/crates/polars-plan/src/logical_plan/projection.rs +++ b/crates/polars-plan/src/logical_plan/projection.rs @@ -4,7 +4,6 @@ use polars_core::utils::get_supertype; use super::*; use crate::prelude::function_expr::FunctionExpr; -use crate::utils::expr_output_name; /// This replace the wildcard Expr with a Column Expr. It also removes the Exclude Expr from the /// expression chain. @@ -355,9 +354,21 @@ fn prepare_excluded( } // exclude group_by keys - for expr in keys.iter() { - if let Ok(name) = expr_output_name(expr) { - exclude.insert(name.clone()); + for mut expr in keys.iter() { + // Allow a number of aliases of a column expression, still exclude column from aggregation + loop { + match expr { + Expr::Column(name) => { + exclude.insert(name.clone()); + break; + }, + Expr::Alias(e, _) => { + expr = e; + }, + _ => { + break; + }, + } } } Ok(exclude) diff --git a/crates/polars-plan/src/logical_plan/schema.rs b/crates/polars-plan/src/logical_plan/schema.rs index 223ec2df7fd6..e97377c41e75 100644 --- a/crates/polars-plan/src/logical_plan/schema.rs +++ b/crates/polars-plan/src/logical_plan/schema.rs @@ -45,9 +45,6 @@ impl LogicalPlan { #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct FileInfo { pub schema: SchemaRef, - // Stores the schema used for the reader, as the main schema can contain - // extra hive columns. - pub reader_schema: SchemaRef, // - known size // - estimated size pub row_estimation: (Option, usize), @@ -57,8 +54,7 @@ pub struct FileInfo { impl FileInfo { pub fn new(schema: SchemaRef, row_estimation: (Option, usize)) -> Self { Self { - schema: schema.clone(), - reader_schema: schema.clone(), + schema, row_estimation, hive_parts: None, } diff --git a/crates/polars-sql/Cargo.toml b/crates/polars-sql/Cargo.toml index b03757ec4435..8e2feb97ba53 100644 --- a/crates/polars-sql/Cargo.toml +++ b/crates/polars-sql/Cargo.toml @@ -12,7 +12,7 @@ description = "SQL transpiler for Polars. Converts SQL to Polars logical plans" arrow = { workspace = true } polars-core = { workspace = true } polars-error = { workspace = true } -polars-lazy = { workspace = true, features = ["strings", "cross_join", "trigonometry", "abs", "round_series", "log", "regex", "is_in", "meta", "cum_agg", "dtype-date"] } +polars-lazy = { workspace = true, features = ["strings", "cross_join", "trigonometry", "abs", "round_series", "log", "regex", "is_in", "meta", "cum_agg"] } polars-plan = { workspace = true } rand = { workspace = true } diff --git a/crates/polars-sql/src/functions.rs b/crates/polars-sql/src/functions.rs index 84f6987c3e0b..70f9cabe7669 100644 --- a/crates/polars-sql/src/functions.rs +++ b/crates/polars-sql/src/functions.rs @@ -2,8 +2,8 @@ use polars_core::prelude::{polars_bail, polars_err, PolarsResult}; use polars_lazy::dsl::Expr; use polars_plan::dsl::{coalesce, count, when}; use polars_plan::logical_plan::LiteralValue; +use polars_plan::prelude::lit; use polars_plan::prelude::LiteralValue::Null; -use polars_plan::prelude::{lit, StrptimeOptions}; use sqlparser::ast::{ Expr as SqlExpr, Function as SQLFunction, FunctionArg, FunctionArgExpr, Value as SqlValue, WindowSpec, WindowType, @@ -217,16 +217,6 @@ pub(crate) enum PolarsSqlFunctions { /// ``` Radians, - // ---- - // Date Functions - // ---- - /// SQL 'date' function - /// ```sql - /// SELECT DATE('2021-03-15') from df; - /// SELECT DATE('2021-03', '%Y-%m') from df; - /// ``` - Date, - // ---- // String functions // ---- @@ -481,7 +471,6 @@ impl PolarsSqlFunctions { "cot", "cotd", "count", - "date", "degrees", "ends_with", "exp", @@ -570,11 +559,6 @@ impl PolarsSqlFunctions { "nullif" => Self::NullIf, "coalesce" => Self::Coalesce, - // ---- - // Date functions - // ---- - "date" => Self::Date, - // ---- // String functions // ---- @@ -734,14 +718,6 @@ impl SqlFunctionVisitor<'_> { }), _ => polars_bail!(InvalidOperation:"Invalid number of arguments for RegexpLike: {}",function.args.len()), }, - Date => match function.args.len() { - 1 => self.visit_unary(|e| e.str().to_date(StrptimeOptions::default())), - 2 => self.visit_binary(|e, fmt| e.str().to_date(fmt)), - _ => polars_bail!(InvalidOperation: - "Invalid number of arguments for Date: {}", - function.args.len() - ), - }, RTrim => match function.args.len() { 1 => self.visit_unary(|e| e.str().strip_chars_end(lit(Null))), 2 => self.visit_binary(|e, s| e.str().strip_chars_end(s)), @@ -1100,24 +1076,6 @@ impl FromSqlExpr for String { } } -impl FromSqlExpr for StrptimeOptions { - fn from_sql_expr(expr: &SqlExpr, _: &mut SQLContext) -> PolarsResult - where - Self: Sized, - { - match expr { - SqlExpr::Value(v) => match v { - SqlValue::SingleQuotedString(s) => Ok(StrptimeOptions { - format: Some(s.clone()), - ..StrptimeOptions::default() - }), - _ => polars_bail!(ComputeError: "can't parse literal {:?}", v), - }, - _ => polars_bail!(ComputeError: "can't parse literal {:?}", expr), - } - } -} - impl FromSqlExpr for Expr { fn from_sql_expr(expr: &SqlExpr, ctx: &mut SQLContext) -> PolarsResult where diff --git a/crates/polars-time/Cargo.toml b/crates/polars-time/Cargo.toml index abd23e909aa3..77e8e861a6a1 100644 --- a/crates/polars-time/Cargo.toml +++ b/crates/polars-time/Cargo.toml @@ -10,7 +10,7 @@ description = "Time related code for the Polars DataFrame library" [dependencies] arrow = { workspace = true, features = ["compute", "temporal"] } -polars-core = { workspace = true, features = ["dtype-datetime", "dtype-duration", "dtype-time", "dtype-date"] } +polars-core = { workspace = true, default-features = false, features = ["dtype-datetime", "dtype-duration", "dtype-time", "dtype-date"] } polars-error = { workspace = true } polars-ops = { workspace = true } polars-utils = { workspace = true } diff --git a/crates/polars/Cargo.toml b/crates/polars/Cargo.toml index 165fea5806f0..19de558f8177 100644 --- a/crates/polars/Cargo.toml +++ b/crates/polars/Cargo.toml @@ -14,7 +14,7 @@ description = "DataFrame library based on Apache Arrow" polars-algo = { workspace = true, optional = true } polars-core = { workspace = true } polars-io = { workspace = true, optional = true } -polars-lazy = { workspace = true, optional = true } +polars-lazy = { workspace = true, default-features = false, optional = true } polars-ops = { workspace = true } polars-sql = { workspace = true, optional = true } polars-time = { workspace = true, optional = true } @@ -113,7 +113,7 @@ sort_multiple = ["polars-core/sort_multiple"] approx_unique = ["polars-lazy?/approx_unique", "polars-ops/approx_unique"] is_in = ["polars-lazy?/is_in"] zip_with = ["polars-core/zip_with", "polars-ops/zip_with"] -round_series = ["polars-ops/round_series", "polars-lazy?/round_series"] +round_series = ["polars-core/round_series", "polars-lazy?/round_series", "polars-ops/round_series"] checked_arithmetic = ["polars-core/checked_arithmetic"] repeat_by = ["polars-ops/repeat_by", "polars-lazy?/repeat_by"] is_first_distinct = ["polars-lazy?/is_first_distinct", "polars-ops/is_first_distinct"] @@ -139,12 +139,12 @@ string_encoding = ["polars-ops/string_encoding", "polars-core/strings"] binary_encoding = ["polars-ops/binary_encoding"] group_by_list = ["polars-core/group_by_list", "polars-ops/group_by_list"] lazy_regex = ["polars-lazy?/regex"] -cum_agg = ["polars-ops/cum_agg", "polars-lazy?/cum_agg"] +cum_agg = ["polars-core/cum_agg", "polars-core/cum_agg"] rolling_window = ["polars-core/rolling_window", "polars-lazy?/rolling_window", "polars-time/rolling_window"] interpolate = ["polars-ops/interpolate", "polars-lazy?/interpolate"] rank = ["polars-lazy?/rank", "polars-ops/rank"] -diff = ["polars-ops/diff", "polars-lazy?/diff"] -pct_change = ["polars-ops/pct_change", "polars-lazy?/pct_change"] +diff = ["polars-core/diff", "polars-lazy?/diff", "polars-ops/diff"] +pct_change = ["polars-core/pct_change", "polars-lazy?/pct_change"] moment = ["polars-core/moment", "polars-lazy?/moment", "polars-ops/moment"] range = ["polars-lazy?/range"] true_div = ["polars-lazy?/true_div"] @@ -152,7 +152,7 @@ diagonal_concat = ["polars-core/diagonal_concat", "polars-lazy?/diagonal_concat" horizontal_concat = ["polars-core/horizontal_concat"] abs = ["polars-core/abs", "polars-lazy?/abs"] dynamic_group_by = ["polars-core/dynamic_group_by", "polars-lazy?/dynamic_group_by"] -ewma = ["polars-ops/ewma", "polars-lazy?/ewma"] +ewma = ["polars-core/ewma", "polars-lazy?/ewma"] dot_diagram = ["polars-lazy?/dot_diagram"] dataframe_arithmetic = ["polars-core/dataframe_arithmetic"] product = ["polars-core/product"] @@ -190,7 +190,6 @@ fused = ["polars-ops/fused", "polars-lazy?/fused"] list_sets = ["polars-lazy?/list_sets"] list_any_all = ["polars-lazy?/list_any_all"] list_drop_nulls = ["polars-lazy?/list_drop_nulls"] -list_sample = ["polars-lazy?/list_sample"] cutqcut = ["polars-lazy?/cutqcut"] rle = ["polars-lazy?/rle"] extract_groups = ["polars-lazy?/extract_groups"] diff --git a/crates/polars/src/lib.rs b/crates/polars/src/lib.rs index df7da1257ff1..0eaf13c040ce 100644 --- a/crates/polars/src/lib.rs +++ b/crates/polars/src/lib.rs @@ -278,11 +278,11 @@ //! - `fmt` - Activate [`DataFrame`] formatting //! //! [`UInt64Chunked`]: crate::datatypes::UInt64Chunked -//! [`cumsum`]: polars_ops::prelude::cumsum -//! [`cummin`]: polars_ops::prelude::cummin -//! [`cummax`]: polars_ops::prelude::cummax +//! [`cumsum`]: crate::series::Series::cumsum +//! [`cummin`]: crate::series::Series::cummin +//! [`cummax`]: crate::series::Series::cummax //! [`rolling_mean`]: crate::series::Series#method.rolling_mean -//! [`diff`]: polars_ops::prelude::diff +//! [`diff`]: crate::series::Series::diff //! [`List`]: crate::datatypes::DataType::List //! [`Struct`]: crate::datatypes::DataType::Struct //! diff --git a/docs/_build/scripts/people.py b/docs/_build/scripts/people.py index 72ba55c37f56..10186549d4d8 100644 --- a/docs/_build/scripts/people.py +++ b/docs/_build/scripts/people.py @@ -6,7 +6,8 @@ auth = Auth.Token(token) if token else None g = Github(auth=auth) -ICON_TEMPLATE = '{login}' +ICON_TEMPLATE = "[![{login}]({avatar_url}){{.contributor_icon}}]({html_url})" + def get_people_md(): repo = g.get_repo("pola-rs/polars") diff --git a/docs/index.md b/docs/index.md index c5c3dfb1bfbf..2621ba4ee11d 100644 --- a/docs/index.md +++ b/docs/index.md @@ -52,6 +52,10 @@ See the results in h2oai's [db-benchmark](https://duckdblabs.github.io/db-benchm {{code_block('home/example','example',['scan_csv','filter','group_by','collect'])}} +## Sponsors + +[](https://www.xomnia.com/)   [](https://www.jetbrains.com) + ## Community `Polars` has a very active community with frequent releases (approximately weekly). Below are some of the top contributors to the project: diff --git a/docs/requirements.txt b/docs/requirements.txt index 35d0629394c0..2c317b06415b 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -3,7 +3,7 @@ pyarrow graphviz matplotlib -mkdocs-material==9.4.6 +mkdocs-material==9.2.5 mkdocs-macros-plugin==1.0.4 -markdown-exec[ansi]==1.7.0 -PyGithub==2.1.1 +markdown-exec[ansi]==1.6.0 +PyGithub==1.59.1 diff --git a/docs/src/python/user-guide/expressions/lists.py b/docs/src/python/user-guide/expressions/lists.py index de4b97fc8d87..5703a01a5518 100644 --- a/docs/src/python/user-guide/expressions/lists.py +++ b/docs/src/python/user-guide/expressions/lists.py @@ -97,10 +97,7 @@ pl.Series("Array_1", [[1, 3], [2, 5]]), pl.Series("Array_2", [[1, 7, 3], [8, 1, 0]]), ], - schema={ - "Array_1": pl.Array(inner=pl.Int64, width=2), - "Array_2": pl.Array(inner=pl.Int64, width=3), - }, + schema={"Array_1": pl.Array(2, pl.Int64), "Array_2": pl.Array(3, pl.Int64)}, ) print(array_df) # --8<-- [end:array_df] diff --git a/docs/src/python/user-guide/expressions/operators.py b/docs/src/python/user-guide/expressions/operators.py index 92bf57952332..6f617487c81e 100644 --- a/docs/src/python/user-guide/expressions/operators.py +++ b/docs/src/python/user-guide/expressions/operators.py @@ -34,7 +34,7 @@ # --8<-- [start:logical] df_logical = df.select( (pl.col("nrs") > 1).alias("nrs > 1"), - (pl.col("random") <= 0.5).alias("random <= .5"), + (pl.col("random") <= 0.5).alias("random < .5"), (pl.col("nrs") != 1).alias("nrs != 1"), (pl.col("nrs") == 1).alias("nrs == 1"), ((pl.col("random") <= 0.5) & (pl.col("nrs") > 1)).alias("and_expr"), # and diff --git a/docs/src/python/user-guide/sql/intro.py b/docs/src/python/user-guide/sql/intro.py index 143ec75c4f76..3b59ac9e70d1 100644 --- a/docs/src/python/user-guide/sql/intro.py +++ b/docs/src/python/user-guide/sql/intro.py @@ -39,7 +39,7 @@ # --8<-- [end:execute] # --8<-- [start:prepare_multiple_sources] -with open("docs/data/products_categories.json", "w") as temp_file: +with open("products_categories.json", "w") as temp_file: json_data = """{"product_id": 1, "category": "Category 1"} {"product_id": 2, "category": "Category 1"} {"product_id": 3, "category": "Category 2"} @@ -48,7 +48,7 @@ temp_file.write(json_data) -with open("docs/data/products_masterdata.csv", "w") as temp_file: +with open("products_masterdata.csv", "w") as temp_file: csv_data = """product_id,product_name 1,Product A 2,Product B @@ -73,19 +73,19 @@ # sales_data is a Pandas DataFrame with schema {'product_id': Int64, 'sales': Int64} ctx = pl.SQLContext( - products_masterdata=pl.scan_csv("docs/data/products_masterdata.csv"), - products_categories=pl.scan_ndjson("docs/data/products_categories.json"), + products_masterdata=pl.scan_csv("products_masterdata.csv"), + products_categories=pl.scan_ndjson("products_categories.json"), sales_data=pl.from_pandas(sales_data), eager_execution=True, ) query = """ -SELECT +SELECT product_id, product_name, category, sales -FROM +FROM products_masterdata LEFT JOIN products_categories USING (product_id) LEFT JOIN sales_data USING (product_id) @@ -95,6 +95,6 @@ # --8<-- [end:execute_multiple_sources] # --8<-- [start:clean_multiple_sources] -os.remove("docs/data/products_categories.json") -os.remove("docs/data/products_masterdata.csv") +os.remove("products_categories.json") +os.remove("products_masterdata.csv") # --8<-- [end:clean_multiple_sources] diff --git a/docs/user-guide/concepts/lazy-vs-eager.md b/docs/user-guide/concepts/lazy-vs-eager.md index 987d07aa8807..1b84a0272aa5 100644 --- a/docs/user-guide/concepts/lazy-vs-eager.md +++ b/docs/user-guide/concepts/lazy-vs-eager.md @@ -6,7 +6,7 @@ In this example we use the eager API to: -1. Read the iris [dataset](https://archive.ics.uci.edu/dataset/53/iris). +1. Read the iris [dataset](https://archive.ics.uci.edu/ml/datasets/iris). 1. Filter the dataset based on sepal length 1. Calculate the mean of the sepal width per species diff --git a/docs/user-guide/expressions/plugins.md b/docs/user-guide/expressions/plugins.md deleted file mode 100644 index 3c40ef2b8cf4..000000000000 --- a/docs/user-guide/expressions/plugins.md +++ /dev/null @@ -1,237 +0,0 @@ -# Expression plugins - -Expression plugins are the preferred way to create user defined functions. They allow you to compile a rust function -and register that as an expression into the polars library. The polars engine will dynamically link your function at runtime -and your expression will run almost as fast as native expressions. Note that this works without any interference of python -and thus no GIL contention. - -They will benefit from the same benefits default expressions have: - -- Optimization -- Parallelism -- Rust native performance - -To get started we will see what is needed to create a custom expression. - -## Our first custom expression: Pig Latin - -For our first expression we are going to create a pig latin converter. Pig latin is a silly language where in every word -the first letter is removed, added to the back and finally "ay" is added. So the word "pig" would convert to "igpay". - -We could of course already do that with expressions, e.g. `col("name").str.slice(1) + col("name").str.slice(0, 1) + "ay"`, -but a specialized function for this would perform better and allows us to learn about the plugins. - -### Setting up - -We start with a new library as the following `Cargo.toml` file - -```toml -[package] -name = "expression_lib" -version = "0.1.0" -edition = "2021" - -[lib] -name = "expression_lib" -crate-type = ["cdylib"] - -[dependencies] -polars = { version = "*" } -pyo3 = { version = "*", features = ["extension-module"] } -pyo3-polars = { version = "*", features = ["derive"] } -serde = { version = "*", features = ["derive"] } -``` - -### Writing the expression - -In this library we create a helper function that converts a `&str` to pig-latin, and we create the function that we will -expose as an expression. To expose a function we must add the `#[polars_expr(output=DataType)]` attribute and the function -must always accept `inputs: &[Series]` as its first argument. - -```rust -use polars::prelude::*; -use pyo3_polars::derive::polars_expr; -use std::fmt::Write; - -fn pig_latin_str(value: &str, output: &mut String) { - if let Some(first_char) = value.chars().next() { - write!(output, "{}{}ay", &value[1..], first_char).unwrap() - } -} - -#[polars_expr(output_type=Utf8)] -fn pig_latinnify(inputs: &[Series]) -> PolarsResult { - let ca = inputs[0].utf8()?; - let out: Utf8Chunked = ca.apply_to_buffer(pig_latin_str); - Ok(out.into_series()) -} -``` - -This is all that is needed on the rust side. On the python side we must setup a folder with the same name as defined in -the `Cargo.toml`, in this case "expression_lib". We will create a folder in the same directory as our rust `src` folder -named `expression_lib` and we create an `expression_lib/init.py`. - -Then we create a new class `Language` that will hold the expressions for our new `expr.language` namespace. The function -name of our expression can be registered. Note that it is important that this name is correct, otherwise the main polars -package cannot resolve the function name. Furthermore we can set additional keyword arguments that explain to polars how -this expression behaves. In this case we tell polars that this function is elementwise. This allows polars to run this -expression in batches. Whereas for other operations this would not be allowed, think for instance of a sort, or a slice. - -```python -import polars as pl -from polars.type_aliases import IntoExpr -from polars.utils.udfs import _get_shared_lib_location - -# boilerplate needed to inform polars of the location of binary wheel. -lib = _get_shared_lib_location(__file__) - -@pl.api.register_expr_namespace("language") -class Language: - def __init__(self, expr: pl.Expr): - self._expr = expr - - def pig_latinnify(self) -> pl.Expr: - return self._expr._register_plugin( - lib=lib, - symbol="pig_latinnify", - is_elementwise=True, - ) -``` - -We can then compile this library in our environment by installing `maturin` and running `maturin develop --release`. - -And that's it. Our expression is ready to use! - -```python -import polars as pl -from expression_lib import Language - -df = pl.DataFrame( - { - "convert": ["pig", "latin", "is", "silly"], - } -) - - -out = df.with_columns( - pig_latin=pl.col("convert").language.pig_latinnify(), -) -``` - -## Accepting kwargs - -If you want to accept `kwargs` (keyword arguments) in a polars expression, all you have to do is define a rust `struct` -and make sure that it derives `serde::Deserialize`. - -```rust -/// Provide your own kwargs struct with the proper schema and accept that type -/// in your plugin expression. -#[derive(Deserialize)] -pub struct MyKwargs { - float_arg: f64, - integer_arg: i64, - string_arg: String, - boolean_arg: bool, -} - -/// If you want to accept `kwargs`. You define a `kwargs` argument -/// on the second position in you plugin. You can provide any custom struct that is deserializable -/// with the pickle protocol (on the rust side). -#[polars_expr(output_type=Utf8)] -fn append_kwargs(input: &[Series], kwargs: MyKwargs) -> PolarsResult { - let input = &input[0]; - let input = input.cast(&DataType::Utf8)?; - let ca = input.utf8().unwrap(); - - Ok(ca - .apply_to_buffer(|val, buf| { - write!( - buf, - "{}-{}-{}-{}-{}", - val, kwargs.float_arg, kwargs.integer_arg, kwargs.string_arg, kwargs.boolean_arg - ) - .unwrap() - }) - .into_series()) -} -``` - -On the python side the kwargs can be passed when we register the plugin. - -```python -@pl.api.register_expr_namespace("my_expr") -class MyCustomExpr: - def __init__(self, expr: pl.Expr): - self._expr = expr - - def append_args( - self, - float_arg: float, - integer_arg: int, - string_arg: str, - boolean_arg: bool, - ) -> pl.Expr: - """ - This example shows how arguments other than `Series` can be used. - """ - return self._expr._register_plugin( - lib=lib, - args=[], - kwargs={ - "float_arg": float_arg, - "integer_arg": integer_arg, - "string_arg": string_arg, - "boolean_arg": boolean_arg, - }, - symbol="append_kwargs", - is_elementwise=True, - ) -``` - -## Output data types - -Output data types ofcourse don't have to be fixed. They often depend on the input types of an expression. To accommodate -this you can provide the `#[polars_expr()]` macro with an `output_type_func` argument that points to a function. This -function can map input fields `&[Field]` to an output `Field` (name and data type). - -In the snippet below is an example where we use the utility `FieldsMapper` to help with this mapping. - -```rust -use polars_plan::dsl::FieldsMapper; - -fn haversine_output(input_fields: &[Field]) -> PolarsResult { - FieldsMapper::new(input_fields).map_to_float_dtype() -} - -#[polars_expr(output_type_func=haversine_output)] -fn haversine(inputs: &[Series]) -> PolarsResult { - let out = match inputs[0].dtype() { - DataType::Float32 => { - let start_lat = inputs[0].f32().unwrap(); - let start_long = inputs[1].f32().unwrap(); - let end_lat = inputs[2].f32().unwrap(); - let end_long = inputs[3].f32().unwrap(); - crate::distances::naive_haversine(start_lat, start_long, end_lat, end_long)? - .into_series() - } - DataType::Float64 => { - let start_lat = inputs[0].f64().unwrap(); - let start_long = inputs[1].f64().unwrap(); - let end_lat = inputs[2].f64().unwrap(); - let end_long = inputs[3].f64().unwrap(); - crate::distances::naive_haversine(start_lat, start_long, end_lat, end_long)? - .into_series() - } - _ => polars_bail!(InvalidOperation: "only supported for float types"), - }; - Ok(out) -} -``` - -That's all you need to know to get started. Take a look at this [repo](https://github.com/pola-rs/pyo3-polars/tree/main/example/derive_expression) to see how this all fits together. - -## Community plugins - -Here is a curated (non-exhaustive) list of community implemented plugins. - -- [polars-business](https://github.com/MarcoGorelli/polars-business) Polars extension offering utilities for business day operations diff --git a/docs/user-guide/expressions/user-defined-functions.md b/docs/user-guide/expressions/user-defined-functions.md index 3dd43f035f85..dd83cb13c382 100644 --- a/docs/user-guide/expressions/user-defined-functions.md +++ b/docs/user-guide/expressions/user-defined-functions.md @@ -1,4 +1,4 @@ -# User-defined functions (Python) +# User-defined functions !!! warning "Not updated for Python Polars `0.19.0`" diff --git a/docs/user-guide/io/cloud-storage.md b/docs/user-guide/io/cloud-storage.md index ba686a5a0f11..a10226a99e65 100644 --- a/docs/user-guide/io/cloud-storage.md +++ b/docs/user-guide/io/cloud-storage.md @@ -32,7 +32,7 @@ Polars can scan a Parquet file in lazy mode from cloud storage. We may need to p This query creates a `LazyFrame` without downloading the file. In the `LazyFrame` we have access to file metadata such as the schema. Polars uses the `object_store.rs` library internally to manage the interface with the cloud storage providers and so no extra dependencies are required in Python to scan a cloud Parquet file. -If we create a lazy query with [predicate and projection pushdowns](../lazy/optimizations.md), the query optimizer will apply them before the file is downloaded. This can significantly reduce the amount of data that needs to be downloaded. The query evaluation is triggered by calling `collect`. +If we create a lazy query with [predicate and projection pushdowns](../lazy/optimizations.md), the query optimiszr will apply them before the file is downloaded. This can significantly reduce the amount of data that needs to be downloaded. The query evaluation is triggered by calling `collect`. {{code_block('user-guide/io/cloud-storage','scan_parquet_query',[])}} diff --git a/docs/user-guide/migration/pandas.md b/docs/user-guide/migration/pandas.md index a9a039a7b1d4..d6674ac43f06 100644 --- a/docs/user-guide/migration/pandas.md +++ b/docs/user-guide/migration/pandas.md @@ -147,20 +147,19 @@ called `hundredXValue` where the `value` column is multiplied by 100. In `Pandas` this would be: ```python -df.assign( - tenXValue=lambda df_: df_.value * 10, - hundredXValue=lambda df_: df_.value * 100 -) +df["tenXValue"] = df["value"] * 10 +df["hundredXValue"] = df["value"] * 100 ``` These column assignments are executed sequentially. -In `Polars` we add columns to `df` using the `.with_columns` method: +In `Polars` we add columns to `df` using the `.with_columns` method and name them with +the `.alias` method: ```python df.with_columns( - tenXValue=pl.col("value") * 10, - hundredXValue=pl.col("value") * 100, + (pl.col("value") * 10).alias("tenXValue"), + (pl.col("value") * 100).alias("hundredXValue"), ) ``` @@ -175,7 +174,7 @@ the values in column `a` based on a condition. When the value in column `c` is e In `Pandas` this would be: ```python -df.assign(a=lambda df_: df_.a.where(df_.c != 2, df_.b)) +df.loc[df["c"] == 2, "a"] = df.loc[df["c"] == 2, "b"] ``` while in `Polars` this would be: @@ -188,17 +187,21 @@ df.with_columns( ) ``` -`Polars` can compute every branch of an `if -> then -> otherwise` in +The `Polars` way is pure in that the original `DataFrame` is not modified. The `mask` is +also not computed twice as in `Pandas` (you could prevent this in `Pandas`, but that +would require setting a temporary variable). + +Additionally `Polars` can compute every branch of an `if -> then -> otherwise` in parallel. This is valuable, when the branches get more expensive to compute. #### Filtering We want to filter the dataframe `df` with housing data based on some criteria. -In `Pandas` you filter the dataframe by passing Boolean expressions to the `query` method: +In `Pandas` you filter the dataframe by passing Boolean expressions to the `loc` method: ```python -df.query('m2_living > 2500 and price < 300000') +df.loc[(df['sqft_living'] > 2500) & (df['price'] < 300000)] ``` while in `Polars` you call the `filter` method: diff --git a/docs/user-guide/transformations/time-series/timezones.md b/docs/user-guide/transformations/time-series/timezones.md index de5046d4cafd..a12b97c68dd9 100644 --- a/docs/user-guide/transformations/time-series/timezones.md +++ b/docs/user-guide/transformations/time-series/timezones.md @@ -12,13 +12,13 @@ hide: The `Datetime` datatype can have a time zone associated with it. Examples of valid time zones are: -- `None`: no time zone, also known as "time zone naive". -- `UTC`: Coordinated Universal Time. +- `None`: no time zone, also known as "time zone naive"; +- `UTC`: Coordinated Universal Time; - `Asia/Kathmandu`: time zone in "area/location" format. See the [list of tz database time zones](https://en.wikipedia.org/wiki/List_of_tz_database_time_zones) - to see what's available. - -Caution: Fixed offsets such as +02:00, should not be used for handling time zones. It's advised to use the "Area/Location" format mentioned above, as it can manage timezones more effectively. + to see what's available; +- `+01:00`: fixed offsets. May be useful when parsing, but you almost certainly want the "Area/Location" + format above instead as it will deal with irregularities such as DST (Daylight Saving Time) for you. Note that, because a `Datetime` can only have a single time zone, it is impossible to have a column with multiple time zones. If you are parsing data @@ -27,8 +27,8 @@ them all to a common time zone (`UTC`), see [parsing dates and times](parsing.md The main methods for setting and converting between time zones are: -- `dt.convert_time_zone`: convert from one time zone to another. -- `dt.replace_time_zone`: set/unset/change time zone. +- `dt.convert_time_zone`: convert from one time zone to another; +- `dt.replace_time_zone`: set/unset/change time zone; Let's look at some examples of common operations: diff --git a/examples/python_rust_compiled_function/Cargo.toml b/examples/python_rust_compiled_function/Cargo.toml index da8b5f37096a..13381f035e59 100644 --- a/examples/python_rust_compiled_function/Cargo.toml +++ b/examples/python_rust_compiled_function/Cargo.toml @@ -14,4 +14,4 @@ polars = { path = "../../crates/polars" } pyo3 = { workspace = true, features = ["extension-module"] } [build-dependencies] -pyo3-build-config = "0.20" +pyo3-build-config = "0.19" diff --git a/examples/read_parquet_cloud/Cargo.toml b/examples/read_parquet_cloud/Cargo.toml index bbb43403bd95..f6f5b56eb430 100644 --- a/examples/read_parquet_cloud/Cargo.toml +++ b/examples/read_parquet_cloud/Cargo.toml @@ -6,4 +6,4 @@ edition = "2021" [dependencies] polars = { path = "../../crates/polars", features = ["lazy", "aws", "parquet"] } -aws-creds = "0.36.0" +aws-creds = "0.35.0" diff --git a/examples/write_parquet_cloud/Cargo.toml b/examples/write_parquet_cloud/Cargo.toml index fe02ad8f8457..7bf6a24e46d3 100644 --- a/examples/write_parquet_cloud/Cargo.toml +++ b/examples/write_parquet_cloud/Cargo.toml @@ -6,5 +6,5 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -aws-creds = "0.36.0" +aws-creds = "0.35.0" polars = { path = "../../crates/polars", features = ["lazy", "aws", "parquet", "cloud_write"] } diff --git a/mkdocs.yml b/mkdocs.yml index 7734cbd11d5b..501d047b35e5 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -37,7 +37,6 @@ nav: - user-guide/expressions/window.md - user-guide/expressions/folds.md - user-guide/expressions/lists.md - - user-guide/expressions/plugins.md - user-guide/expressions/user-defined-functions.md - user-guide/expressions/structs.md - user-guide/expressions/numpy.md @@ -141,8 +140,8 @@ markdown_extensions: - pymdownx.details - attr_list - pymdownx.emoji: - emoji_index: !!python/name:material.extensions.emoji.twemoji - emoji_generator: !!python/name:material.extensions.emoji.to_svg + emoji_index: !!python/name:materialx.emoji.twemoji + emoji_generator: !!python/name:materialx.emoji.to_svg - pymdownx.superfences - pymdownx.tabbed: alternate_style: true diff --git a/py-polars/Cargo.lock b/py-polars/Cargo.lock index 8e2c6af72fd0..de57f81c1845 100644 --- a/py-polars/Cargo.lock +++ b/py-polars/Cargo.lock @@ -26,7 +26,8 @@ checksum = "aae1277d39aeec15cb388266ecc24b11c80469deae6067e17a1a7aa9e5c1f234" [[package]] name = "ahash" version = "0.8.3" -source = "git+https://github.com/orlp/aHash?branch=fix-arm-intrinsics#80685f88d3c120ef39fb3fde1c7786b044af5e8b" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f" dependencies = [ "cfg-if", "getrandom", @@ -1489,7 +1490,7 @@ dependencies = [ "snap", "streaming-decompression", "xxhash-rust", - "zstd 0.12.4", + "zstd", ] [[package]] @@ -1626,7 +1627,7 @@ dependencies = [ "simdutf8", "streaming-iterator", "strength_reduce", - "zstd 0.13.0", + "zstd", ] [[package]] @@ -1925,7 +1926,7 @@ dependencies = [ [[package]] name = "py-polars" -version = "0.19.10" +version = "0.19.8" dependencies = [ "ahash", "built", @@ -2394,9 +2395,9 @@ dependencies = [ [[package]] name = "simd-json" -version = "0.12.0" +version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0f07a84c7456b901b8dd2c1d44caca8b0fd2c2616206ee5acc9d9da61e8d9ec" +checksum = "474b451aaac1828ed12f6454a80fe58b940ae2998d10389d41533940a6f641bf" dependencies = [ "ahash", "getrandom", @@ -3114,16 +3115,7 @@ version = "0.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a27595e173641171fc74a1232b7b1c7a7cb6e18222c11e9dfb9888fa424c53c" dependencies = [ - "zstd-safe 6.0.6", -] - -[[package]] -name = "zstd" -version = "0.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bffb3309596d527cfcba7dfc6ed6052f1d39dfbd7c867aa2e865e4a449c10110" -dependencies = [ - "zstd-safe 7.0.0", + "zstd-safe", ] [[package]] @@ -3136,15 +3128,6 @@ dependencies = [ "zstd-sys", ] -[[package]] -name = "zstd-safe" -version = "7.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43747c7422e2924c11144d5229878b98180ef8b06cca4ab5af37afc8a8d8ea3e" -dependencies = [ - "zstd-sys", -] - [[package]] name = "zstd-sys" version = "2.0.8+zstd.1.5.5" diff --git a/py-polars/Cargo.toml b/py-polars/Cargo.toml index e6950298d037..b3d8a6222c0e 100644 --- a/py-polars/Cargo.toml +++ b/py-polars/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "py-polars" -version = "0.19.10" +version = "0.19.8" edition = "2021" [lib] @@ -138,7 +138,6 @@ binary_encoding = ["polars/binary_encoding"] list_sets = ["polars-lazy/list_sets"] list_any_all = ["polars/list_any_all"] list_drop_nulls = ["polars/list_drop_nulls"] -list_sample = ["polars/list_sample"] cutqcut = ["polars/cutqcut"] rle = ["polars/rle"] extract_groups = ["polars/extract_groups"] @@ -166,7 +165,6 @@ operations = [ "list_sets", "list_any_all", "list_drop_nulls", - "list_sample", "cutqcut", "rle", "extract_groups", @@ -237,9 +235,6 @@ lto = "thin" codegen-units = 1 lto = "fat" -[patch.crates-io] -ahash = { git = "https://github.com/orlp/aHash", branch = "fix-arm-intrinsics" } - # This is ignored here; would be set in .cargo/config.toml. # Should not be used when packaging # target-cpu = "native" diff --git a/py-polars/docs/requirements-docs.txt b/py-polars/docs/requirements-docs.txt index 592c5fc249b0..b755ae5dc02f 100644 --- a/py-polars/docs/requirements-docs.txt +++ b/py-polars/docs/requirements-docs.txt @@ -4,7 +4,7 @@ numpy pandas pyarrow -hypothesis==6.88.1 +hypothesis==6.87.1 sphinx==7.2.4 diff --git a/py-polars/docs/source/reference/dataframe/modify_select.rst b/py-polars/docs/source/reference/dataframe/modify_select.rst index dec81a18a308..c3d1f1b91c8c 100644 --- a/py-polars/docs/source/reference/dataframe/modify_select.rst +++ b/py-polars/docs/source/reference/dataframe/modify_select.rst @@ -47,7 +47,6 @@ Manipulation/selection DataFrame.replace DataFrame.replace_at_idx DataFrame.reverse - DataFrame.rolling DataFrame.row DataFrame.rows DataFrame.rows_by_key diff --git a/py-polars/docs/source/reference/expressions/list.rst b/py-polars/docs/source/reference/expressions/list.rst index f43401e20561..d56b44abcc30 100644 --- a/py-polars/docs/source/reference/expressions/list.rst +++ b/py-polars/docs/source/reference/expressions/list.rst @@ -34,7 +34,6 @@ The following methods are available under the `expr.list` attribute. Expr.list.mean Expr.list.min Expr.list.reverse - Expr.list.sample Expr.list.set_difference Expr.list.set_intersection Expr.list.set_symmetric_difference diff --git a/py-polars/docs/source/reference/lazyframe/modify_select.rst b/py-polars/docs/source/reference/lazyframe/modify_select.rst index 19cef033426f..1a1482ec4623 100644 --- a/py-polars/docs/source/reference/lazyframe/modify_select.rst +++ b/py-polars/docs/source/reference/lazyframe/modify_select.rst @@ -36,7 +36,6 @@ Manipulation/selection LazyFrame.merge_sorted LazyFrame.rename LazyFrame.reverse - LazyFrame.rolling LazyFrame.select LazyFrame.select_seq LazyFrame.set_sorted diff --git a/py-polars/docs/source/reference/series/list.rst b/py-polars/docs/source/reference/series/list.rst index ad766dd92eb9..7f3b709e80db 100644 --- a/py-polars/docs/source/reference/series/list.rst +++ b/py-polars/docs/source/reference/series/list.rst @@ -34,7 +34,6 @@ The following methods are available under the `Series.list` attribute. Series.list.mean Series.list.min Series.list.reverse - Series.list.sample Series.list.set_difference Series.list.set_intersection Series.list.set_symmetric_difference diff --git a/py-polars/docs/source/reference/testing.rst b/py-polars/docs/source/reference/testing.rst index 78ce4c96a0bd..4e268cec77dd 100644 --- a/py-polars/docs/source/reference/testing.rst +++ b/py-polars/docs/source/reference/testing.rst @@ -25,9 +25,7 @@ Polars provides some standard asserts for use with unit tests: :toctree: api/ testing.assert_frame_equal - testing.assert_frame_not_equal testing.assert_series_equal - testing.assert_series_not_equal Parametric testing diff --git a/py-polars/polars/__init__.py b/py-polars/polars/__init__.py index b0495b3ef1df..9abe028cbe25 100644 --- a/py-polars/polars/__init__.py +++ b/py-polars/polars/__init__.py @@ -29,7 +29,6 @@ DURATION_DTYPES, FLOAT_DTYPES, INTEGER_DTYPES, - NESTED_DTYPES, NUMERIC_DTYPES, TEMPORAL_DTYPES, Array, @@ -254,7 +253,6 @@ "DURATION_DTYPES", "FLOAT_DTYPES", "INTEGER_DTYPES", - "NESTED_DTYPES", "NUMERIC_DTYPES", "TEMPORAL_DTYPES", # polars.type_aliases diff --git a/py-polars/polars/config.py b/py-polars/polars/config.py index 3a5a1f1b07cb..2d557d914354 100644 --- a/py-polars/polars/config.py +++ b/py-polars/polars/config.py @@ -588,9 +588,7 @@ def set_fmt_table_cell_list_len(cls, n: int | None) -> type[Config]: """ Set the number of elements to display for List values. - Empty lists will always print "[]". Negative values will result in all values - being printed. A value of 0 will always "[…]" for lists with contents. A value - of 1 will print only the final item in the list. + Values less than 0 will result in all values being printed. Parameters ---------- diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index ffc8f0f4223e..86d9051e557f 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -1631,11 +1631,13 @@ def __getitem__( ): df = self[:, col_selection] return df.slice(row_selection, 1) + # df[2, "a"] + if isinstance(col_selection, str): + return self[col_selection][row_selection] - # df[:, "a"] + # column selection can be "a" and ["a", "b"] if isinstance(col_selection, str): - series = self.get_column(col_selection) - return series[row_selection] + col_selection = [col_selection] # df[:, 1] if isinstance(col_selection, int): @@ -1664,7 +1666,7 @@ def __getitem__( # select single column # df["foo"] if isinstance(item, str): - return self.get_column(item) + return wrap_s(self._df.column(item)) # df[idx] if isinstance(item, int): @@ -1862,7 +1864,7 @@ def item(self, row: int | None = None, column: int | str | None = None) -> Any: s = ( self._df.select_at_idx(column) if isinstance(column, int) - else self._df.get_column(column) + else self._df.column(column) ) if s is None: raise IndexError(f"column index {column!r} is out of bounds") @@ -3968,8 +3970,8 @@ def filter( Provide multiple filters using `*args` syntax: >>> df.filter( - ... pl.col("foo") <= 2, - ... ~pl.col("ham").is_in(["b", "c"]), + ... pl.col("foo") == 1, + ... pl.col("ham") == "a", ... ) shape: (1, 3) ┌─────┬─────┬─────┐ @@ -3982,14 +3984,14 @@ def filter( Provide multiple filters using `**kwargs` syntax: - >>> df.filter(foo=2, ham="b") + >>> df.filter(foo=1, ham="a") shape: (1, 3) ┌─────┬─────┬─────┐ │ foo ┆ bar ┆ ham │ │ --- ┆ --- ┆ --- │ │ i64 ┆ i64 ┆ str │ ╞═════╪═════╪═════╡ - │ 2 ┆ 7 ┆ b │ + │ 1 ┆ 6 ┆ a │ └─────┴─────┴─────┘ """ @@ -5123,7 +5125,7 @@ def group_by( """ return GroupBy(self, by, *more_by, maintain_order=maintain_order) - def rolling( + def group_by_rolling( self, index_column: IntoExpr, *, @@ -5175,7 +5177,7 @@ def rolling( not be 24 hours, due to daylight savings). Similarly for "calendar week", "calendar month", "calendar quarter", and "calendar year". - In case of a rolling operation on an integer column, the windows are defined by: + In case of a group_by_rolling on an integer column, the windows are defined by: - **"1i" # length 1** - **"10i" # length 10** @@ -5188,7 +5190,7 @@ def rolling( This column must be sorted in ascending order (or, if `by` is specified, then it must be sorted in ascending order within each group). - In case of a rolling operation on indices, dtype needs to be one of + In case of a rolling group by on indices, dtype needs to be one of {Int32, Int64}. Note that Int32 gets temporarily cast to Int64, so if performance matters use an Int64 column. period @@ -5230,7 +5232,7 @@ def rolling( >>> df = pl.DataFrame({"dt": dates, "a": [3, 7, 5, 9, 2, 1]}).with_columns( ... pl.col("dt").str.strptime(pl.Datetime).set_sorted() ... ) - >>> out = df.rolling(index_column="dt", period="2d").agg( + >>> out = df.group_by_rolling(index_column="dt", period="2d").agg( ... [ ... pl.sum("a").alias("sum_a"), ... pl.min("a").alias("min_a"), @@ -5368,7 +5370,7 @@ def group_by_dynamic( See Also -------- - rolling + group_by_rolling Notes ----- @@ -6648,17 +6650,13 @@ def get_columns(self) -> list[Series]: def get_column(self, name: str) -> Series: """ - Get a single column by name. + Get a single column as Series by name. Parameters ---------- name : str Name of the column to retrieve. - Returns - ------- - Series - See Also -------- to_series @@ -6676,7 +6674,11 @@ def get_column(self, name: str) -> Series: ] """ - return wrap_s(self._df.get_column(name)) + if not isinstance(name, str): + raise TypeError( + f"column name {name!r} should be be a string, but is {type(name).__name__!r}" + ) + return self[name] def fill_null( self, @@ -9754,19 +9756,13 @@ def update( left_on: str | Sequence[str] | None = None, right_on: str | Sequence[str] | None = None, how: Literal["left", "inner", "outer"] = "left", - include_nulls: bool | None = False, ) -> DataFrame: """ - Update the values in this `DataFrame` with the values in `other`. - - By default, null values in the right dataframe are ignored. Use - `ignore_nulls=False` to overwrite values in this frame with null values in other - frame. + Update the values in this `DataFrame` with the non-null values in `other`. Notes ----- - This is syntactic sugar for a left/inner join, with an optional coalesce when - `include_nulls = False`. + This is syntactic sugar for a left/inner join + coalesce Warnings -------- @@ -9790,9 +9786,6 @@ def update( * 'inner' keeps only those rows where the key exists in both frames. * 'outer' will update existing rows where the key matches while also adding any new rows contained in the given frame. - include_nulls - If True, null values from the right dataframe will be used to update the - left dataframe. Examples -------- @@ -9868,29 +9861,10 @@ def update( │ 5 ┆ -66 │ └─────┴─────┘ - Update `df` values including null values in `new_df`, using an outer join - strategy that defines explicit join columns in each frame: - - >>> df.update( - ... new_df, left_on="A", right_on="C", how="outer", include_nulls=True - ... ) - shape: (5, 2) - ┌─────┬──────┐ - │ A ┆ B │ - │ --- ┆ --- │ - │ i64 ┆ i64 │ - ╞═════╪══════╡ - │ 1 ┆ -99 │ - │ 2 ┆ 500 │ - │ 3 ┆ null │ - │ 4 ┆ 700 │ - │ 5 ┆ -66 │ - └─────┴──────┘ - """ return ( self.lazy() - .update(other.lazy(), on, left_on, right_on, how, include_nulls) + .update(other.lazy(), on, left_on, right_on, how) .collect(_eager=True) ) @@ -9932,7 +9906,7 @@ def groupby( """ return self.group_by(by, *more_by, maintain_order=maintain_order) - @deprecate_renamed_function("rolling", version="0.19.0") + @deprecate_renamed_function("group_by_rolling", version="0.19.0") def groupby_rolling( self, index_column: IntoExpr, @@ -9947,60 +9921,7 @@ def groupby_rolling( Create rolling groups based on a time, Int32, or Int64 column. .. deprecated:: 0.19.0 - This method has been renamed to :func:`DataFrame.rolling`. - - Parameters - ---------- - index_column - Column used to group based on the time window. - Often of type Date/Datetime. - This column must be sorted in ascending order (or, if `by` is specified, - then it must be sorted in ascending order within each group). - - In case of a rolling group by on indices, dtype needs to be one of - {Int32, Int64}. Note that Int32 gets temporarily cast to Int64, so if - performance matters use an Int64 column. - period - length of the window - must be non-negative - offset - offset of the window. Default is -period - closed : {'right', 'left', 'both', 'none'} - Define which sides of the temporal interval are closed (inclusive). - by - Also group by this column/these columns - check_sorted - When the ``by`` argument is given, polars can not check sortedness - by the metadata and has to do a full scan on the index column to - verify data is sorted. This is expensive. If you are sure the - data within the by groups is sorted, you can set this to ``False``. - Doing so incorrectly will lead to incorrect output - - """ - return self.rolling( - index_column, - period=period, - offset=offset, - closed=closed, - by=by, - check_sorted=check_sorted, - ) - - @deprecate_renamed_function("rolling", version="0.19.9") - def group_by_rolling( - self, - index_column: IntoExpr, - *, - period: str | timedelta, - offset: str | timedelta | None = None, - closed: ClosedInterval = "right", - by: IntoExpr | Iterable[IntoExpr] | None = None, - check_sorted: bool = True, - ) -> RollingGroupBy: - """ - Create rolling groups based on a time, Int32, or Int64 column. - - .. deprecated:: 0.19.9 - This method has been renamed to :func:`DataFrame.rolling`. + This method has been renamed to :func:`DataFrame.group_by_rolling`. Parameters ---------- @@ -10029,7 +9950,7 @@ def group_by_rolling( Doing so incorrectly will lead to incorrect output """ - return self.rolling( + return self.group_by_rolling( index_column, period=period, offset=offset, diff --git a/py-polars/polars/dataframe/group_by.py b/py-polars/polars/dataframe/group_by.py index 549c78cbd3c3..e648b149970d 100644 --- a/py-polars/polars/dataframe/group_by.py +++ b/py-polars/polars/dataframe/group_by.py @@ -800,7 +800,7 @@ def __iter__(self) -> Self: groups_df = ( self.df.lazy() .with_row_count(name=temp_col) - .rolling( + .group_by_rolling( index_column=self.time_column, period=self.period, offset=self.offset, @@ -859,7 +859,7 @@ def agg( """ return ( self.df.lazy() - .rolling( + .group_by_rolling( index_column=self.time_column, period=self.period, offset=self.offset, @@ -903,7 +903,7 @@ def map_groups( """ return ( self.df.lazy() - .rolling( + .group_by_rolling( index_column=self.time_column, period=self.period, offset=self.offset, diff --git a/py-polars/polars/datatypes/__init__.py b/py-polars/polars/datatypes/__init__.py index 4576282539c4..dc69e38eca39 100644 --- a/py-polars/polars/datatypes/__init__.py +++ b/py-polars/polars/datatypes/__init__.py @@ -18,7 +18,7 @@ Int16, Int32, Int64, - IntegerType, + IntegralType, List, Null, NumericType, @@ -40,7 +40,6 @@ FLOAT_DTYPES, INTEGER_DTYPES, N_INFER_DEFAULT, - NESTED_DTYPES, NUMERIC_DTYPES, SIGNED_INTEGER_DTYPES, TEMPORAL_DTYPES, @@ -94,7 +93,7 @@ "Int32", "Int64", "Int8", - "IntegerType", + "IntegralType", "List", "Null", "NumericType", @@ -114,7 +113,6 @@ "DURATION_DTYPES", "FLOAT_DTYPES", "INTEGER_DTYPES", - "NESTED_DTYPES", "NUMERIC_DTYPES", "N_INFER_DEFAULT", "SIGNED_INTEGER_DTYPES", diff --git a/py-polars/polars/datatypes/classes.py b/py-polars/polars/datatypes/classes.py index 4e830cffe76e..39e16d6d4ff2 100644 --- a/py-polars/polars/datatypes/classes.py +++ b/py-polars/polars/datatypes/classes.py @@ -45,23 +45,24 @@ def __repr__(cls) -> str: def _string_repr(cls) -> str: return _dtype_str_repr(cls) - # Methods below defined here in signature only to satisfy mypy + def base_type(cls) -> DataTypeClass: + """Return the base type.""" + return cls - @classmethod - def base_type(cls) -> DataTypeClass: # noqa: D102 - ... + @classproperty + def is_nested(self) -> bool: + """Check if this data type is nested.""" + return False @classmethod - def is_(cls, other: PolarsDataType) -> bool: # noqa: D102 - ... + def is_(cls, other: PolarsDataType) -> bool: + """Check if this DataType is the same as another DataType.""" + return cls == other and hash(cls) == hash(other) @classmethod - def is_not(cls, other: PolarsDataType) -> bool: # noqa: D102 - ... - - @classproperty - def is_nested(self) -> bool: # noqa: D102 - ... + def is_not(cls, other: PolarsDataType) -> bool: + """Check if this DataType is NOT the same as another DataType.""" + return not cls.is_(other) class DataType(metaclass=DataTypeClass): @@ -96,6 +97,11 @@ def base_type(cls) -> DataTypeClass: """ return cls + @classproperty + def is_nested(self) -> bool: + """Check if this data type is nested.""" + return False + @classinstmethod # type: ignore[arg-type] def is_(self, other: PolarsDataType) -> bool: """ @@ -142,25 +148,6 @@ def is_not(self, other: PolarsDataType) -> bool: """ return not self.is_(other) - @classproperty - def is_nested(self) -> bool: - """ - Check if this data type is nested. - - .. deprecated:: 0.19.10 - Use `dtype in pl.NESTED_DTYPES` instead. - - """ - from polars.utils.deprecation import issue_deprecation_warning - - message = ( - "`DataType.is_nested` is deprecated and will be removed in the next breaking release." - " It will be changed to a classmethod rather than a property." - " To silence this warning, use `dtype in pl.NESTED_DTYPES` instead." - ) - issue_deprecation_warning(message, version="0.19.10") - return False - def _custom_reconstruct( cls: type[Any], base: type[Any], state: Any @@ -213,7 +200,7 @@ class NumericType(DataType): """Base class for numeric data types.""" -class IntegerType(NumericType): +class IntegralType(NumericType): """Base class for integral data types.""" @@ -234,53 +221,39 @@ class NestedType(DataType): @classproperty def is_nested(self) -> bool: - """ - Check if this data type is nested. - - .. deprecated:: 0.19.10 - Use `dtype in pl.NESTED_DTYPES` instead. - - """ - from polars.utils.deprecation import issue_deprecation_warning - - message = ( - "`DataType.is_nested` is deprecated and will be removed in the next breaking release." - " It will be changed to a classmethod rather than a property." - " To silence this warning, use `dtype in pl.NESTED_DTYPES` instead." - ) - issue_deprecation_warning(message, version="0.19.10") + """Check if this data type is nested.""" return True -class Int8(IntegerType): +class Int8(IntegralType): """8-bit signed integer type.""" -class Int16(IntegerType): +class Int16(IntegralType): """16-bit signed integer type.""" -class Int32(IntegerType): +class Int32(IntegralType): """32-bit signed integer type.""" -class Int64(IntegerType): +class Int64(IntegralType): """64-bit signed integer type.""" -class UInt8(IntegerType): +class UInt8(IntegralType): """8-bit unsigned integer type.""" -class UInt16(IntegerType): +class UInt16(IntegralType): """16-bit unsigned integer type.""" -class UInt32(IntegerType): +class UInt32(IntegralType): """32-bit unsigned integer type.""" -class UInt64(IntegerType): +class UInt64(IntegralType): """64-bit unsigned integer type.""" @@ -456,18 +429,18 @@ class Unknown(DataType): class List(NestedType): - """Variable length list type.""" + """Nested list/array type with variable length of inner lists.""" inner: PolarsDataType | None = None def __init__(self, inner: PolarsDataType | PythonDataType): """ - Variable length list type. + Nested list/array type with variable length of inner lists. Parameters ---------- inner - The ``DataType`` of the values within each list. + The `DataType` of values within the list Examples -------- @@ -518,31 +491,26 @@ def __repr__(self) -> str: class Array(NestedType): - """Fixed length list type.""" + """Nested list/array type with fixed length of inner arrays.""" inner: PolarsDataType | None = None width: int - def __init__( # noqa: D417 - self, - *args: Any, - width: int | None = None, - inner: PolarsDataType | PythonDataType | None = None, - ): + def __init__(self, width: int, inner: PolarsDataType | PythonDataType = Null): """ - Fixed length list type. + Nested list/array type with fixed length of inner arrays. Parameters ---------- width - The length of the arrays. + The fixed size length of the inner arrays. inner - The ``DataType`` of the values within each array. + The `DataType` of values within the inner arrays Examples -------- >>> s = pl.Series( - ... "a", [[1, 2], [4, 3]], dtype=pl.Array(inner=pl.Int64, width=2) + ... "a", [[1, 2], [4, 3]], dtype=pl.Array(width=2, inner=pl.Int64) ... ) >>> s shape: (2,) @@ -553,32 +521,6 @@ def __init__( # noqa: D417 ] """ - from polars.utils.deprecation import issue_deprecation_warning - - if args: - # TODO: When removing this deprecation, update the `to_object` - # implementation in py-polars/src/conversion.rs to use `call1` instead of - # `call` - issue_deprecation_warning( - "Parameters `inner` and `width` will change positions in the next breaking release." - " Use keyword arguments to keep current behavior and silence this warning.", - version="0.19.11", - ) - if len(args) == 1: - width = args[0] - else: - width, inner = args[:2] - if width is None: - raise TypeError("`width` must be specified when initializing an `Array`") - - if inner is None: - issue_deprecation_warning( - "The default value for the `inner` parameter of `Array` will be removed in the next breaking release." - " Pass `inner=pl.Null`to keep current behavior and silence this warning.", - version="0.19.11", - ) - inner = Null - self.width = width self.inner = polars.datatypes.py_type_to_dtype(inner) @@ -601,11 +543,11 @@ def __eq__(self, other: PolarsDataType) -> bool: # type: ignore[override] return False def __hash__(self) -> int: - return hash((self.__class__, self.inner, self.width)) + return hash((self.__class__, self.inner)) def __repr__(self) -> str: class_name = self.__class__.__name__ - return f"{class_name}({self.inner!r}, {self.width})" + return f"{class_name}({self.inner!r})" class Field: diff --git a/py-polars/polars/datatypes/constants.py b/py-polars/polars/datatypes/constants.py index f3f2efc56431..a1654442ecc9 100644 --- a/py-polars/polars/datatypes/constants.py +++ b/py-polars/polars/datatypes/constants.py @@ -3,7 +3,6 @@ from typing import TYPE_CHECKING from polars.datatypes import ( - Array, DataTypeGroup, Date, Datetime, @@ -15,8 +14,6 @@ Int16, Int32, Int64, - List, - Struct, Time, UInt8, UInt16, @@ -76,7 +73,5 @@ FLOAT_DTYPES | INTEGER_DTYPES | frozenset([Decimal]) ) -NESTED_DTYPES: frozenset[PolarsDataType] = DataTypeGroup([List, Struct, Array]) - # number of rows to scan by default when inferring datatypes N_INFER_DEFAULT = 100 diff --git a/py-polars/polars/datatypes/convert.py b/py-polars/polars/datatypes/convert.py index 496ed957d427..7d5d61e45c7b 100644 --- a/py-polars/polars/datatypes/convert.py +++ b/py-polars/polars/datatypes/convert.py @@ -21,7 +21,6 @@ ) from polars.datatypes import ( - Array, Binary, Boolean, Categorical, @@ -204,7 +203,7 @@ def unpack_dtypes( unpacked: set[PolarsDataType] = set() for tp in dtypes: - if isinstance(tp, (List, Array)): + if isinstance(tp, List): if include_compound: unpacked.add(tp) unpacked.update(unpack_dtypes(tp.inner, include_compound=include_compound)) diff --git a/py-polars/polars/expr/array.py b/py-polars/polars/expr/array.py index d845fb7263e9..1f80cee68899 100644 --- a/py-polars/polars/expr/array.py +++ b/py-polars/polars/expr/array.py @@ -24,7 +24,7 @@ def min(self) -> Expr: -------- >>> df = pl.DataFrame( ... data={"a": [[1, 2], [4, 3]]}, - ... schema={"a": pl.Array(inner=pl.Int64, width=2)}, + ... schema={"a": pl.Array(width=2, inner=pl.Int64)}, ... ) >>> df.select(pl.col("a").arr.min()) shape: (2, 1) @@ -48,7 +48,7 @@ def max(self) -> Expr: -------- >>> df = pl.DataFrame( ... data={"a": [[1, 2], [4, 3]]}, - ... schema={"a": pl.Array(inner=pl.Int64, width=2)}, + ... schema={"a": pl.Array(width=2, inner=pl.Int64)}, ... ) >>> df.select(pl.col("a").arr.max()) shape: (2, 1) @@ -72,7 +72,7 @@ def sum(self) -> Expr: -------- >>> df = pl.DataFrame( ... data={"a": [[1, 2], [4, 3]]}, - ... schema={"a": pl.Array(inner=pl.Int64, width=2)}, + ... schema={"a": pl.Array(width=2, inner=pl.Int64)}, ... ) >>> df.select(pl.col("a").arr.sum()) shape: (2, 1) @@ -103,7 +103,7 @@ def unique(self, *, maintain_order: bool = False) -> Expr: ... { ... "a": [[1, 1, 2]], ... }, - ... schema_overrides={"a": pl.Array(inner=pl.Int64, width=3)}, + ... schema_overrides={"a": pl.Array(width=3, inner=pl.Int64)}, ... ) >>> df.select(pl.col("a").arr.unique()) shape: (1, 1) diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 05eb15530917..cc9735faf2c1 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -3225,7 +3225,7 @@ def rolling( not be 24 hours, due to daylight savings). Similarly for "calendar week", "calendar month", "calendar quarter", and "calendar year". - In case of a rolling operation on an integer column, the windows are defined by: + In case of a group_by_rolling on an integer column, the windows are defined by: - "1i" # length 1 - "10i" # length 10 @@ -4494,7 +4494,7 @@ def eq(self, other: Any) -> Self: def eq_missing(self, other: Any) -> Self: """ - Method equivalent of equality operator ``expr == other`` where ``None == None``. + Method equivalent of equality operator ``expr == other`` where `None` == None`. This differs from default ``eq`` where null values are propagated. @@ -4709,7 +4709,7 @@ def ne(self, other: Any) -> Self: def ne_missing(self, other: Any) -> Self: """ - Method equivalent of equality operator ``expr != other`` where ``None == None``. + Method equivalent of equality operator ``expr != other`` where `None` == None`. This differs from default ``ne`` where null values are propagated. @@ -5530,7 +5530,7 @@ def rolling_min( Notes ----- If you want to compute multiple aggregation statistics over the same dynamic - window, consider using `rolling` - this method can cache the window size + window, consider using `group_by_rolling` this method can cache the window size computation. Examples @@ -5736,7 +5736,7 @@ def rolling_max( Notes ----- If you want to compute multiple aggregation statistics over the same dynamic - window, consider using `rolling` - this method can cache the window size + window, consider using `group_by_rolling` this method can cache the window size computation. Examples @@ -5973,7 +5973,7 @@ def rolling_mean( Notes ----- If you want to compute multiple aggregation statistics over the same dynamic - window, consider using `rolling` - this method can cache the window size + window, consider using `group_by_rolling` this method can cache the window size computation. Examples @@ -6206,7 +6206,7 @@ def rolling_sum( Notes ----- If you want to compute multiple aggregation statistics over the same dynamic - window, consider using `rolling` - this method can cache the window size + window, consider using `group_by_rolling` this method can cache the window size computation. Examples @@ -6442,7 +6442,7 @@ def rolling_std( Notes ----- If you want to compute multiple aggregation statistics over the same dynamic - window, consider using `rolling` - this method can cache the window size + window, consider using `group_by_rolling` this method can cache the window size computation. Examples @@ -6678,7 +6678,7 @@ def rolling_var( Notes ----- If you want to compute multiple aggregation statistics over the same dynamic - window, consider using `rolling` - this method can cache the window size + window, consider using `group_by_rolling` this method can cache the window size computation. Examples @@ -6917,7 +6917,7 @@ def rolling_median( Notes ----- If you want to compute multiple aggregation statistics over the same dynamic - window, consider using `rolling` - this method can cache the window size + window, consider using `group_by_rolling` this method can cache the window size computation. Examples @@ -7082,7 +7082,7 @@ def rolling_quantile( Notes ----- If you want to compute multiple aggregation statistics over the same dynamic - window, consider using `rolling` - this method can cache the window size + window, consider using `group_by_rolling` this method can cache the window size computation. Examples @@ -7409,7 +7409,7 @@ def rank( def diff(self, n: int = 1, null_behavior: NullBehavior = "ignore") -> Self: """ - Calculate the first discrete difference between shifted items. + Calculate the n-th discrete difference. Parameters ---------- @@ -7464,7 +7464,7 @@ def diff(self, n: int = 1, null_behavior: NullBehavior = "ignore") -> Self: """ return self._from_pyexpr(self._pyexpr.diff(n, null_behavior)) - def pct_change(self, n: int | IntoExprColumn = 1) -> Self: + def pct_change(self, n: int = 1) -> Self: """ Computes percentage change between values. @@ -7500,7 +7500,6 @@ def pct_change(self, n: int | IntoExprColumn = 1) -> Self: └──────┴────────────┘ """ - n = parse_as_expression(n) return self._from_pyexpr(self._pyexpr.pct_change(n)) def skew(self, *, bias: bool = True) -> Self: @@ -8244,7 +8243,7 @@ def shuffle(self, seed: int | None = None) -> Self: def sample( self, - n: int | IntoExprColumn | None = None, + n: int | Expr | None = None, *, fraction: float | None = None, with_replacement: bool = False, @@ -8300,7 +8299,6 @@ def sample( self._pyexpr.sample_n(n, with_replacement, shuffle, seed) ) - @deprecate_nonkeyword_arguments(version="0.19.10") def ewm_mean( self, com: float | None = None, @@ -8368,6 +8366,7 @@ def ewm_mean( :math:`1-\alpha` and :math:`1` if ``adjust=True``, and :math:`1-\alpha` and :math:`\alpha` if ``adjust=False``. + Examples -------- >>> df = pl.DataFrame({"a": [1, 2, 3]}) @@ -8389,7 +8388,6 @@ def ewm_mean( self._pyexpr.ewm_mean(alpha, adjust, min_periods, ignore_nulls) ) - @deprecate_nonkeyword_arguments(version="0.19.10") def ewm_std( self, com: float | None = None, @@ -8482,7 +8480,6 @@ def ewm_std( self._pyexpr.ewm_std(alpha, adjust, bias, min_periods, ignore_nulls) ) - @deprecate_nonkeyword_arguments(version="0.19.10") def ewm_var( self, com: float | None = None, @@ -9511,11 +9508,10 @@ def is_last(self) -> Self: def _register_plugin( self, - *, lib: str, symbol: str, args: list[IntoExpr] | None = None, - kwargs: dict[Any, Any] | None = None, + *, is_elementwise: bool = False, input_wildcard_expansion: bool = False, auto_explode: bool = False, @@ -9540,9 +9536,6 @@ def _register_plugin( Function to load. args Arguments (other than self) passed to this function. - These arguments have to be of type Expression. - kwargs - Non-expression arguments. They must be JSON serializable. is_elementwise If the function only operates on scalars this will trigger fast paths. @@ -9559,19 +9552,11 @@ def _register_plugin( args = [] else: args = [parse_as_expression(a) for a in args] - if kwargs is None: - serialized_kwargs = b"" - else: - import pickle - - serialized_kwargs = pickle.dumps(kwargs, protocol=2) - return self._from_pyexpr( self._pyexpr.register_plugin( lib, symbol, args, - serialized_kwargs, is_elementwise, input_wildcard_expansion, auto_explode, diff --git a/py-polars/polars/expr/list.py b/py-polars/polars/expr/list.py index 88c4cb993402..5a41f4e2f213 100644 --- a/py-polars/polars/expr/list.py +++ b/py-polars/polars/expr/list.py @@ -138,64 +138,6 @@ def drop_nulls(self) -> Expr: """ return wrap_expr(self._pyexpr.list_drop_nulls()) - def sample( - self, - n: int | IntoExprColumn | None = None, - *, - fraction: float | IntoExprColumn | None = None, - with_replacement: bool = False, - shuffle: bool = False, - seed: int | None = None, - ) -> Expr: - """ - Sample from this list. - - Parameters - ---------- - n - Number of items to return. Cannot be used with `fraction`. Defaults to 1 if - `fraction` is None. - fraction - Fraction of items to return. Cannot be used with `n`. - with_replacement - Allow values to be sampled more than once. - shuffle - Shuffle the order of sampled data points. - seed - Seed for the random number generator. If set to None (default), a - random seed is generated for each sample operation. - - Examples - -------- - >>> df = pl.DataFrame({"values": [[1, 2, 3], [4, 5]], "n": [2, 1]}) - >>> df.select(pl.col("values").list.sample(n=pl.col("n"), seed=1)) - shape: (2, 1) - ┌───────────┐ - │ values │ - │ --- │ - │ list[i64] │ - ╞═══════════╡ - │ [2, 1] │ - │ [5] │ - └───────────┘ - - """ - if n is not None and fraction is not None: - raise ValueError("cannot specify both `n` and `fraction`") - - if fraction is not None: - fraction = parse_as_expression(fraction) - return wrap_expr( - self._pyexpr.list_sample_fraction( - fraction, with_replacement, shuffle, seed - ) - ) - - if n is None: - n = 1 - n = parse_as_expression(n) - return wrap_expr(self._pyexpr.list_sample_n(n, with_replacement, shuffle, seed)) - def sum(self) -> Expr: """ Sum all the lists in the array. @@ -664,7 +606,7 @@ def arg_max(self) -> Expr: def diff(self, n: int = 1, null_behavior: NullBehavior = "ignore") -> Expr: """ - Calculate the first discrete difference between shifted items of every sublist. + Calculate the n-th discrete difference of every sublist. Parameters ---------- diff --git a/py-polars/polars/functions/eager.py b/py-polars/polars/functions/eager.py index 6c225c37f47b..c953991eaa09 100644 --- a/py-polars/polars/functions/eager.py +++ b/py-polars/polars/functions/eager.py @@ -33,7 +33,7 @@ def concat( ---------- items DataFrames, LazyFrames, or Series to concatenate. - how : {'vertical', 'vertical_relaxed', 'diagonal', 'diagonal_relaxed', 'horizontal', 'align'} + how : {'vertical', 'vertical_relaxed', 'diagonal', 'horizontal', 'align'} Series only support the `vertical` strategy. LazyFrames do not support the `horizontal` strategy. @@ -125,7 +125,7 @@ def concat( │ 3 ┆ null ┆ 6 ┆ 8 │ └─────┴──────┴──────┴──────┘ - """ # noqa: W505 + """ # unpack/standardise (handles generator input) elems = list(items) diff --git a/py-polars/polars/io/database.py b/py-polars/polars/io/database.py index 1f0795b3df0d..ed3bfa858d5a 100644 --- a/py-polars/polars/io/database.py +++ b/py-polars/polars/io/database.py @@ -31,49 +31,42 @@ Selectable: TypeAlias = Any # type: ignore[no-redef] -class _ArrowDriverProperties_(TypedDict): - fetch_all: str # name of the method that fetches all arrow data - fetch_batches: str | None # name of the method that fetches arrow data in batches - exact_batch_size: bool | None # whether indicated batch size is respected exactly - repeat_batch_calls: bool # repeat batch calls (if False, batch call is generator) +class _DriverProperties_(TypedDict): + fetch_all: str + fetch_batches: str | None + exact_batch_size: bool | None -_ARROW_DRIVER_REGISTRY_: dict[str, _ArrowDriverProperties_] = { +_ARROW_DRIVER_REGISTRY_: dict[str, _DriverProperties_] = { "adbc_.*": { "fetch_all": "fetch_arrow_table", "fetch_batches": None, "exact_batch_size": None, - "repeat_batch_calls": False, }, "arrow_odbc_proxy": { "fetch_all": "fetch_record_batches", "fetch_batches": "fetch_record_batches", "exact_batch_size": True, - "repeat_batch_calls": False, }, "databricks": { "fetch_all": "fetchall_arrow", "fetch_batches": "fetchmany_arrow", "exact_batch_size": True, - "repeat_batch_calls": True, }, "duckdb": { "fetch_all": "fetch_arrow_table", "fetch_batches": "fetch_record_batch", "exact_batch_size": True, - "repeat_batch_calls": False, }, "snowflake": { "fetch_all": "fetch_arrow_all", "fetch_batches": "fetch_arrow_batches", "exact_batch_size": False, - "repeat_batch_calls": False, }, "turbodbc": { "fetch_all": "fetchallarrow", "fetch_batches": "fetcharrowbatches", "exact_batch_size": False, - "repeat_batch_calls": False, }, } @@ -128,9 +121,10 @@ def fetch_record_batches( class ConnectionExecutor: """Abstraction for querying databases with user-supplied connection objects.""" - # indicate if we can/should close the cursor on scope exit. note that we - # should never close the underlying connection, or a user-supplied cursor. - can_close_cursor: bool = False + # indicate that we acquired a cursor (and are therefore responsible for closing + # it on scope-exit). note that we should never close the underlying connection, + # or a user-supplied cursor. + acquired_cursor: bool = False def __init__(self, connection: ConnectionOrCursor) -> None: self.driver_name = ( @@ -150,57 +144,24 @@ def __exit__( exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: - # iif we created it and are finished with it, we can - # close the cursor (but NOT the connection) - if self.can_close_cursor: + # iif we created it, close the cursor (NOT the connection) + if self.acquired_cursor: self.cursor.close() def __repr__(self) -> str: return f"<{type(self).__name__} module={self.driver_name!r}>" - def _arrow_batches( - self, - driver_properties: _ArrowDriverProperties_, - *, - batch_size: int | None, - iter_batches: bool, - ) -> Iterable[pa.RecordBatch]: - """Yield Arrow data in batches, or as a single 'fetchall' batch.""" - fetch_batches = driver_properties["fetch_batches"] - if not iter_batches or fetch_batches is None: - fetch_method = driver_properties["fetch_all"] - yield getattr(self.result, fetch_method)() - else: - size = batch_size if driver_properties["exact_batch_size"] else None - repeat_batch_calls = driver_properties["repeat_batch_calls"] - fetchmany_arrow = getattr(self.result, fetch_batches) - if not repeat_batch_calls: - yield from fetchmany_arrow(size) - else: - while True: - arrow = fetchmany_arrow(size) - if not arrow: - break - yield arrow - def _normalise_cursor(self, conn: ConnectionOrCursor) -> Cursor: """Normalise a connection object such that we have the query executor.""" if self.driver_name == "sqlalchemy" and type(conn).__name__ == "Engine": - self.can_close_cursor = True - if conn.driver == "databricks-sql-python": # type: ignore[union-attr] - # take advantage of the raw connection to get arrow integration - self.driver_name = "databricks" - return conn.raw_connection().cursor() # type: ignore[union-attr] - else: - # sqlalchemy engine; direct use is deprecated, so prefer the connection - return conn.connect() # type: ignore[union-attr] - + # sqlalchemy engine; direct use is deprecated, so prefer the connection + self.acquired_cursor = True + return conn.connect() # type: ignore[union-attr] elif hasattr(conn, "cursor"): # connection has a dedicated cursor; prefer over direct execute cursor = cursor() if callable(cursor := conn.cursor) else cursor - self.can_close_cursor = True + self.acquired_cursor = True return cursor - elif hasattr(conn, "execute"): # can execute directly (given cursor, sqlalchemy connection, etc) return conn # type: ignore[return-value] @@ -245,20 +206,22 @@ def _from_arrow( try: for driver, driver_properties in _ARROW_DRIVER_REGISTRY_.items(): if re.match(f"^{driver}$", self.driver_name): + size = batch_size if driver_properties["exact_batch_size"] else None fetch_batches = driver_properties["fetch_batches"] - self.can_close_cursor = fetch_batches is None or not iter_batches frames = ( from_arrow(batch, schema_overrides=schema_overrides) - for batch in self._arrow_batches( - driver_properties, - iter_batches=iter_batches, - batch_size=batch_size, + for batch in ( + getattr(self.result, fetch_batches)(size) + if (iter_batches and fetch_batches is not None) + else [ + getattr(self.result, driver_properties["fetch_all"])() + ] ) ) return frames if iter_batches else next(frames) # type: ignore[arg-type,return-value] except Exception as err: # eg: valid turbodbc/snowflake connection, but no arrow support - # compiled in to the underlying driver (or on this connection) + # available in the underlying driver or this connection arrow_not_supported = ( "does not support Apache Arrow", "Apache Arrow format is not supported", diff --git a/py-polars/polars/io/spreadsheet/functions.py b/py-polars/polars/io/spreadsheet/functions.py index 29bfac7acec1..198b69404290 100644 --- a/py-polars/polars/io/spreadsheet/functions.py +++ b/py-polars/polars/io/spreadsheet/functions.py @@ -28,7 +28,7 @@ def read_excel( engine: Literal["xlsx2csv", "openpyxl", "pyxlsb"] | None = ..., xlsx2csv_options: dict[str, Any] | None = ..., read_csv_options: dict[str, Any] | None = ..., - schema_overrides: SchemaDict | None = ..., + schema_overrides: SchemaDict | None = None, raise_if_empty: bool = ..., ) -> pl.DataFrame: ... @@ -43,7 +43,7 @@ def read_excel( engine: Literal["xlsx2csv", "openpyxl", "pyxlsb"] | None = ..., xlsx2csv_options: dict[str, Any] | None = ..., read_csv_options: dict[str, Any] | None = ..., - schema_overrides: SchemaDict | None = ..., + schema_overrides: SchemaDict | None = None, raise_if_empty: bool = ..., ) -> pl.DataFrame: ... @@ -58,7 +58,7 @@ def read_excel( engine: Literal["xlsx2csv", "openpyxl", "pyxlsb"] | None = ..., xlsx2csv_options: dict[str, Any] | None = ..., read_csv_options: dict[str, Any] | None = ..., - schema_overrides: SchemaDict | None = ..., + schema_overrides: SchemaDict | None = None, raise_if_empty: bool = ..., ) -> NoReturn: ... @@ -75,7 +75,7 @@ def read_excel( engine: Literal["xlsx2csv", "openpyxl", "pyxlsb"] | None = ..., xlsx2csv_options: dict[str, Any] | None = ..., read_csv_options: dict[str, Any] | None = ..., - schema_overrides: SchemaDict | None = ..., + schema_overrides: SchemaDict | None = None, raise_if_empty: bool = ..., ) -> dict[str, pl.DataFrame]: ... @@ -90,7 +90,7 @@ def read_excel( engine: Literal["xlsx2csv", "openpyxl", "pyxlsb"] | None = ..., xlsx2csv_options: dict[str, Any] | None = ..., read_csv_options: dict[str, Any] | None = ..., - schema_overrides: SchemaDict | None = ..., + schema_overrides: SchemaDict | None = None, raise_if_empty: bool = ..., ) -> pl.DataFrame: ... @@ -105,7 +105,7 @@ def read_excel( engine: Literal["xlsx2csv", "openpyxl", "pyxlsb"] | None = ..., xlsx2csv_options: dict[str, Any] | None = ..., read_csv_options: dict[str, Any] | None = ..., - schema_overrides: SchemaDict | None = ..., + schema_overrides: SchemaDict | None = None, raise_if_empty: bool = ..., ) -> dict[str, pl.DataFrame]: ... @@ -445,10 +445,6 @@ def _read_spreadsheet( if hasattr(parser, "close"): parser.close() - if not parsed_sheets: - param, value = ("id", sheet_id) if sheet_name is None else ("name", sheet_name) - raise ValueError(f"no matching sheets found when `sheet_{param}` is {value!r}") - if return_multi: return parsed_sheets return next(iter(parsed_sheets.values())) @@ -552,7 +548,6 @@ def _csv_buffer_to_frame( raise ParameterCollisionError( "cannot specify columns in both `schema_overrides` and `read_csv_options['dtypes']`" ) - read_csv_options = read_csv_options.copy() read_csv_options["dtypes"] = {**csv_dtypes, **schema_overrides} # otherwise rewind the buffer and parse as csv diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index 784c2afa7583..6943598beffc 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -2946,7 +2946,7 @@ def group_by( lgb = self._ldf.group_by(exprs, maintain_order) return LazyGroupBy(lgb) - def rolling( + def group_by_rolling( self, index_column: IntoExpr, *, @@ -2998,7 +2998,7 @@ def rolling( not be 24 hours, due to daylight savings). Similarly for "calendar week", "calendar month", "calendar quarter", and "calendar year". - In case of a rolling operation on an integer column, the windows are defined by: + In case of a group_by_rolling on an integer column, the windows are defined by: - "1i" # length 1 - "10i" # length 10 @@ -3054,14 +3054,19 @@ def rolling( ... pl.col("dt").str.strptime(pl.Datetime).set_sorted() ... ) >>> out = ( - ... df.rolling(index_column="dt", period="2d") + ... df.group_by_rolling(index_column="dt", period="2d") ... .agg( - ... pl.sum("a").alias("sum_a"), - ... pl.min("a").alias("min_a"), - ... pl.max("a").alias("max_a"), + ... [ + ... pl.sum("a").alias("sum_a"), + ... pl.min("a").alias("min_a"), + ... pl.max("a").alias("max_a"), + ... ] ... ) ... .collect() ... ) + >>> assert out["sum_a"].to_list() == [3, 10, 15, 24, 11, 1] + >>> assert out["max_a"].to_list() == [3, 7, 7, 9, 9, 1] + >>> assert out["min_a"].to_list() == [3, 3, 3, 3, 2, 1] >>> out shape: (6, 4) ┌─────────────────────┬───────┬───────┬───────┐ @@ -3086,7 +3091,7 @@ def rolling( period = _timedelta_to_pl_duration(period) offset = _timedelta_to_pl_duration(offset) - lgb = self._ldf.rolling( + lgb = self._ldf.group_by_rolling( index_column, period, offset, closed, pyexprs_by, check_sorted ) return LazyGroupBy(lgb) @@ -3193,7 +3198,7 @@ def group_by_dynamic( See Also -------- - rolling + group_by_rolling Notes ----- @@ -5656,7 +5661,6 @@ def update( left_on: str | Sequence[str] | None = None, right_on: str | Sequence[str] | None = None, how: Literal["left", "inner", "outer"] = "left", - include_nulls: bool | None = False, ) -> Self: """ Update the values in this `LazyFrame` with the non-null values in `other`. @@ -5678,14 +5682,10 @@ def update( * 'inner' keeps only those rows where the key exists in both frames. * 'outer' will update existing rows where the key matches while also adding any new rows contained in the given frame. - include_nulls - If True, null values from the right dataframe will be used to update the - left dataframe. Notes ----- - This is syntactic sugar for a left/inner join, with an optional coalesce when - `include_nulls = False`. + This is syntactic sugar for a join + coalesce (upsert) operation. Examples -------- @@ -5761,25 +5761,6 @@ def update( │ 5 ┆ -66 │ └─────┴─────┘ - Update `df` values including null values in `new_df`, using an outer join - strategy that defines explicit join columns in each frame: - - >>> lf.update( - ... new_lf, left_on="A", right_on="C", how="outer", include_nulls=True - ... ).collect() - shape: (5, 2) - ┌─────┬──────┐ - │ A ┆ B │ - │ --- ┆ --- │ - │ i64 ┆ i64 │ - ╞═════╪══════╡ - │ 1 ┆ -99 │ - │ 2 ┆ 500 │ - │ 3 ┆ null │ - │ 4 ┆ 700 │ - │ 5 ┆ -66 │ - └─────┴──────┘ - """ if how not in ("left", "inner", "outer"): raise ValueError( @@ -5828,38 +5809,24 @@ def update( # only use non-idx right columns present in left frame right_other = set(other.columns).intersection(self.columns) - set(right_on) - # When include_nulls is True, we need to distinguish records after the join that - # were originally null in the right frame, as opposed to records that were null - # because the key was missing from the right frame. - # Add a validity column to track whether row was matched or not. - if include_nulls: - validity = ("__POLARS_VALIDITY",) - other = other.with_columns(F.lit(True).alias(validity[0])) - else: - validity = () # type: ignore[assignment] - tmp_name = "__POLARS_RIGHT" - drop_columns = [*(f"{name}{tmp_name}" for name in right_other), *validity] result = ( self.join( - other.select(*right_on, *right_other, *validity), + other.select(*right_on, *right_other), left_on=left_on, right_on=right_on, how=how, suffix=tmp_name, ) .with_columns( - ( - # use left value only when right value failed to join - F.when(F.col(validity).is_null()) - .then(F.col(name)) - .otherwise(F.col(f"{name}{tmp_name}")) - if include_nulls - else F.coalesce([f"{name}{tmp_name}", F.col(name)]) - ).alias(name) - for name in right_other + [ + F.coalesce([f"{column_name}{tmp_name}", F.col(column_name)]).alias( + column_name + ) + for column_name in right_other + ] ) - .drop(drop_columns) + .drop([f"{name}{tmp_name}" for name in right_other]) ) if row_count_used: result = result.drop(row_count_name) @@ -5895,7 +5862,7 @@ def groupby( """ return self.group_by(by, *more_by, maintain_order=maintain_order) - @deprecate_renamed_function("rolling", version="0.19.0") + @deprecate_renamed_function("group_by_rolling", version="0.19.0") def groupby_rolling( self, index_column: IntoExpr, @@ -5910,67 +5877,7 @@ def groupby_rolling( Create rolling groups based on a time, Int32, or Int64 column. .. deprecated:: 0.19.0 - This method has been renamed to :func:`LazyFrame.rolling`. - - Parameters - ---------- - index_column - Column used to group based on the time window. - Often of type Date/Datetime. - This column must be sorted in ascending order (or, if `by` is specified, - then it must be sorted in ascending order within each group). - - In case of a rolling group by on indices, dtype needs to be one of - {Int32, Int64}. Note that Int32 gets temporarily cast to Int64, so if - performance matters use an Int64 column. - period - length of the window - must be non-negative - offset - offset of the window. Default is -period - closed : {'right', 'left', 'both', 'none'} - Define which sides of the temporal interval are closed (inclusive). - by - Also group by this column/these columns - check_sorted - When the ``by`` argument is given, polars can not check sortedness - by the metadata and has to do a full scan on the index column to - verify data is sorted. This is expensive. If you are sure the - data within the by groups is sorted, you can set this to ``False``. - Doing so incorrectly will lead to incorrect output - - Returns - ------- - LazyGroupBy - Object you can call ``.agg`` on to aggregate by groups, the result - of which will be sorted by `index_column` (but note that if `by` columns are - passed, it will only be sorted within each `by` group). - - """ - return self.rolling( - index_column, - period=period, - offset=offset, - closed=closed, - by=by, - check_sorted=check_sorted, - ) - - @deprecate_renamed_function("rolling", version="0.19.9") - def group_by_rolling( - self, - index_column: IntoExpr, - *, - period: str | timedelta, - offset: str | timedelta | None = None, - closed: ClosedInterval = "right", - by: IntoExpr | Iterable[IntoExpr] | None = None, - check_sorted: bool = True, - ) -> LazyGroupBy: - """ - Create rolling groups based on a time, Int32, or Int64 column. - - .. deprecated:: 0.19.9 - This method has been renamed to :func:`LazyFrame.rolling`. + This method has been renamed to :func:`LazyFrame.group_by_rolling`. Parameters ---------- @@ -6006,7 +5913,7 @@ def group_by_rolling( passed, it will only be sorted within each `by` group). """ - return self.rolling( + return self.group_by_rolling( index_column, period=period, offset=offset, diff --git a/py-polars/polars/series/array.py b/py-polars/polars/series/array.py index 5c61e4ab54f8..2ece17871ed5 100644 --- a/py-polars/polars/series/array.py +++ b/py-polars/polars/series/array.py @@ -25,7 +25,7 @@ def min(self) -> Series: Examples -------- >>> s = pl.Series( - ... "a", [[1, 2], [4, 3]], dtype=pl.Array(inner=pl.Int64, width=2) + ... "a", [[1, 2], [4, 3]], dtype=pl.Array(width=2, inner=pl.Int64) ... ) >>> s.arr.min() shape: (2,) @@ -44,7 +44,7 @@ def max(self) -> Series: Examples -------- >>> s = pl.Series( - ... "a", [[1, 2], [4, 3]], dtype=pl.Array(inner=pl.Int64, width=2) + ... "a", [[1, 2], [4, 3]], dtype=pl.Array(width=2, inner=pl.Int64) ... ) >>> s.arr.max() shape: (2,) @@ -64,7 +64,7 @@ def sum(self) -> Series: -------- >>> df = pl.DataFrame( ... data={"a": [[1, 2], [4, 3]]}, - ... schema={"a": pl.Array(inner=pl.Int64, width=2)}, + ... schema={"a": pl.Array(width=2, inner=pl.Int64)}, ... ) >>> df.select(pl.col("a").arr.sum()) shape: (2, 1) @@ -94,7 +94,7 @@ def unique(self, *, maintain_order: bool = False) -> Series: ... { ... "a": [[1, 1, 2]], ... }, - ... schema_overrides={"a": pl.Array(inner=pl.Int64, width=3)}, + ... schema_overrides={"a": pl.Array(width=3, inner=pl.Int64)}, ... ) >>> df.select(pl.col("a").arr.unique()) shape: (1, 1) diff --git a/py-polars/polars/series/list.py b/py-polars/polars/series/list.py index 6387362d04e9..3b883df4dea3 100644 --- a/py-polars/polars/series/list.py +++ b/py-polars/polars/series/list.py @@ -125,46 +125,6 @@ def drop_nulls(self) -> Series: """ - def sample( - self, - n: int | IntoExprColumn | None = None, - *, - fraction: float | IntoExprColumn | None = None, - with_replacement: bool = False, - shuffle: bool = False, - seed: int | None = None, - ) -> Series: - """ - Sample from this list. - - Parameters - ---------- - n - Number of items to return. Cannot be used with `fraction`. Defaults to 1 if - `fraction` is None. - fraction - Fraction of items to return. Cannot be used with `n`. - with_replacement - Allow values to be sampled more than once. - shuffle - Shuffle the order of sampled data points. - seed - Seed for the random number generator. If set to None (default), a - random seed is generated for each sample operation. - - Examples - -------- - >>> s = pl.Series("values", [[1, 2, 3], [4, 5]]) - >>> s.list.sample(n=pl.Series("n", [2, 1]), seed=1) - shape: (2,) - Series: 'values' [list[i64]] - [ - [2, 1] - [5] - ] - - """ - def sum(self) -> Series: """Sum all the arrays in the list.""" @@ -347,7 +307,7 @@ def arg_max(self) -> Series: def diff(self, n: int = 1, null_behavior: NullBehavior = "ignore") -> Series: """ - Calculate the first discrete difference between shifted items of every sublist. + Calculate the n-th discrete difference of every sublist. Parameters ---------- diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 90591799cdf2..d6036a83415c 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -481,30 +481,14 @@ def _comp(self, other: Any, op: ComparisonOperator) -> Series: return self.clone() elif (other is False and op == "eq") or (other is True and op == "neq"): return ~self - elif isinstance(other, float) and self.dtype in INTEGER_DTYPES: - # require upcast when comparing int series to float value - self = self.cast(Float64) - f = get_ffi_func(op + "_<>", Float64, self._s) - assert f is not None - return self._from_pyseries(f(other)) - elif isinstance(other, datetime): - if self.dtype == Date: - # require upcast when comparing date series to datetime - self = self.cast(Datetime("us")) - time_unit = "us" - elif self.dtype == Datetime: - # Use local time zone info - time_zone = self.dtype.time_zone # type: ignore[union-attr] - if str(other.tzinfo) != str(time_zone): - raise TypeError( - f"Datetime time zone {other.tzinfo!r} does not match Series timezone {time_zone!r}" - ) - time_unit = self.dtype.time_unit # type: ignore[union-attr] - else: - raise ValueError( - f"cannot compare datetime.datetime to series of type {self.dtype}" + + if isinstance(other, datetime) and self.dtype == Datetime: + time_zone = self.dtype.time_zone # type: ignore[union-attr] + if str(other.tzinfo) != str(time_zone): + raise TypeError( + f"Datetime time zone {other.tzinfo!r} does not match Series timezone {time_zone!r}" ) - ts = _datetime_to_pl_timestamp(other, time_unit) # type: ignore[arg-type] + ts = _datetime_to_pl_timestamp(other, self.dtype.time_unit) # type: ignore[union-attr] f = get_ffi_func(op + "_<>", Int64, self._s) assert f is not None return self._from_pyseries(f(ts)) @@ -513,13 +497,14 @@ def _comp(self, other: Any, op: ComparisonOperator) -> Series: f = get_ffi_func(op + "_<>", Int64, self._s) assert f is not None return self._from_pyseries(f(d)) - elif self.dtype == Categorical and not isinstance(other, Series): - other = Series([other]) elif isinstance(other, date) and self.dtype == Date: d = _date_to_pl_date(other) f = get_ffi_func(op + "_<>", Int32, self._s) assert f is not None return self._from_pyseries(f(d)) + elif self.dtype == Categorical and not isinstance(other, Series): + other = Series([other]) + if isinstance(other, Sequence) and not isinstance(other, str): other = Series("", other, dtype_if_empty=self.dtype) if isinstance(other, Series): @@ -633,7 +618,7 @@ def eq_missing(self, other: Expr) -> Expr: # type: ignore[misc] def eq_missing(self, other: Any) -> Self | Expr: """ - Method equivalent of equality operator ``series == other`` where ``None == None``. + Method equivalent of equality operator ``series == other`` where `None` == None`. This differs from the standard ``ne`` where null values are propagated. @@ -684,7 +669,7 @@ def ne_missing(self, other: Any) -> Self: def ne_missing(self, other: Any) -> Self | Expr: """ - Method equivalent of equality operator ``series != other`` where ``None == None``. + Method equivalent of equality operator ``series != other`` where `None` == None`. This differs from the standard ``ne`` where null values are propagated. @@ -3826,7 +3811,7 @@ def to_physical(self) -> Series: """ - def to_list(self, *, use_pyarrow: bool | None = None) -> list[Any]: + def to_list(self, *, use_pyarrow: bool = False) -> list[Any]: """ Convert this Series to a Python List. This operation clones data. @@ -3844,15 +3829,8 @@ def to_list(self, *, use_pyarrow: bool | None = None) -> list[Any]: """ - if use_pyarrow is not None: - issue_deprecation_warning( - "The parameter `use_pyarrow` for `Series.to_list` is deprecated." - " Call the method without `use_pyarrow` to silence this warning.", - version="0.19.9", - ) - if use_pyarrow: - return self.to_arrow().to_pylist() - + if use_pyarrow: + return self.to_arrow().to_pylist() return self._s.to_list() def rechunk(self, *, in_place: bool = False) -> Self: @@ -6003,7 +5981,7 @@ def rank( def diff(self, n: int = 1, null_behavior: NullBehavior = "ignore") -> Series: """ - Calculate the first discrete difference between shifted items. + Calculate the n-th discrete difference. Parameters ---------- @@ -6048,7 +6026,7 @@ def diff(self, n: int = 1, null_behavior: NullBehavior = "ignore") -> Series: """ - def pct_change(self, n: int | IntoExprColumn = 1) -> Series: + def pct_change(self, n: int = 1) -> Series: """ Computes percentage change between values. @@ -6428,7 +6406,6 @@ def shuffle(self, seed: int | None = None) -> Series: """ - @deprecate_nonkeyword_arguments(version="0.19.10") def ewm_mean( self, com: float | None = None, @@ -6496,21 +6473,8 @@ def ewm_mean( :math:`1-\alpha` and :math:`1` if ``adjust=True``, and :math:`1-\alpha` and :math:`\alpha` if ``adjust=False``. - Examples - -------- - >>> s = pl.Series([1, 2, 3]) - >>> s.ewm_mean(com=1) - shape: (3,) - Series: '' [f64] - [ - 1.0 - 1.666667 - 2.428571 - ] - """ - @deprecate_nonkeyword_arguments(version="0.19.10") def ewm_std( self, com: float | None = None, @@ -6596,7 +6560,6 @@ def ewm_std( """ - @deprecate_nonkeyword_arguments(version="0.19.10") def ewm_var( self, com: float | None = None, diff --git a/py-polars/polars/series/utils.py b/py-polars/polars/series/utils.py index 9e5c95477b74..00ab66ba5e5e 100644 --- a/py-polars/polars/series/utils.py +++ b/py-polars/polars/series/utils.py @@ -89,7 +89,7 @@ def _undecorated(function: Callable[P, T]) -> Callable[P, T]: def call_expr(func: SeriesMethod) -> SeriesMethod: """Dispatch Series method to an expression implementation.""" - @wraps(func) + @wraps(func) # type: ignore[arg-type] def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> Series: s = wrap_s(self._s) expr = F.col(s.name) diff --git a/py-polars/polars/testing/__init__.py b/py-polars/polars/testing/__init__.py index b5962f7fba2c..13cf07939044 100644 --- a/py-polars/polars/testing/__init__.py +++ b/py-polars/polars/testing/__init__.py @@ -1,13 +1,15 @@ from polars.testing.asserts import ( assert_frame_equal, + assert_frame_equal_local_categoricals, assert_frame_not_equal, assert_series_equal, assert_series_not_equal, ) __all__ = [ - "assert_frame_equal", - "assert_frame_not_equal", "assert_series_equal", "assert_series_not_equal", + "assert_frame_equal", + "assert_frame_not_equal", + "assert_frame_equal_local_categoricals", ] diff --git a/py-polars/polars/testing/_private.py b/py-polars/polars/testing/_private.py new file mode 100644 index 000000000000..f7a47c49c1fd --- /dev/null +++ b/py-polars/polars/testing/_private.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from polars.datatypes import Utf8 + +if TYPE_CHECKING: + from polars import DataFrame, Series + + +def _to_rust_syntax(df: DataFrame) -> str: + """Utility to generate the syntax that creates a polars 'DataFrame' in Rust.""" + syntax = "df![\n" + + def format_s(s: Series) -> str: + if s.null_count() == 0: + out = str(s.to_list()).replace("'", '"') + if s.dtype != Utf8: + out = out.lower() + return out + else: + tmp = "[" + for val in s: + if val is None: + tmp += "None, " + else: + if isinstance(val, str): + tmp += f'Some("{val}"), ' + else: + val = str(val).lower() + tmp += f"Some({val}), " + tmp = tmp[:-2] + "]" + return tmp + + for s in df: + syntax += f' "{s.name}" => {format_s(s)},\n' + syntax += "]" + return syntax diff --git a/py-polars/polars/testing/asserts.py b/py-polars/polars/testing/asserts.py new file mode 100644 index 000000000000..7fdd92edb89a --- /dev/null +++ b/py-polars/polars/testing/asserts.py @@ -0,0 +1,581 @@ +from __future__ import annotations + +import textwrap +from typing import Any, NoReturn + +from polars import functions as F +from polars.dataframe import DataFrame +from polars.datatypes import ( + FLOAT_DTYPES, + UNSIGNED_INTEGER_DTYPES, + Categorical, + DataTypeClass, + List, + Struct, + UInt64, + Utf8, + dtype_to_py_type, + unpack_dtypes, +) +from polars.exceptions import ComputeError, InvalidAssert +from polars.lazyframe import LazyFrame +from polars.series import Series +from polars.utils.deprecation import deprecate_function + + +def assert_frame_equal( + left: DataFrame | LazyFrame, + right: DataFrame | LazyFrame, + *, + check_row_order: bool = True, + check_column_order: bool = True, + check_dtype: bool = True, + check_exact: bool = False, + rtol: float = 1.0e-5, + atol: float = 1.0e-8, + nans_compare_equal: bool = True, + categorical_as_str: bool = False, +) -> None: + """ + Raise detailed AssertionError if `left` does NOT equal `right`. + + Parameters + ---------- + left + the DataFrame to compare. + right + the DataFrame to compare with. + check_row_order + if False, frames will compare equal if the required rows are present, + irrespective of the order in which they appear; as this requires + sorting, you cannot set on frames that contain unsortable columns. + check_column_order + if False, frames will compare equal if the required columns are present, + irrespective of the order in which they appear. + check_dtype + if True, data types need to match exactly. + check_exact + if False, test if values are within tolerance of each other + (see `rtol` & `atol`). + rtol + relative tolerance for inexact checking. Fraction of values in `right`. + atol + absolute tolerance for inexact checking. + nans_compare_equal + if your assert/test requires float NaN != NaN, set this to False. + categorical_as_str + Cast categorical columns to string before comparing. Enabling this helps + compare DataFrames that do not share the same string cache. + + Examples + -------- + >>> from polars.testing import assert_frame_equal + >>> df1 = pl.DataFrame({"a": [1, 2, 3]}) + >>> df2 = pl.DataFrame({"a": [2, 3, 4]}) + >>> assert_frame_equal(df1, df2) # doctest: +SKIP + AssertionError: Values for column 'a' are different. + """ + collect_input_frames = isinstance(left, LazyFrame) and isinstance(right, LazyFrame) + if collect_input_frames: + objs = "LazyFrames" + elif isinstance(left, DataFrame) and isinstance(right, DataFrame): + objs = "DataFrames" + else: + raise_assert_detail("Inputs", "unexpected input types", type(left), type(right)) + + if left_not_right := [c for c in left.columns if c not in right.columns]: + raise AssertionError( + f"columns {left_not_right!r} in left frame, but not in right" + ) + + if right_not_left := [c for c in right.columns if c not in left.columns]: + raise AssertionError( + f"columns {right_not_left!r} in right frame, but not in left" + ) + + if check_column_order and left.columns != right.columns: + raise AssertionError( + f"columns are not in the same order:\n{left.columns!r}\n{right.columns!r}" + ) + + if collect_input_frames: + if check_dtype: # check this _before_ we collect + left_schema, right_schema = left.schema, right.schema + if left_schema != right_schema: + raise_assert_detail( + objs, "lazy schemas are not equal", left_schema, right_schema + ) + left, right = left.collect(), right.collect() # type: ignore[union-attr] + + if left.shape[0] != right.shape[0]: # type: ignore[union-attr] + raise_assert_detail(objs, "length mismatch", left.shape, right.shape) # type: ignore[union-attr] + + if not check_row_order: + try: + left = left.sort(by=left.columns) + right = right.sort(by=left.columns) + except ComputeError as exc: + raise InvalidAssert( + "cannot set `check_row_order=False` on frame with unsortable columns" + ) from exc + + # note: does not assume a particular column order + for c in left.columns: + try: + _assert_series_inner( + left[c], # type: ignore[arg-type, index] + right[c], # type: ignore[arg-type, index] + check_dtype=check_dtype, + check_exact=check_exact, + atol=atol, + rtol=rtol, + nans_compare_equal=nans_compare_equal, + categorical_as_str=categorical_as_str, + ) + except AssertionError as exc: + msg = f"values for column {c!r} are different." + raise AssertionError(msg) from exc + + +def assert_frame_not_equal( + left: DataFrame | LazyFrame, + right: DataFrame | LazyFrame, + *, + check_row_order: bool = True, + check_column_order: bool = True, + check_dtype: bool = True, + check_exact: bool = False, + rtol: float = 1.0e-5, + atol: float = 1.0e-8, + nans_compare_equal: bool = True, + categorical_as_str: bool = False, +) -> None: + """ + Raise AssertionError if `left` DOES equal `right`. + + Parameters + ---------- + left + the DataFrame to compare. + right + the DataFrame to compare with. + check_row_order + if False, frames will compare equal if the required rows are present, + irrespective of the order in which they appear; as this requires + sorting, you cannot set on frames that contain unsortable columns. + check_column_order + if False, frames will compare equal if the required columns are present, + irrespective of the order in which they appear. + check_dtype + if True, data types need to match exactly. + check_exact + if False, test if values are within tolerance of each other + (see `rtol` & `atol`). + rtol + relative tolerance for inexact checking. Fraction of values in `right`. + atol + absolute tolerance for inexact checking. + nans_compare_equal + if your assert/test requires float NaN != NaN, set this to False. + categorical_as_str + Cast categorical columns to string before comparing. Enabling this helps + compare DataFrames that do not share the same string cache. + + Examples + -------- + >>> from polars.testing import assert_frame_not_equal + >>> df1 = pl.DataFrame({"a": [1, 2, 3]}) + >>> df2 = pl.DataFrame({"a": [2, 3, 4]}) + >>> assert_frame_not_equal(df1, df2) + + """ + try: + assert_frame_equal( + left=left, + right=right, + check_column_order=check_column_order, + check_row_order=check_row_order, + check_dtype=check_dtype, + check_exact=check_exact, + rtol=rtol, + atol=atol, + nans_compare_equal=nans_compare_equal, + categorical_as_str=categorical_as_str, + ) + except AssertionError: + return + else: + raise AssertionError("expected the input frames to be unequal") + + +def assert_series_equal( + left: Series, + right: Series, + *, + check_dtype: bool = True, + check_names: bool = True, + check_exact: bool = False, + rtol: float = 1.0e-5, + atol: float = 1.0e-8, + nans_compare_equal: bool = True, + categorical_as_str: bool = False, +) -> None: + """ + Raise detailed AssertionError if `left` does NOT equal `right`. + + Parameters + ---------- + left + the series to compare. + right + the series to compare with. + check_dtype + if True, data types need to match exactly. + check_names + if True, names need to match. + check_exact + if False, test if values are within tolerance of each other + (see `rtol` & `atol`). + rtol + relative tolerance for inexact checking. Fraction of values in `right`. + atol + absolute tolerance for inexact checking. + nans_compare_equal + if your assert/test requires float NaN != NaN, set this to False. + categorical_as_str + Cast categorical columns to string before comparing. Enabling this helps + compare DataFrames that do not share the same string cache. + + Examples + -------- + >>> from polars.testing import assert_series_equal + >>> s1 = pl.Series([1, 2, 3]) + >>> s2 = pl.Series([2, 3, 4]) + >>> assert_series_equal(s1, s2) # doctest: +SKIP + + """ + if not ( + isinstance(left, Series) # type: ignore[redundant-expr] + and isinstance(right, Series) + ): + raise_assert_detail("Inputs", "unexpected input types", type(left), type(right)) + + if len(left) != len(right): + raise_assert_detail("Series", "length mismatch", len(left), len(right)) + + if check_names and left.name != right.name: + raise_assert_detail("Series", "name mismatch", left.name, right.name) + + _assert_series_inner( + left, + right, + check_dtype=check_dtype, + check_exact=check_exact, + atol=atol, + rtol=rtol, + nans_compare_equal=nans_compare_equal, + categorical_as_str=categorical_as_str, + ) + + +def assert_series_not_equal( + left: Series, + right: Series, + *, + check_dtype: bool = True, + check_names: bool = True, + check_exact: bool = False, + rtol: float = 1.0e-5, + atol: float = 1.0e-8, + nans_compare_equal: bool = True, + categorical_as_str: bool = False, +) -> None: + """ + Raise AssertionError if `left` DOES equal `right`. + + Parameters + ---------- + left + the series to compare. + right + the series to compare with. + check_dtype + if True, data types need to match exactly. + check_names + if True, names need to match. + check_exact + if False, test if values are within tolerance of each other + (see `rtol` & `atol`). + rtol + relative tolerance for inexact checking. Fraction of values in `right`. + atol + absolute tolerance for inexact checking. + nans_compare_equal + if your assert/test requires float NaN != NaN, set this to False. + categorical_as_str + Cast categorical columns to string before comparing. Enabling this helps + compare DataFrames that do not share the same string cache. + + Examples + -------- + >>> from polars.testing import assert_series_not_equal + >>> s1 = pl.Series([1, 2, 3]) + >>> s2 = pl.Series([2, 3, 4]) + >>> assert_series_not_equal(s1, s2) + + """ + try: + assert_series_equal( + left=left, + right=right, + check_dtype=check_dtype, + check_names=check_names, + check_exact=check_exact, + rtol=rtol, + atol=atol, + nans_compare_equal=nans_compare_equal, + categorical_as_str=categorical_as_str, + ) + except AssertionError: + return + else: + raise AssertionError("expected the input Series to be unequal") + + +def _assert_series_inner( + left: Series, + right: Series, + *, + check_dtype: bool, + check_exact: bool, + atol: float, + rtol: float, + nans_compare_equal: bool, + categorical_as_str: bool, +) -> None: + """Compare Series dtype + values.""" + if check_dtype and left.dtype != right.dtype: + raise_assert_detail("Series", "dtype mismatch", left.dtype, right.dtype) + + if left.null_count() != right.null_count(): + raise_assert_detail( + "Series", "null_count is not equal", left.null_count(), right.null_count() + ) + + if categorical_as_str and left.dtype == Categorical: + left = left.cast(Utf8) + right = right.cast(Utf8) + + # create mask of which (if any) values are unequal + unequal = left.ne_missing(right) + + # handle NaN values (which compare unequal to themselves) + comparing_float_dtypes = left.dtype in FLOAT_DTYPES and right.dtype in FLOAT_DTYPES + if unequal.any() and nans_compare_equal: + # when both dtypes are scalar floats + if comparing_float_dtypes: + unequal = unequal & ~( + (left.is_nan() & right.is_nan()).fill_null(F.lit(False)) + ) + if comparing_float_dtypes and not nans_compare_equal: + unequal = unequal | left.is_nan() | right.is_nan() + + # check nested dtypes in separate function + if left.dtype.is_nested or right.dtype.is_nested: + if _assert_series_nested( + left=left.filter(unequal), + right=right.filter(unequal), + check_dtype=check_dtype, + check_exact=check_exact, + atol=atol, + rtol=rtol, + nans_compare_equal=nans_compare_equal, + categorical_as_str=categorical_as_str, + ): + return + + try: + can_be_subtracted = hasattr(dtype_to_py_type(left.dtype), "__sub__") + except NotImplementedError: + can_be_subtracted = False + + check_exact = ( + check_exact or not can_be_subtracted or left.is_boolean() or left.is_temporal() + ) + + # assert exact, or with tolerance + if unequal.any(): + if check_exact: + raise_assert_detail( + "Series", + "exact value mismatch", + left=list(left), + right=list(right), + ) + else: + # apply check with tolerance (to the known-unequal matches). + left, right = left.filter(unequal), right.filter(unequal) + + if all(tp in UNSIGNED_INTEGER_DTYPES for tp in (left.dtype, right.dtype)): + # avoid potential "subtract-with-overflow" panic on uint math + s_diff = Series( + "diff", [abs(v1 - v2) for v1, v2 in zip(left, right)], dtype=UInt64 + ) + else: + s_diff = (left - right).abs() + + mismatch, nan_info = False, "" + if ((s_diff > (atol + rtol * right.abs())).sum() != 0) or ( + left.is_null() != right.is_null() + ).any(): + mismatch = True + elif comparing_float_dtypes: + # note: take special care with NaN values. + # if NaNs don't compare as equal, any NaN in the left Series is + # sufficient for a mismatch because the if condition above already + # compares the null values. + if not nans_compare_equal and left.is_nan().any(): + nan_info = " (nans_compare_equal=False)" + mismatch = True + elif (left.is_nan() != right.is_nan()).any(): + nan_info = f" (nans_compare_equal={nans_compare_equal})" + mismatch = True + + if mismatch: + raise_assert_detail( + "Series", + f"value mismatch{nan_info}", + left=list(left), + right=list(right), + ) + + +def _assert_series_nested( + left: Series, + right: Series, + *, + check_dtype: bool, + check_exact: bool, + atol: float, + rtol: float, + nans_compare_equal: bool, + categorical_as_str: bool, +) -> bool: + # check that float values exist at _some_ level of nesting + if not any(tp in FLOAT_DTYPES for tp in unpack_dtypes(left.dtype, right.dtype)): + return False + + # compare nested lists element-wise + elif left.dtype == List == right.dtype: + for s1, s2 in zip(left, right): + if s1 is None and s2 is None: + if nans_compare_equal: + continue + else: + raise_assert_detail( + "Series", + f"Nested value mismatch (nans_compare_equal={nans_compare_equal})", + s1, + s2, + ) + elif (s1 is None and s2 is not None) or (s2 is None and s1 is not None): + raise_assert_detail("Series", "nested value mismatch", s1, s2) + elif len(s1) != len(s2): + raise_assert_detail( + "Series", "nested list length mismatch", len(s1), len(s2) + ) + + _assert_series_inner( + s1, + s2, + check_dtype=check_dtype, + check_exact=check_exact, + atol=atol, + rtol=rtol, + nans_compare_equal=nans_compare_equal, + categorical_as_str=categorical_as_str, + ) + return True + + # unnest structs as series and compare + elif left.dtype == Struct == right.dtype: + ls, rs = left.struct.unnest(), right.struct.unnest() + if len(ls.columns) != len(rs.columns): + raise_assert_detail( + "Series", + "nested struct fields mismatch", + len(ls.columns), + len(rs.columns), + ) + elif len(ls) != len(rs): + raise_assert_detail( + "Series", "nested struct length mismatch", len(ls), len(rs) + ) + for s1, s2 in zip(ls, rs): + _assert_series_inner( + s1, + s2, + check_dtype=check_dtype, + check_exact=check_exact, + atol=atol, + rtol=rtol, + nans_compare_equal=nans_compare_equal, + categorical_as_str=categorical_as_str, + ) + return True + else: + # fall-back to outer codepath (if mismatched dtypes we would expect + # the equality check to fail - unless ALL series values are null) + return False + + +def raise_assert_detail( + obj: str, + detail: str, + left: Any, + right: Any, + exc: AssertionError | None = None, +) -> NoReturn: + """Raise a detailed assertion error.""" + __tracebackhide__ = True + + error_msg = textwrap.dedent( + f"""\ + {obj} are different ({detail}) + [left]: {left} + [right]: {right}\ + """ + ) + + raise AssertionError(error_msg) from exc + + +def is_categorical_dtype(data_type: Any) -> bool: + """Check if the input is a polars Categorical dtype.""" + return ( + type(data_type) is DataTypeClass + and issubclass(data_type, Categorical) + or isinstance(data_type, Categorical) + ) + + +@deprecate_function( + "Use `assert_frame_equal` instead and pass `categorical_as_str=True`.", + version="0.18.13", +) +def assert_frame_equal_local_categoricals(df_a: DataFrame, df_b: DataFrame) -> None: + """Assert frame equal for frames containing categoricals.""" + for (a_name, a_value), (b_name, b_value) in zip( + df_a.schema.items(), df_b.schema.items() + ): + if a_name != b_name: + print(f"{a_name} != {b_name}") + raise AssertionError + if a_value != b_value: + print(f"{a_value} != {b_value}") + raise AssertionError + + cat_to_str = F.col(Categorical).cast(str) + assert_frame_equal(df_a.with_columns(cat_to_str), df_b.with_columns(cat_to_str)) + cat_to_phys = F.col(Categorical).to_physical() + assert_frame_equal(df_a.with_columns(cat_to_phys), df_b.with_columns(cat_to_phys)) diff --git a/py-polars/polars/testing/asserts/__init__.py b/py-polars/polars/testing/asserts/__init__.py deleted file mode 100644 index 4e00da7cc1fa..000000000000 --- a/py-polars/polars/testing/asserts/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from polars.testing.asserts.frame import assert_frame_equal, assert_frame_not_equal -from polars.testing.asserts.series import assert_series_equal, assert_series_not_equal - -__all__ = [ - "assert_frame_equal", - "assert_frame_not_equal", - "assert_series_equal", - "assert_series_not_equal", -] diff --git a/py-polars/polars/testing/asserts/frame.py b/py-polars/polars/testing/asserts/frame.py deleted file mode 100644 index 3920dc57e050..000000000000 --- a/py-polars/polars/testing/asserts/frame.py +++ /dev/null @@ -1,276 +0,0 @@ -from __future__ import annotations - -from typing import cast - -from polars.dataframe import DataFrame -from polars.exceptions import ComputeError, InvalidAssert -from polars.lazyframe import LazyFrame -from polars.testing.asserts.series import _assert_series_values_equal -from polars.testing.asserts.utils import raise_assertion_error - - -def assert_frame_equal( - left: DataFrame | LazyFrame, - right: DataFrame | LazyFrame, - *, - check_row_order: bool = True, - check_column_order: bool = True, - check_dtype: bool = True, - check_exact: bool = False, - rtol: float = 1e-5, - atol: float = 1e-8, - nans_compare_equal: bool = True, - categorical_as_str: bool = False, -) -> None: - """ - Assert that the left and right frame are equal. - - Raises a detailed ``AssertionError`` if the frames differ. - This function is intended for use in unit tests. - - Parameters - ---------- - left - The first DataFrame or LazyFrame to compare. - right - The second DataFrame or LazyFrame to compare. - check_row_order - Require row order to match. - - .. note:: - Setting this to ``False`` requires sorting the data, which will fail on - frames that contain unsortable columns. - check_column_order - Require column order to match. - check_dtype - Require data types to match. - check_exact - Require data values to match exactly. If set to ``False``, values are considered - equal when within tolerance of each other (see ``rtol`` and ``atol``). - Logical types like dates are always checked exactly. - rtol - Relative tolerance for inexact checking. Fraction of values in ``right``. - atol - Absolute tolerance for inexact checking. - nans_compare_equal - Consider NaN values to be equal. - categorical_as_str - Cast categorical columns to string before comparing. Enabling this helps - compare columns that do not share the same string cache. - - See Also - -------- - assert_series_equal - assert_frame_not_equal - - Notes - ----- - When using pytest, it may be worthwhile to shorten Python traceback printing - by passing ``--tb=short``. The default mode tends to be unhelpfully verbose. - More information in the - `pytest docs `_. - - Examples - -------- - >>> from polars.testing import assert_frame_equal - >>> df1 = pl.DataFrame({"a": [1, 2, 3]}) - >>> df2 = pl.DataFrame({"a": [1, 5, 3]}) - >>> assert_frame_equal(df1, df2) # doctest: +SKIP - Traceback (most recent call last): - ... - AssertionError: Series are different (value mismatch) - [left]: [1, 2, 3] - [right]: [1, 5, 3] - - The above exception was the direct cause of the following exception: - - Traceback (most recent call last): - ... - AssertionError: values for column 'a' are different - - """ - lazy = _assert_correct_input_type(left, right) - objects = "LazyFrames" if lazy else "DataFrames" - - _assert_frame_schema_equal( - left, - right, - check_column_order=check_column_order, - check_dtype=check_dtype, - objects=objects, - ) - - if lazy: - left, right = left.collect(), right.collect() # type: ignore[union-attr] - left, right = cast(DataFrame, left), cast(DataFrame, right) - - if left.height != right.height: - raise_assertion_error( - objects, "number of rows does not match", left.height, right.height - ) - - if not check_row_order: - left, right = _sort_dataframes(left, right) - - for c in left.columns: - try: - _assert_series_values_equal( - left.get_column(c), - right.get_column(c), - check_exact=check_exact, - rtol=rtol, - atol=atol, - nans_compare_equal=nans_compare_equal, - categorical_as_str=categorical_as_str, - ) - except AssertionError as exc: - msg = f"values for column {c!r} are different" - raise AssertionError(msg) from exc - - -def _assert_correct_input_type( - left: DataFrame | LazyFrame, right: DataFrame | LazyFrame -) -> bool: - if isinstance(left, DataFrame) and isinstance(right, DataFrame): - return False - elif isinstance(left, LazyFrame) and isinstance(right, LazyFrame): - return True - else: - raise_assertion_error( - "inputs", - "unexpected input types", - type(left).__name__, - type(right).__name__, - ) - - -def _assert_frame_schema_equal( - left: DataFrame | LazyFrame, - right: DataFrame | LazyFrame, - *, - check_dtype: bool, - check_column_order: bool, - objects: str, -) -> None: - left_schema, right_schema = left.schema, right.schema - - # Fast path for equal frames - if left_schema == right_schema: - return - - # Special error message for when column names do not match - if left_schema.keys() != right_schema.keys(): - if left_not_right := [c for c in left_schema if c not in right_schema]: - msg = f"columns {left_not_right!r} in left {objects[:-1]}, but not in right" - raise AssertionError(msg) - else: - right_not_left = [c for c in right_schema if c not in left_schema] - msg = f"columns {right_not_left!r} in right {objects[:-1]}, but not in left" - raise AssertionError(msg) - - if check_column_order: - left_columns, right_columns = list(left_schema), list(right_schema) - if left_columns != right_columns: - detail = "columns are not in the same order" - raise_assertion_error(objects, detail, left_columns, right_columns) - - if check_dtype: - left_schema_dict, right_schema_dict = dict(left_schema), dict(right_schema) - if check_column_order or left_schema_dict != right_schema_dict: - detail = "dtypes do not match" - raise_assertion_error(objects, detail, left_schema_dict, right_schema_dict) - - -def _sort_dataframes(left: DataFrame, right: DataFrame) -> tuple[DataFrame, DataFrame]: - by = left.columns - try: - left = left.sort(by) - right = right.sort(by) - except ComputeError as exc: - msg = "cannot set `check_row_order=False` on frame with unsortable columns" - raise InvalidAssert(msg) from exc - return left, right - - -def assert_frame_not_equal( - left: DataFrame | LazyFrame, - right: DataFrame | LazyFrame, - *, - check_row_order: bool = True, - check_column_order: bool = True, - check_dtype: bool = True, - check_exact: bool = False, - rtol: float = 1e-5, - atol: float = 1e-8, - nans_compare_equal: bool = True, - categorical_as_str: bool = False, -) -> None: - """ - Assert that the left and right frame are **not** equal. - - This function is intended for use in unit tests. - - Parameters - ---------- - left - The first DataFrame or LazyFrame to compare. - right - The second DataFrame or LazyFrame to compare. - check_row_order - Require row order to match. - - .. note:: - Setting this to ``False`` requires sorting the data, which will fail on - frames that contain unsortable columns. - check_column_order - Require column order to match. - check_dtype - Require data types to match. - check_exact - Require data values to match exactly. If set to ``False``, values are considered - equal when within tolerance of each other (see ``rtol`` and ``atol``). - Logical types like dates are always checked exactly. - rtol - Relative tolerance for inexact checking. Fraction of values in ``right``. - atol - Absolute tolerance for inexact checking. - nans_compare_equal - Consider NaN values to be equal. - categorical_as_str - Cast categorical columns to string before comparing. Enabling this helps - compare columns that do not share the same string cache. - - See Also - -------- - assert_frame_equal - assert_series_not_equal - - Examples - -------- - >>> from polars.testing import assert_frame_not_equal - >>> df1 = pl.DataFrame({"a": [1, 2, 3]}) - >>> df2 = pl.DataFrame({"a": [1, 2, 3]}) - >>> assert_frame_not_equal(df1, df2) # doctest: +SKIP - Traceback (most recent call last): - ... - AssertionError: frames are equal - - """ - try: - assert_frame_equal( - left=left, - right=right, - check_column_order=check_column_order, - check_row_order=check_row_order, - check_dtype=check_dtype, - check_exact=check_exact, - rtol=rtol, - atol=atol, - nans_compare_equal=nans_compare_equal, - categorical_as_str=categorical_as_str, - ) - except AssertionError: - return - else: - msg = "frames are equal" - raise AssertionError(msg) diff --git a/py-polars/polars/testing/asserts/series.py b/py-polars/polars/testing/asserts/series.py deleted file mode 100644 index 53b41ac335fe..000000000000 --- a/py-polars/polars/testing/asserts/series.py +++ /dev/null @@ -1,403 +0,0 @@ -from __future__ import annotations - -from polars.datatypes import ( - FLOAT_DTYPES, - NESTED_DTYPES, - NUMERIC_DTYPES, - UNSIGNED_INTEGER_DTYPES, - Categorical, - Int64, - List, - Struct, - UInt64, - Utf8, - unpack_dtypes, -) -from polars.exceptions import ComputeError -from polars.series import Series -from polars.testing.asserts.utils import raise_assertion_error - - -def assert_series_equal( - left: Series, - right: Series, - *, - check_dtype: bool = True, - check_names: bool = True, - check_exact: bool = False, - rtol: float = 1e-5, - atol: float = 1e-8, - nans_compare_equal: bool = True, - categorical_as_str: bool = False, -) -> None: - """ - Assert that the left and right Series are equal. - - Raises a detailed ``AssertionError`` if the Series differ. - This function is intended for use in unit tests. - - Parameters - ---------- - left - The first Series to compare. - right - The second Series to compare. - check_dtype - Require data types to match. - check_names - Require names to match. - check_exact - Require data values to match exactly. If set to ``False``, values are considered - equal when within tolerance of each other (see ``rtol`` and ``atol``). - Logical types like dates are always checked exactly. - rtol - Relative tolerance for inexact checking, given as a fraction of the values in - ``right``. - atol - Absolute tolerance for inexact checking. - nans_compare_equal - Consider NaN values to be equal. - categorical_as_str - Cast categorical columns to string before comparing. Enabling this helps - compare columns that do not share the same string cache. - - See Also - -------- - assert_frame_equal - assert_series_not_equal - - Notes - ----- - When using pytest, it may be worthwhile to shorten Python traceback printing - by passing ``--tb=short``. The default mode tends to be unhelpfully verbose. - More information in the - `pytest docs `_. - - Examples - -------- - >>> from polars.testing import assert_series_equal - >>> s1 = pl.Series([1, 2, 3]) - >>> s2 = pl.Series([1, 5, 3]) - >>> assert_series_equal(s1, s2) # doctest: +SKIP - Traceback (most recent call last): - ... - AssertionError: Series are different (value mismatch) - [left]: [1, 2, 3] - [right]: [1, 5, 3] - - """ - if not (isinstance(left, Series) and isinstance(right, Series)): # type: ignore[redundant-expr] - raise_assertion_error( - "inputs", - "unexpected input types", - type(left).__name__, - type(right).__name__, - ) - - if left.len() != right.len(): - raise_assertion_error("Series", "length mismatch", left.len(), right.len()) - - if check_names and left.name != right.name: - raise_assertion_error("Series", "name mismatch", left.name, right.name) - - if check_dtype and left.dtype != right.dtype: - raise_assertion_error("Series", "dtype mismatch", left.dtype, right.dtype) - - _assert_series_values_equal( - left, - right, - check_exact=check_exact, - rtol=rtol, - atol=atol, - nans_compare_equal=nans_compare_equal, - categorical_as_str=categorical_as_str, - ) - - -def _assert_series_values_equal( - left: Series, - right: Series, - *, - check_exact: bool, - rtol: float, - atol: float, - nans_compare_equal: bool, - categorical_as_str: bool, -) -> None: - """Assert that the values in both Series are equal.""" - # Handle categoricals - if categorical_as_str: - if left.dtype == Categorical: - left = left.cast(Utf8) - if right.dtype == Categorical: - right = right.cast(Utf8) - - # Determine unequal elements - try: - unequal = left.ne_missing(right) - except ComputeError as exc: - raise_assertion_error( - "Series", - "incompatible data types", - left=left.dtype, - right=right.dtype, - cause=exc, - ) - - # Handle NaN values (which compare unequal to themselves) - comparing_floats = left.dtype in FLOAT_DTYPES and right.dtype in FLOAT_DTYPES - if comparing_floats and nans_compare_equal: - both_nan = (left.is_nan() & right.is_nan()).fill_null(False) - unequal = unequal & ~both_nan - - # Check nested dtypes in separate function - if left.dtype in NESTED_DTYPES or right.dtype in NESTED_DTYPES: - if _assert_series_nested( - left=left.filter(unequal), - right=right.filter(unequal), - check_exact=check_exact, - rtol=rtol, - atol=atol, - nans_compare_equal=nans_compare_equal, - categorical_as_str=categorical_as_str, - ): - return - - # If no differences found during exact checking, we're done - if not unequal.any(): - return - - # Only do inexact checking for numeric types - if ( - check_exact - or left.dtype not in NUMERIC_DTYPES - or right.dtype not in NUMERIC_DTYPES - ): - raise_assertion_error( - "Series", - "exact value mismatch", - left=left.to_list(), - right=right.to_list(), - ) - - _assert_series_null_values_match(left, right) - if comparing_floats: - _assert_series_nan_values_match( - left, right, nans_compare_equal=nans_compare_equal - ) - _assert_series_values_within_tolerance( - left, - right, - unequal, - rtol=rtol, - atol=atol, - ) - - -def _assert_series_null_values_match(left: Series, right: Series) -> None: - null_value_mismatch = left.is_null() != right.is_null() - if null_value_mismatch.any(): - raise_assertion_error( - "Series", "null value mismatch", left.to_list(), right.to_list() - ) - - -def _assert_series_nan_values_match( - left: Series, right: Series, *, nans_compare_equal: bool -) -> None: - if nans_compare_equal: - nan_value_mismatch = left.is_nan() != right.is_nan() - if nan_value_mismatch.any(): - raise_assertion_error( - "Series", - "nan value mismatch - nans compare equal", - left.to_list(), - right.to_list(), - ) - - elif left.is_nan().any() or right.is_nan().any(): - raise_assertion_error( - "Series", - "nan value mismatch - nans compare unequal", - left.to_list(), - right.to_list(), - ) - - -def _assert_series_nested( - left: Series, - right: Series, - *, - check_exact: bool, - rtol: float, - atol: float, - nans_compare_equal: bool, - categorical_as_str: bool, -) -> bool: - # check that float values exist at _some_ level of nesting - if not any(tp in FLOAT_DTYPES for tp in unpack_dtypes(left.dtype, right.dtype)): - return False - - # compare nested lists element-wise - elif left.dtype == List == right.dtype: - for s1, s2 in zip(left, right): - if (s1 is None and s2 is not None) or (s2 is None and s1 is not None): - raise_assertion_error("Series", "nested value mismatch", s1, s2) - elif s1.len() != s2.len(): - raise_assertion_error( - "Series", "nested list length mismatch", len(s1), len(s2) - ) - - _assert_series_values_equal( - s1, - s2, - check_exact=check_exact, - rtol=rtol, - atol=atol, - nans_compare_equal=nans_compare_equal, - categorical_as_str=categorical_as_str, - ) - return True - - # unnest structs as series and compare - elif left.dtype == Struct == right.dtype: - ls, rs = left.struct.unnest(), right.struct.unnest() - if len(ls.columns) != len(rs.columns): - raise_assertion_error( - "Series", - "nested struct fields mismatch", - len(ls.columns), - len(rs.columns), - ) - elif len(ls) != len(rs): - raise_assertion_error( - "Series", "nested struct length mismatch", len(ls), len(rs) - ) - for s1, s2 in zip(ls, rs): - _assert_series_values_equal( - s1, - s2, - check_exact=check_exact, - rtol=rtol, - atol=atol, - nans_compare_equal=nans_compare_equal, - categorical_as_str=categorical_as_str, - ) - return True - else: - # fall-back to outer codepath (if mismatched dtypes we would expect - # the equality check to fail - unless ALL series values are null) - return False - - -def _assert_series_values_within_tolerance( - left: Series, - right: Series, - unequal: Series, - *, - rtol: float, - atol: float, -) -> None: - left_unequal, right_unequal = left.filter(unequal), right.filter(unequal) - - difference = _calc_absolute_diff(left_unequal, right_unequal) - tolerance = atol + rtol * right_unequal.abs() - exceeds_tolerance = difference > tolerance - - if exceeds_tolerance.any(): - raise_assertion_error( - "Series", - "value mismatch", - left.to_list(), - right.to_list(), - ) - - -def _calc_absolute_diff(left: Series, right: Series) -> Series: - if left.dtype in UNSIGNED_INTEGER_DTYPES and right.dtype in UNSIGNED_INTEGER_DTYPES: - try: - left = left.cast(Int64) - right = right.cast(Int64) - except ComputeError: - # Handle big UInt64 values through conversion to Python - diff = [abs(v1 - v2) for v1, v2 in zip(left, right)] - return Series(diff, dtype=UInt64) - - return (left - right).abs() - - -def assert_series_not_equal( - left: Series, - right: Series, - *, - check_dtype: bool = True, - check_names: bool = True, - check_exact: bool = False, - rtol: float = 1e-5, - atol: float = 1e-8, - nans_compare_equal: bool = True, - categorical_as_str: bool = False, -) -> None: - """ - Assert that the left and right Series are **not** equal. - - This function is intended for use in unit tests. - - Parameters - ---------- - left - The first Series to compare. - right - The second Series to compare. - check_dtype - Require data types to match. - check_names - Require names to match. - check_exact - Require data values to match exactly. If set to ``False``, values are considered - equal when within tolerance of each other (see ``rtol`` and ``atol``). - Logical types like dates are always checked exactly. - rtol - Relative tolerance for inexact checking, given as a fraction of the values in - ``right``. - atol - Absolute tolerance for inexact checking. - nans_compare_equal - Consider NaN values to be equal. - categorical_as_str - Cast categorical columns to string before comparing. Enabling this helps - compare columns that do not share the same string cache. - - See Also - -------- - assert_series_equal - assert_frame_not_equal - - Examples - -------- - >>> from polars.testing import assert_series_not_equal - >>> s1 = pl.Series([1, 2, 3]) - >>> s2 = pl.Series([1, 2, 3]) - >>> assert_series_not_equal(s1, s2) # doctest: +SKIP - Traceback (most recent call last): - ... - AssertionError: Series are equal - - """ - try: - assert_series_equal( - left=left, - right=right, - check_dtype=check_dtype, - check_names=check_names, - check_exact=check_exact, - rtol=rtol, - atol=atol, - nans_compare_equal=nans_compare_equal, - categorical_as_str=categorical_as_str, - ) - except AssertionError: - return - else: - msg = "Series are equal" - raise AssertionError(msg) diff --git a/py-polars/polars/testing/asserts/utils.py b/py-polars/polars/testing/asserts/utils.py deleted file mode 100644 index 1b7ac40c7814..000000000000 --- a/py-polars/polars/testing/asserts/utils.py +++ /dev/null @@ -1,12 +0,0 @@ -from __future__ import annotations - -from typing import Any, NoReturn - - -def raise_assertion_error( - objects: str, detail: str, left: Any, right: Any, *, cause: Exception | None = None -) -> NoReturn: - """Raise a detailed assertion error.""" - __tracebackhide__ = True - msg = f"{objects} are different ({detail})\n[left]: {left}\n[right]: {right}" - raise AssertionError(msg) from cause diff --git a/py-polars/polars/testing/parametric/primitives.py b/py-polars/polars/testing/parametric/primitives.py index 27194b56c5e5..5f3c34730993 100644 --- a/py-polars/polars/testing/parametric/primitives.py +++ b/py-polars/polars/testing/parametric/primitives.py @@ -26,6 +26,7 @@ ) from polars.series import Series from polars.string_cache import StringCache +from polars.testing.asserts import is_categorical_dtype from polars.testing.parametric.strategies import ( _flexhash, all_strategies, @@ -430,7 +431,7 @@ def draw_series(draw: DrawFn) -> Series: dtype=series_dtype, values=series_values, ) - if dtype == Categorical: + if is_categorical_dtype(dtype): s = s.cast(Categorical) if series_size and (chunked or (chunked is None and draw(booleans()))): split_at = series_size // 2 diff --git a/py-polars/polars/type_aliases.py b/py-polars/polars/type_aliases.py index a1bbb246b1bc..ea6764e1d2af 100644 --- a/py-polars/polars/type_aliases.py +++ b/py-polars/polars/type_aliases.py @@ -22,7 +22,7 @@ import sys from polars import DataFrame, Expr, LazyFrame, Series - from polars.datatypes import DataType, DataTypeClass, IntegerType, TemporalType + from polars.datatypes import DataType, DataTypeClass, IntegralType, TemporalType from polars.dependencies import numpy as np from polars.dependencies import pandas as pd from polars.dependencies import pyarrow as pa @@ -36,7 +36,7 @@ # Data types PolarsDataType: TypeAlias = Union["DataTypeClass", "DataType"] PolarsTemporalType: TypeAlias = Union[Type["TemporalType"], "TemporalType"] -PolarsIntegerType: TypeAlias = Union[Type["IntegerType"], "IntegerType"] +PolarsIntegerType: TypeAlias = Union[Type["IntegralType"], "IntegralType"] OneOrMoreDataTypes: TypeAlias = Union[PolarsDataType, Iterable[PolarsDataType]] PythonDataType: TypeAlias = Union[ Type[int], diff --git a/py-polars/polars/utils/_async.py b/py-polars/polars/utils/_async.py index 3294bca9428f..42ddfe85c313 100644 --- a/py-polars/polars/utils/_async.py +++ b/py-polars/polars/utils/_async.py @@ -24,8 +24,8 @@ def __init__(self) -> None: "polars.collect_all_async(gevent=True)" ) - from gevent.event import AsyncResult # type: ignore[import-untyped] - from gevent.hub import get_hub # type: ignore[import-untyped] + from gevent.event import AsyncResult # type: ignore[import] + from gevent.hub import get_hub # type: ignore[import] self._value: None | Exception | PyDataFrame | list[PyDataFrame] = None self._result = AsyncResult() diff --git a/py-polars/polars/utils/_construction.py b/py-polars/polars/utils/_construction.py index 8baf75cd19fb..b98bfe41a764 100644 --- a/py-polars/polars/utils/_construction.py +++ b/py-polars/polars/utils/_construction.py @@ -860,7 +860,7 @@ def dict_to_pydf( lambda t: pl.Series(t[0], t[1]) if isinstance(t[1], np.ndarray) else t[1], - list(data.items()), + [(k, v) for k, v in data.items()], ), ) ) @@ -1024,11 +1024,7 @@ def _sequence_of_sequence_to_pydf( local_schema_override = ( include_unknowns(schema_overrides, column_names) if schema_overrides else {} ) - if ( - column_names - and len(first_element) > 0 - and len(first_element) != len(column_names) - ): + if column_names and first_element and len(first_element) != len(column_names): raise ShapeError("the row data does not match the number of columns") unpack_nested = False diff --git a/py-polars/polars/utils/convert.py b/py-polars/polars/utils/convert.py index ec380519b8f6..85caae3266b3 100644 --- a/py-polars/polars/utils/convert.py +++ b/py-polars/polars/utils/convert.py @@ -97,7 +97,7 @@ def _negate_duration(duration: str) -> str: return f"-{duration}" -def _datetime_to_pl_timestamp(dt: datetime, time_unit: TimeUnit) -> int: +def _datetime_to_pl_timestamp(dt: datetime, time_unit: TimeUnit | None) -> int: """Convert a python datetime to a timestamp in given time unit.""" if dt.tzinfo is None: # Make sure to use UTC rather than system time zone. diff --git a/py-polars/pyproject.toml b/py-polars/pyproject.toml index 24b6ac8e6e46..21ce228ce519 100644 --- a/py-polars/pyproject.toml +++ b/py-polars/pyproject.toml @@ -190,7 +190,6 @@ strict = true [tool.pytest.ini_options] addopts = [ - "--tb=short", "--strict-config", "--strict-markers", "--import-mode=importlib", diff --git a/py-polars/requirements-dev.txt b/py-polars/requirements-dev.txt index f63137249749..33094478f6f2 100644 --- a/py-polars/requirements-dev.txt +++ b/py-polars/requirements-dev.txt @@ -53,7 +53,7 @@ gevent # TOOLING # ------- -hypothesis==6.88.1 +hypothesis==6.87.1 pytest==7.4.0 pytest-cov==4.1.0 pytest-xdist==3.3.1 diff --git a/py-polars/requirements-lint.txt b/py-polars/requirements-lint.txt index 40588f518050..0748b30ce8e8 100644 --- a/py-polars/requirements-lint.txt +++ b/py-polars/requirements-lint.txt @@ -1,5 +1,5 @@ -black==23.10.0 +black==23.9.1 blackdoc==0.3.8 -mypy==1.6.0 -ruff==0.1.0 -typos==1.16.20 +mypy==1.5.1 +ruff==0.0.287 +typos==1.16.8 diff --git a/py-polars/src/conversion.rs b/py-polars/src/conversion.rs index cdc0e9b69b8c..71871a48bfca 100644 --- a/py-polars/src/conversion.rs +++ b/py-polars/src/conversion.rs @@ -312,10 +312,7 @@ impl ToPyObject for Wrap { DataType::Array(inner, size) => { let inner = Wrap(*inner.clone()).to_object(py); let list_class = pl.getattr(intern!(py, "Array")).unwrap(); - let kwargs = PyDict::new(py); - kwargs.set_item("inner", inner).unwrap(); - kwargs.set_item("width", size).unwrap(); - list_class.call((), Some(kwargs)).unwrap().into() + list_class.call1((*size, inner)).unwrap().into() }, DataType::List(inner) => { let inner = Wrap(*inner.clone()).to_object(py); diff --git a/py-polars/src/dataframe.rs b/py-polars/src/dataframe.rs index 98e8bae5ce89..582253078265 100644 --- a/py-polars/src/dataframe.rs +++ b/py-polars/src/dataframe.rs @@ -1030,7 +1030,7 @@ impl PyDataFrame { self.df.find_idx_by_name(name) } - pub fn get_column(&self, name: &str) -> PyResult { + pub fn column(&self, name: &str) -> PyResult { let series = self .df .column(name) diff --git a/py-polars/src/error.rs b/py-polars/src/error.rs index 07046772316a..4f747dfe3118 100644 --- a/py-polars/src/error.rs +++ b/py-polars/src/error.rs @@ -64,17 +64,17 @@ impl Debug for PyPolarsErr { } } -create_exception!(polars.exceptions, ColumnNotFoundError, PyException); -create_exception!(polars.exceptions, ComputeError, PyException); -create_exception!(polars.exceptions, DuplicateError, PyException); -create_exception!(polars.exceptions, InvalidOperationError, PyException); -create_exception!(polars.exceptions, NoDataError, PyException); -create_exception!(polars.exceptions, OutOfBoundsError, PyException); -create_exception!(polars.exceptions, SchemaError, PyException); -create_exception!(polars.exceptions, SchemaFieldNotFoundError, PyException); -create_exception!(polars.exceptions, ShapeError, PyException); -create_exception!(polars.exceptions, StringCacheMismatchError, PyException); -create_exception!(polars.exceptions, StructFieldNotFoundError, PyException); +create_exception!(exceptions, ColumnNotFoundError, PyException); +create_exception!(exceptions, ComputeError, PyException); +create_exception!(exceptions, DuplicateError, PyException); +create_exception!(exceptions, InvalidOperationError, PyException); +create_exception!(exceptions, NoDataError, PyException); +create_exception!(exceptions, OutOfBoundsError, PyException); +create_exception!(exceptions, SchemaError, PyException); +create_exception!(exceptions, SchemaFieldNotFoundError, PyException); +create_exception!(exceptions, ShapeError, PyException); +create_exception!(exceptions, StringCacheMismatchError, PyException); +create_exception!(exceptions, StructFieldNotFoundError, PyException); #[macro_export] macro_rules! raise_err( diff --git a/py-polars/src/expr/general.rs b/py-polars/src/expr/general.rs index 46137400969c..88173745d9b1 100644 --- a/py-polars/src/expr/general.rs +++ b/py-polars/src/expr/general.rs @@ -741,8 +741,8 @@ impl PyExpr { } #[cfg(feature = "pct_change")] - fn pct_change(&self, n: Self) -> Self { - self.inner.clone().pct_change(n.inner).into() + fn pct_change(&self, n: i64) -> Self { + self.inner.clone().pct_change(n).into() } fn skew(&self, bias: bool) -> Self { @@ -892,12 +892,11 @@ impl PyExpr { lib: &str, symbol: &str, args: Vec, - kwargs: Vec, is_elementwise: bool, input_wildcard_expansion: bool, auto_explode: bool, cast_to_supertypes: bool, - ) -> PyResult { + ) -> Self { use polars_plan::prelude::*; let inner = self.inner.clone(); @@ -912,12 +911,11 @@ impl PyExpr { input.push(a.inner) } - Ok(Expr::Function { + Expr::Function { input, function: FunctionExpr::FfiPlugin { lib: Arc::from(lib), symbol: Arc::from(symbol), - kwargs: Arc::from(kwargs), }, options: FunctionOptions { collect_groups, @@ -927,6 +925,6 @@ impl PyExpr { ..Default::default() }, } - .into()) + .into() } } diff --git a/py-polars/src/expr/list.rs b/py-polars/src/expr/list.rs index a8a6db6613b9..dbac07c08a3b 100644 --- a/py-polars/src/expr/list.rs +++ b/py-polars/src/expr/list.rs @@ -115,36 +115,6 @@ impl PyExpr { self.inner.clone().list().drop_nulls().into() } - #[cfg(feature = "list_sample")] - fn list_sample_n( - &self, - n: PyExpr, - with_replacement: bool, - shuffle: bool, - seed: Option, - ) -> Self { - self.inner - .clone() - .list() - .sample_n(n.inner, with_replacement, shuffle, seed) - .into() - } - - #[cfg(feature = "list_sample")] - fn list_sample_fraction( - &self, - fraction: PyExpr, - with_replacement: bool, - shuffle: bool, - seed: Option, - ) -> Self { - self.inner - .clone() - .list() - .sample_fraction(fraction.inner, with_replacement, shuffle, seed) - .into() - } - #[cfg(feature = "list_take")] fn list_take(&self, index: PyExpr, null_on_oob: bool) -> Self { self.inner diff --git a/py-polars/src/lazyframe.rs b/py-polars/src/lazyframe.rs index 64815d0f551c..003266bbe02d 100644 --- a/py-polars/src/lazyframe.rs +++ b/py-polars/src/lazyframe.rs @@ -661,7 +661,7 @@ impl PyLazyFrame { PyLazyGroupBy { lgb: Some(lazy_gb) } } - fn rolling( + fn group_by_rolling( &mut self, index_column: PyExpr, period: &str, diff --git a/py-polars/src/lazygroupby.rs b/py-polars/src/lazygroupby.rs index 2364fad0094d..d74163b8e43b 100644 --- a/py-polars/src/lazygroupby.rs +++ b/py-polars/src/lazygroupby.rs @@ -19,18 +19,18 @@ pub struct PyLazyGroupBy { #[pymethods] impl PyLazyGroupBy { fn agg(&mut self, aggs: Vec) -> PyLazyFrame { - let lgb = self.lgb.clone().unwrap(); + let lgb = self.lgb.take().unwrap(); let aggs = aggs.to_exprs(); lgb.agg(aggs).into() } fn head(&mut self, n: usize) -> PyLazyFrame { - let lgb = self.lgb.clone().unwrap(); + let lgb = self.lgb.take().unwrap(); lgb.head(Some(n)).into() } fn tail(&mut self, n: usize) -> PyLazyFrame { - let lgb = self.lgb.clone().unwrap(); + let lgb = self.lgb.take().unwrap(); lgb.tail(Some(n)).into() } @@ -39,7 +39,7 @@ impl PyLazyGroupBy { lambda: PyObject, schema: Option>, ) -> PyResult { - let lgb = self.lgb.clone().unwrap(); + let lgb = self.lgb.take().unwrap(); let schema = match schema { Some(schema) => Arc::new(schema.0), None => LazyFrame::from(lgb.logical_plan.clone()) diff --git a/py-polars/src/series/export.rs b/py-polars/src/series/export.rs index 2ec3eec4d343..fb1f9df18e08 100644 --- a/py-polars/src/series/export.rs +++ b/py-polars/src/series/export.rs @@ -69,11 +69,6 @@ impl PySeries { PyArray1::from_iter(py, ca.into_iter().map(|opt_v| opt_v.to_object(py))); Ok(np_arr.into_py(py)) }, - DataType::Null => { - let n = s.len(); - let np_arr = PyArray1::from_iter(py, std::iter::repeat(f32::NAN).take(n)); - Ok(np_arr.into_py(py)) - }, dt => { raise_err!( format!("'to_numpy' not supported for dtype: {dt:?}"), diff --git a/py-polars/src/series/numpy_ufunc.rs b/py-polars/src/series/numpy_ufunc.rs index 91265aca0789..f37e20e33b91 100644 --- a/py-polars/src/series/numpy_ufunc.rs +++ b/py-polars/src/series/numpy_ufunc.rs @@ -86,8 +86,11 @@ macro_rules! impl_ufuncs { assert_eq!(get_refcnt(out_array), 3); let validity = self.series.chunks()[0].validity().cloned(); - let ca = - ChunkedArray::<$type>::from_vec_validity(self.name(), av, validity); + let ca = ChunkedArray::<$type>::new_from_owned_with_null_bitmap( + self.name(), + av, + validity, + ); PySeries::new(ca.into_series()) }, Err(e) => { diff --git a/py-polars/tests/parametric/test_groupby_rolling.py b/py-polars/tests/parametric/test_groupby_rolling.py index 62c1b188c12e..e5b0e8716c53 100644 --- a/py-polars/tests/parametric/test_groupby_rolling.py +++ b/py-polars/tests/parametric/test_groupby_rolling.py @@ -28,7 +28,7 @@ data=st.data(), time_unit=strategy_time_unit, ) -def test_rolling( +def test_group_by_rolling( period: str, offset: str, closed: ClosedInterval, @@ -57,7 +57,7 @@ def test_rolling( ) ) df = dataframe.sort("ts") - result = df.rolling("ts", period=period, offset=offset, closed=closed).agg( + result = df.group_by_rolling("ts", period=period, offset=offset, closed=closed).agg( pl.col("value") ) diff --git a/py-polars/tests/test_udfs.py b/py-polars/tests/test_udfs.py index c3f9c32360fa..956edaecd1f6 100644 --- a/py-polars/tests/test_udfs.py +++ b/py-polars/tests/test_udfs.py @@ -159,7 +159,7 @@ ) def test_bytecode_parser_expression(col: str, func: str, expected: str) -> None: try: - import udfs # type: ignore[import-not-found] + import udfs # type: ignore[import] except ModuleNotFoundError as exc: assert "No module named 'udfs'" in str(exc) # noqa: PT017 # Skip test if udfs can't be imported because it's not in the path. diff --git a/py-polars/tests/unit/dataframe/test_df.py b/py-polars/tests/unit/dataframe/test_df.py index 0f32b704deff..b30c8bc44e21 100644 --- a/py-polars/tests/unit/dataframe/test_df.py +++ b/py-polars/tests/unit/dataframe/test_df.py @@ -124,7 +124,7 @@ def test_selection() -> None: # select columns by mask assert df[:2, :1].rows() == [(1,), (2,)] - assert df[:2, ["a"]].rows() == [(1,), (2,)] + assert df[:2, "a"].rows() == [(1,), (2,)] # type: ignore[attr-defined] # column selection by string(s) in first dimension assert df["a"].to_list() == [1, 2, 3] @@ -136,7 +136,7 @@ def test_selection() -> None: assert_frame_equal(df[-1], pl.DataFrame({"a": [3], "b": [3.0], "c": ["c"]})) # row, column selection when using two dimensions - assert df[:, "a"].to_list() == [1, 2, 3] + assert df[:, 0].to_list() == [1, 2, 3] assert df[:, 1].to_list() == [1.0, 2.0, 3.0] assert df[:2, 2].to_list() == ["a", "b"] @@ -155,6 +155,7 @@ def test_selection() -> None: assert typing.cast(float, df[1, 1]) == 2.0 assert typing.cast(int, df[2, 0]) == 3 + assert df[[0, 1], "b"].rows() == [(1.0,), (2.0,)] # type: ignore[attr-defined] assert df[[2], ["a", "b"]].rows() == [(3, 3.0)] assert df.to_series(0).name == "a" assert (df["a"] == df["a"]).sum() == 3 diff --git a/py-polars/tests/unit/datatypes/test_array.py b/py-polars/tests/unit/datatypes/test_array.py index fc4064addbb6..a0eaeb3d7e0c 100644 --- a/py-polars/tests/unit/datatypes/test_array.py +++ b/py-polars/tests/unit/datatypes/test_array.py @@ -9,7 +9,7 @@ def test_cast_list_array() -> None: payload = [[1, 2, 3], [4, 2, 3]] s = pl.Series(payload) - dtype = pl.Array(inner=pl.Int64, width=3) + dtype = pl.Array(width=3, inner=pl.Int64) out = s.cast(dtype) assert out.dtype == dtype assert out.to_list() == payload @@ -20,19 +20,19 @@ def test_cast_list_array() -> None: pl.ComputeError, match=r"incompatible offsets in source list", ): - s.cast(pl.Array(inner=pl.Int64, width=2)) + s.cast(pl.Array(width=2, inner=pl.Int64)) def test_array_construction() -> None: payload = [[1, 2, 3], [4, 2, 3]] - dtype = pl.Array(inner=pl.Int64, width=3) + dtype = pl.Array(width=3, inner=pl.Int64) s = pl.Series(payload, dtype=dtype) assert s.dtype == dtype assert s.to_list() == payload # inner type - dtype = pl.Array(inner=pl.UInt8, width=2) + dtype = pl.Array(2, pl.UInt8) payload = [[1, 2], [3, 4]] s = pl.Series(payload, dtype=dtype) assert s.dtype == dtype @@ -41,13 +41,13 @@ def test_array_construction() -> None: # create using schema df = pl.DataFrame( schema={ - "a": pl.Array(inner=pl.Float32, width=3), - "b": pl.Array(inner=pl.Datetime("ms"), width=5), + "a": pl.Array(width=3, inner=pl.Float32), + "b": pl.Array(width=5, inner=pl.Datetime("ms")), } ) assert df.dtypes == [ - pl.Array(inner=pl.Float32, width=3), - pl.Array(inner=pl.Datetime("ms"), width=5), + pl.Array(width=3, inner=pl.Float32), + pl.Array(width=5, inner=pl.Datetime("ms")), ] assert df.rows() == [] @@ -56,9 +56,7 @@ def test_array_in_group_by() -> None: df = pl.DataFrame( [ pl.Series("id", [1, 2]), - pl.Series( - "list", [[1, 2], [5, 5]], dtype=pl.Array(inner=pl.UInt8, width=2) - ), + pl.Series("list", [[1, 2], [5, 5]], dtype=pl.Array(2, pl.UInt8)), ] ) @@ -85,7 +83,7 @@ def test_array_in_group_by() -> None: def test_array_invalid_operation() -> None: s = pl.Series( [[1, 2], [8, 9]], - dtype=pl.Array(inner=pl.Int32, width=2), + dtype=pl.Array(width=2, inner=pl.Int32), ) with pytest.raises( InvalidOperationError, @@ -96,22 +94,11 @@ def test_array_invalid_operation() -> None: def test_array_concat() -> None: a_df = pl.DataFrame({"a": [[0, 1], [1, 0]]}).select( - pl.col("a").cast(pl.Array(inner=pl.Int32, width=2)) + pl.col("a").cast(pl.Array(width=2, inner=pl.Int32)) ) b_df = pl.DataFrame({"a": [[1, 1], [0, 0]]}).select( - pl.col("a").cast(pl.Array(inner=pl.Int32, width=2)) + pl.col("a").cast(pl.Array(width=2, inner=pl.Int32)) ) assert pl.concat([a_df, b_df]).to_dict(False) == { "a": [[0, 1], [1, 0], [1, 1], [0, 0]] } - - -def test_array_init_deprecation() -> None: - with pytest.deprecated_call(): - pl.Array(2) - with pytest.deprecated_call(): - pl.Array(2, pl.Utf8) - with pytest.deprecated_call(): - pl.Array(2, inner=pl.Utf8) - with pytest.deprecated_call(): - pl.Array(width=2) diff --git a/py-polars/tests/unit/datatypes/test_categorical.py b/py-polars/tests/unit/datatypes/test_categorical.py index be707cb207a4..af727b187a75 100644 --- a/py-polars/tests/unit/datatypes/test_categorical.py +++ b/py-polars/tests/unit/datatypes/test_categorical.py @@ -422,18 +422,3 @@ def test_categorical_collect_11408() -> None: "groups": ["a", "b", "c"], "cats": ["a", "b", "c"], } - - -def test_categorical_nested_cast_unchecked() -> None: - s = pl.Series("cat", [["cat"]]).cast(pl.List(pl.Categorical)) - assert pl.Series([s]).to_list() == [[["cat"]]] - - -def test_categorical_update_lengths() -> None: - with pl.StringCache(): - s1 = pl.Series(["", ""], dtype=pl.Categorical) - s2 = pl.Series([None, "", ""], dtype=pl.Categorical) - - s = pl.concat([s1, s2], rechunk=False) - assert s.null_count() == 1 - assert s.len() == 5 diff --git a/py-polars/tests/unit/datatypes/test_list.py b/py-polars/tests/unit/datatypes/test_list.py index 23e1a7be42b4..abe5ec123780 100644 --- a/py-polars/tests/unit/datatypes/test_list.py +++ b/py-polars/tests/unit/datatypes/test_list.py @@ -43,7 +43,7 @@ def test_dtype() -> None: "dt": pl.List(pl.Date), "dtm": pl.List(pl.Datetime), } - assert all(tp in pl.NESTED_DTYPES for tp in df.dtypes) + assert all(tp.is_nested for tp in df.dtypes) assert df.schema["i"].inner == pl.Int8 # type: ignore[union-attr] assert df.rows() == [ ( @@ -69,15 +69,17 @@ def test_categorical() -> None: out = ( df.group_by(["a", "b"]) .agg( - pl.col("c").count().alias("num_different_c"), - pl.col("c").alias("c_values"), + [ + pl.col("c").count().alias("num_different_c"), + pl.col("c").alias("c_values"), + ] ) .filter(pl.col("num_different_c") >= 2) .to_series(3) ) assert out.inner_dtype == pl.Categorical - assert out.inner_dtype not in pl.NESTED_DTYPES + assert not out.inner_dtype.is_nested def test_cast_inner() -> None: @@ -466,10 +468,10 @@ def test_list_recursive_categorical_cast() -> None: @pytest.mark.parametrize( ("data", "expected_data", "dtype"), [ - ([None, 1, 2], [None, [1], [2]], pl.Int64), - ([None, 1.0, 2.0], [None, [1.0], [2.0]], pl.Float64), - ([None, "x", "y"], [None, ["x"], ["y"]], pl.Utf8), - ([None, True, False], [None, [True], [False]], pl.Boolean), + ([1, 2], [[1], [2]], pl.Int64), + ([1.0, 2.0], [[1.0], [2.0]], pl.Float64), + (["x", "y"], [["x"], ["y"]], pl.Utf8), + ([True, False], [[True], [False]], pl.Boolean), ], ) def test_non_nested_cast_to_list( @@ -563,11 +565,3 @@ def test_list_inner_cast_physical_11513() -> None: }, ) assert df.select(pl.col("struct").take(0)).to_dict(False) == {"struct": [[]]} - - -@pytest.mark.parametrize( - ("dtype", "expected"), [(pl.List, True), (pl.Struct, True), (pl.Utf8, False)] -) -def test_list_is_nested_deprecated(dtype: PolarsDataType, expected: bool) -> None: - with pytest.deprecated_call(): - assert dtype.is_nested is expected diff --git a/py-polars/tests/unit/datatypes/test_struct.py b/py-polars/tests/unit/datatypes/test_struct.py index 542c6f3b1b4d..b372be41f045 100644 --- a/py-polars/tests/unit/datatypes/test_struct.py +++ b/py-polars/tests/unit/datatypes/test_struct.py @@ -152,7 +152,7 @@ def test_struct_unnest_multiple() -> None: # List input result = df_structs.unnest(["s1", "s2"]) assert_frame_equal(result, df) - assert all(tp in pl.NESTED_DTYPES for tp in df_structs.dtypes) + assert all(tp.is_nested for tp in df_structs.dtypes) # Positional input result = df_structs.unnest("s1", "s2") @@ -645,8 +645,8 @@ def test_empty_struct() -> None: pl.List, pl.List(pl.Null), pl.List(pl.Utf8), - pl.Array(inner=pl.Null, width=32), - pl.Array(inner=pl.UInt8, width=16), + pl.Array(32), + pl.Array(16, inner=pl.UInt8), pl.Struct, pl.Struct([pl.Field("", pl.Null)]), pl.Struct([pl.Field("x", pl.UInt32), pl.Field("y", pl.Float64)]), diff --git a/py-polars/tests/unit/datatypes/test_temporal.py b/py-polars/tests/unit/datatypes/test_temporal.py index aa79cfa82c3e..290b2087fbc2 100644 --- a/py-polars/tests/unit/datatypes/test_temporal.py +++ b/py-polars/tests/unit/datatypes/test_temporal.py @@ -612,7 +612,7 @@ def test_rolling() -> None: period: str | timedelta for period in ("2d", timedelta(days=2)): # type: ignore[assignment] - out = df.rolling(index_column="dt", period=period).agg( + out = df.group_by_rolling(index_column="dt", period=period).agg( [ pl.sum("a").alias("sum_a"), pl.min("a").alias("min_a"), @@ -820,7 +820,7 @@ def test_asof_join_tolerance_grouper() -> None: def test_rolling_group_by_by_argument() -> None: df = pl.DataFrame({"times": range(10), "groups": [1] * 4 + [2] * 6}) - out = df.rolling("times", period="5i", by=["groups"]).agg( + out = df.group_by_rolling("times", period="5i", by=["groups"]).agg( pl.col("times").alias("agg_list") ) @@ -846,7 +846,7 @@ def test_rolling_group_by_by_argument() -> None: assert_frame_equal(out, expected) -def test_rolling_mean_3020() -> None: +def test_group_by_rolling_mean_3020() -> None: df = pl.DataFrame( { "Date": [ @@ -864,7 +864,7 @@ def test_rolling_mean_3020() -> None: period: str | timedelta for period in ("1w", timedelta(days=7)): # type: ignore[assignment] - result = df.rolling(index_column="Date", period=period).agg( + result = df.group_by_rolling(index_column="Date", period=period).agg( pl.col("val").mean().alias("val_mean") ) expected = pl.DataFrame( @@ -1275,7 +1275,7 @@ def test_unique_counts_on_dates() -> None: } -def test_rolling_by_ordering() -> None: +def test_group_by_rolling_by_ordering() -> None: # we must check that the keys still match the time labels after the rolling window # with a `by` argument. df = pl.DataFrame( @@ -1294,7 +1294,7 @@ def test_rolling_by_ordering() -> None: } ).set_sorted("dt") - assert df.rolling( + assert df.group_by_rolling( index_column="dt", period="2m", closed="both", @@ -1321,7 +1321,7 @@ def test_rolling_by_ordering() -> None: } -def test_rolling_by_() -> None: +def test_group_by_rolling_by_() -> None: df = pl.DataFrame({"group": pl.arange(0, 3, eager=True)}).join( pl.DataFrame( { @@ -1334,13 +1334,13 @@ def test_rolling_by_() -> None: ) out = ( df.sort("datetime") - .rolling(index_column="datetime", by="group", period=timedelta(days=3)) + .group_by_rolling(index_column="datetime", by="group", period=timedelta(days=3)) .agg([pl.count().alias("count")]) ) expected = ( df.sort(["group", "datetime"]) - .rolling(index_column="datetime", by="group", period="3d") + .group_by_rolling(index_column="datetime", by="group", period="3d") .agg([pl.count().alias("count")]) ) assert_frame_equal(out.sort(["group", "datetime"]), expected) @@ -2590,7 +2590,7 @@ def test_rolling_group_by_empty_groups_by_take_6330() -> None: .set_sorted("Date") ) assert ( - df.rolling( + df.group_by_rolling( index_column="Date", period="2i", offset="-2i", @@ -2752,7 +2752,7 @@ def test_pytime_conversion(tm: time) -> None: assert s.to_list() == [tm] -def test_rolling_duplicates() -> None: +def test_group_by_rolling_duplicates() -> None: df = pl.DataFrame( { "ts": [datetime(2000, 1, 1, 0, 0), datetime(2000, 1, 1, 0, 0)], diff --git a/py-polars/tests/unit/io/test_database_read.py b/py-polars/tests/unit/io/test_database_read.py index 824ffb989fd2..aa872c00a85b 100644 --- a/py-polars/tests/unit/io/test_database_read.py +++ b/py-polars/tests/unit/io/test_database_read.py @@ -15,12 +15,9 @@ import polars as pl from polars.exceptions import UnsuitableSQLError -from polars.io.database import _ARROW_DRIVER_REGISTRY_ from polars.testing import assert_frame_equal if TYPE_CHECKING: - import pyarrow as pa - from polars.type_aliases import DbReadEngine, SchemaDefinition, SchemaDict @@ -87,77 +84,6 @@ class ExceptionTestParams(NamedTuple): kwargs: dict[str, Any] | None = None -class MockConnection: - """Mock connection class for databases we can't test in CI.""" - - def __init__( - self, - driver: str, - batch_size: int | None, - test_data: pa.Table, - repeat_batch_calls: bool, - ) -> None: - self.__class__.__module__ = driver - self._cursor = MockCursor( - repeat_batch_calls=repeat_batch_calls, - batched=(batch_size is not None), - test_data=test_data, - ) - - def close(self) -> None: # noqa: D102 - pass - - def cursor(self) -> Any: # noqa: D102 - return self._cursor - - -class MockCursor: - """Mock cursor class for databases we can't test in CI.""" - - def __init__( - self, - batched: bool, - test_data: pa.Table, - repeat_batch_calls: bool, - ) -> None: - self.resultset = MockResultSet(test_data, batched, repeat_batch_calls) - self.called: list[str] = [] - self.batched = batched - self.n_calls = 1 - - def __getattr__(self, item: str) -> Any: - if "fetch" in item: - self.called.append(item) - return self.resultset - super().__getattr__(item) # type: ignore[misc] - - def close(self) -> Any: # noqa: D102 - pass - - def execute(self, query: str) -> Any: # noqa: D102 - return self - - -class MockResultSet: - """Mock resultset class for databases we can't test in CI.""" - - def __init__( - self, test_data: pa.Table, batched: bool, repeat_batch_calls: bool = False - ): - self.test_data = test_data - self.repeat_batched_calls = repeat_batch_calls - self.batched = batched - self.n_calls = 1 - - def __call__(self, *args: Any, **kwargs: Any) -> Any: # noqa: D102 - if self.repeat_batched_calls: - res = self.test_data[: None if self.n_calls else 0] - self.n_calls -= 1 - else: - res = iter((self.test_data,)) - return res - - @pytest.mark.write_disk() @pytest.mark.parametrize( ( @@ -381,9 +307,45 @@ def test_read_database_parameterisd(tmp_path: Path) -> None: ) -@pytest.mark.parametrize( - ("driver", "batch_size", "iter_batches", "expected_call"), - [ +def test_read_database_mocked() -> None: + arr = pl.DataFrame({"x": [1, 2, 3], "y": ["aa", "bb", "cc"]}).to_arrow() + + class MockConnection: + def __init__(self, driver: str, batch_size: int | None = None) -> None: + self.__class__.__module__ = driver + self._cursor = MockCursor(batched=batch_size is not None) + + def close(self) -> None: + pass + + def cursor(self) -> Any: + return self._cursor + + class MockCursor: + def __init__(self, batched: bool) -> None: + self.called: list[str] = [] + self.batched = batched + + def __getattr__(self, item: str) -> Any: + if "fetch" in item: + res = ( + (lambda *args, **kwargs: (arr for _ in range(1))) + if self.batched + else (lambda *args, **kwargs: arr) + ) + self.called.append(item) + return res + super().__getattr__(item) # type: ignore[misc] + + def close(self) -> Any: + pass + + def execute(self, query: str) -> Any: + return self + + # since we don't have access to snowflake/databricks/etc from CI we + # mock them so we can check that we're calling the expected methods + for driver, batch_size, iter_batches, expected_call in ( ("snowflake", None, False, "fetch_arrow_all"), ("snowflake", 10_000, False, "fetch_arrow_all"), ("snowflake", 10_000, True, "fetch_arrow_batches"), @@ -396,34 +358,20 @@ def test_read_database_parameterisd(tmp_path: Path) -> None: ("adbc_driver_postgresql", None, False, "fetch_arrow_table"), ("adbc_driver_postgresql", 75_000, False, "fetch_arrow_table"), ("adbc_driver_postgresql", 75_000, True, "fetch_arrow_table"), - ], -) -def test_read_database_mocked( - driver: str, batch_size: int | None, iter_batches: bool, expected_call: str -) -> None: - # since we don't have access to snowflake/databricks/etc from CI we - # mock them so we can check that we're calling the expected methods - arrow = pl.DataFrame({"x": [1, 2, 3], "y": ["aa", "bb", "cc"]}).to_arrow() - mc = MockConnection( - driver, - batch_size, - test_data=arrow, - repeat_batch_calls=_ARROW_DRIVER_REGISTRY_.get(driver, {}).get( # type: ignore[call-overload] - "repeat_batch_calls", False - ), - ) - res = pl.read_database( # type: ignore[call-overload] - query="SELECT * FROM test_data", - connection=mc, - iter_batches=iter_batches, - batch_size=batch_size, - ) - if iter_batches: - assert isinstance(res, GeneratorType) - res = pl.concat(res) + ): + mc = MockConnection(driver, batch_size) + res = pl.read_database( # type: ignore[call-overload] + query="SELECT * FROM test_data", + connection=mc, + iter_batches=iter_batches, + batch_size=batch_size, + ) + assert expected_call in mc.cursor().called + if iter_batches: + assert isinstance(res, GeneratorType) + res = pl.concat(res) - assert expected_call in mc.cursor().called - assert res.rows() == [(1, "aa"), (2, "bb"), (3, "cc")] + assert res.rows() == [(1, "aa"), (2, "bb"), (3, "cc")] @pytest.mark.parametrize( diff --git a/py-polars/tests/unit/io/test_hive.py b/py-polars/tests/unit/io/test_hive.py index 4e0bea5316f6..83fce2da1194 100644 --- a/py-polars/tests/unit/io/test_hive.py +++ b/py-polars/tests/unit/io/test_hive.py @@ -98,15 +98,3 @@ def test_hive_partitioned_projection_pushdown( columns = ["sugars_g", "category"] for streaming in [True, False]: assert q.select(columns).collect(streaming=streaming).columns == columns - - # test that hive partition columns are projected with the correct height when - # the projection contains only hive partition columns (11796) - for parallel in ("row_groups", "columns"): - q = pl.scan_parquet( - root / "**/*.parquet", hive_partitioning=True, parallel=parallel # type: ignore[arg-type] - ) - - expect = q.collect().select("category") - actual = q.select("category").collect() - - assert expect.frame_equal(actual) diff --git a/py-polars/tests/unit/io/test_spreadsheet.py b/py-polars/tests/unit/io/test_spreadsheet.py index 94868f0e467d..a777e8af318b 100644 --- a/py-polars/tests/unit/io/test_spreadsheet.py +++ b/py-polars/tests/unit/io/test_spreadsheet.py @@ -1,7 +1,6 @@ from __future__ import annotations import warnings -from collections import OrderedDict from datetime import date, datetime from io import BytesIO from typing import TYPE_CHECKING, Any, Callable, Literal @@ -195,36 +194,6 @@ def test_read_excel_basic_datatypes( assert_frame_equal(df, df) -@pytest.mark.parametrize( - ("read_spreadsheet", "source", "params"), - [ - (pl.read_excel, "path_xlsx", {"engine": "xlsx2csv"}), - (pl.read_excel, "path_xlsx", {"engine": "openpyxl"}), - (pl.read_excel, "path_xlsb", {"engine": "pyxlsb"}), - (pl.read_ods, "path_ods", {}), - ], -) -def test_read_invalid_worksheet( - read_spreadsheet: Callable[..., dict[str, pl.DataFrame]], - source: str, - params: dict[str, str], - request: pytest.FixtureRequest, -) -> None: - spreadsheet_path = request.getfixturevalue(source) - for param, sheet_id, sheet_name in ( - ("id", 999, None), - ("name", None, "not_a_sheet_name"), - ): - value = sheet_id if param == "id" else sheet_name - with pytest.raises( - ValueError, - match=f"no matching sheets found when `sheet_{param}` is {value!r}", - ): - read_spreadsheet( - spreadsheet_path, sheet_id=sheet_id, sheet_name=sheet_name, **params - ) - - @pytest.mark.parametrize("engine", ["xlsx2csv", "openpyxl"]) def test_write_excel_bytes(engine: Literal["xlsx2csv", "openpyxl", "pyxlsb"]) -> None: df = pl.DataFrame({"A": [1, 2, 3, 4, 5]}) @@ -306,22 +275,6 @@ def test_schema_overrides(path_xlsx: Path, path_xlsb: Path, path_ods: Path) -> N read_csv_options={"dtypes": {"cardinality": pl.Int32}}, ) - # read multiple sheets in conjunction with 'schema_overrides' - # (note: reading the same sheet twice simulates the issue in #11850) - overrides = OrderedDict( - [ - ("cardinality", pl.UInt32), - ("rows_by_key", pl.Float32), - ("iter_groups", pl.Float64), - ] - ) - df = pl.read_excel( # type: ignore[call-overload] - path_xlsx, - sheet_name=["test4", "test4"], - schema_overrides=overrides, - ) - assert df["test4"].schema == overrides - def test_unsupported_engine() -> None: with pytest.raises(NotImplementedError): diff --git a/py-polars/tests/unit/namespaces/test_array.py b/py-polars/tests/unit/namespaces/test_array.py index cc20cba7feca..ac69510cd8ed 100644 --- a/py-polars/tests/unit/namespaces/test_array.py +++ b/py-polars/tests/unit/namespaces/test_array.py @@ -5,19 +5,19 @@ def test_arr_min_max() -> None: - s = pl.Series("a", [[1, 2], [4, 3]], dtype=pl.Array(inner=pl.Int64, width=2)) + s = pl.Series("a", [[1, 2], [4, 3]], dtype=pl.Array(width=2, inner=pl.Int64)) assert s.arr.max().to_list() == [2, 4] assert s.arr.min().to_list() == [1, 3] def test_arr_sum() -> None: - s = pl.Series("a", [[1, 2], [4, 3]], dtype=pl.Array(inner=pl.Int64, width=2)) + s = pl.Series("a", [[1, 2], [4, 3]], dtype=pl.Array(width=2, inner=pl.Int64)) assert s.arr.sum().to_list() == [3, 7] def test_arr_unique() -> None: df = pl.DataFrame( - {"a": pl.Series("a", [[1, 1], [4, 3]], dtype=pl.Array(inner=pl.Int64, width=2))} + {"a": pl.Series("a", [[1, 1], [4, 3]], dtype=pl.Array(width=2, inner=pl.Int64))} ) out = df.select(pl.col("a").arr.unique(maintain_order=True)) @@ -26,5 +26,5 @@ def test_arr_unique() -> None: def test_array_to_numpy() -> None: - s = pl.Series([[1, 2], [3, 4], [5, 6]], dtype=pl.Array(inner=pl.Int64, width=2)) + s = pl.Series([[1, 2], [3, 4], [5, 6]], dtype=pl.Array(width=2, inner=pl.Int64)) assert (s.to_numpy() == np.array([[1, 2], [3, 4], [5, 6]])).all() diff --git a/py-polars/tests/unit/namespaces/test_list.py b/py-polars/tests/unit/namespaces/test_list.py index 89ceff4e6a0a..9c0eeb51e063 100644 --- a/py-polars/tests/unit/namespaces/test_list.py +++ b/py-polars/tests/unit/namespaces/test_list.py @@ -179,37 +179,6 @@ def test_list_drop_nulls() -> None: assert_frame_equal(df, expected_df) -def test_list_sample() -> None: - s = pl.Series("values", [[1, 2, 3, None], [None, None], [1, 2], None]) - - expected_sample_n = pl.Series("values", [[3, 1], [None], [2], None]) - assert_series_equal( - s.list.sample(n=pl.Series([2, 1, 1, 1]), seed=1), expected_sample_n - ) - - expected_sample_frac = pl.Series("values", [[3, 1], [None], [1, 2], None]) - assert_series_equal( - s.list.sample(fraction=pl.Series([0.5, 0.5, 1.0, 0.3]), seed=1), - expected_sample_frac, - ) - - df = pl.DataFrame( - { - "values": [[1, 2, 3, None], [None, None], [3, 4]], - "n": [2, 1, 2], - "frac": [0.5, 0.5, 1.0], - } - ) - df = df.select( - sample_n=pl.col("values").list.sample(n=pl.col("n"), seed=1), - sample_frac=pl.col("values").list.sample(fraction=pl.col("frac"), seed=1), - ) - expected_df = pl.DataFrame( - {"sample_n": [[3, 1], [None], [3, 4]], "sample_frac": [[3, 1], [None], [3, 4]]} - ) - assert_frame_equal(df, expected_df) - - def test_list_diff() -> None: s = pl.Series("a", [[1, 2], [10, 2, 1]]) expected = pl.Series("a", [[None, 1], [None, -8, -1]]) diff --git a/py-polars/tests/unit/operations/map/test_map_groups.py b/py-polars/tests/unit/operations/map/test_map_groups.py index 90699300ea9d..64cc7c443adf 100644 --- a/py-polars/tests/unit/operations/map/test_map_groups.py +++ b/py-polars/tests/unit/operations/map/test_map_groups.py @@ -49,7 +49,9 @@ def function(df: pl.DataFrame) -> pl.DataFrame: pl.col("b").max(), ) - result = df.rolling("a", period="2i").map_groups(function, schema=df.schema) + result = df.group_by_rolling("a", period="2i").map_groups( + function, schema=df.schema + ) expected = pl.DataFrame( [ @@ -160,7 +162,7 @@ def test_apply_deprecated() -> None: with pytest.deprecated_call(): df.group_by("a").apply(lambda x: x) with pytest.deprecated_call(): - df.rolling("a", period="2i").apply(lambda x: x, schema=None) + df.group_by_rolling("a", period="2i").apply(lambda x: x, schema=None) with pytest.deprecated_call(): df.group_by_dynamic("a", every="2i").apply(lambda x: x, schema=None) with pytest.deprecated_call(): diff --git a/py-polars/tests/unit/operations/rolling/test_rolling.py b/py-polars/tests/unit/operations/rolling/test_rolling.py index 0f763f1a04c2..639b7b297031 100644 --- a/py-polars/tests/unit/operations/rolling/test_rolling.py +++ b/py-polars/tests/unit/operations/rolling/test_rolling.py @@ -36,7 +36,7 @@ def example_df() -> pl.DataFrame: ["1d", "2d", "3d", timedelta(days=1), timedelta(days=2), timedelta(days=3)], ) @pytest.mark.parametrize("closed", ["left", "right", "none", "both"]) -def test_rolling_kernels_and_rolling( +def test_rolling_kernels_and_group_by_rolling( example_df: pl.DataFrame, period: str | timedelta, closed: ClosedInterval ) -> None: out1 = example_df.set_sorted("dt").select( @@ -56,7 +56,7 @@ def test_rolling_kernels_and_rolling( ) out2 = ( example_df.set_sorted("dt") - .rolling("dt", period=period, closed=closed) + .group_by_rolling("dt", period=period, closed=closed) .agg( [ pl.col("values").sum().alias("sum"), @@ -145,7 +145,7 @@ def test_rolling_negative_offset( "value": [1, 2, 3, 4], } ) - result = df.rolling("ts", period="2d", offset=offset, closed=closed).agg( + result = df.group_by_rolling("ts", period="2d", offset=offset, closed=closed).agg( pl.col("value") ) expected = pl.DataFrame( @@ -271,7 +271,7 @@ def test_rolling_group_by_extrema() -> None: ).with_columns(pl.col("col1").reverse().alias("row_nr")) assert ( - df.rolling( + df.group_by_rolling( index_column="row_nr", period="3i", ) @@ -310,7 +310,7 @@ def test_rolling_group_by_extrema() -> None: ).with_columns(pl.col("col1").alias("row_nr")) assert ( - df.rolling( + df.group_by_rolling( index_column="row_nr", period="3i", ) @@ -348,7 +348,7 @@ def test_rolling_group_by_extrema() -> None: ).with_columns(pl.col("col1").sort().alias("row_nr")) assert ( - df.rolling( + df.group_by_rolling( index_column="row_nr", period="3i", ) @@ -379,7 +379,7 @@ def test_rolling_slice_pushdown() -> None: df = pl.DataFrame({"a": [1, 2, 3], "b": ["a", "a", "b"], "c": [1, 3, 5]}).lazy() df = ( df.sort("a") - .rolling( + .group_by_rolling( "a", by="b", period="2i", @@ -407,7 +407,7 @@ def test_overlapping_groups_4628() -> None: } ) assert ( - df.rolling(index_column=pl.col("index").set_sorted(), period="3i").agg( + df.group_by_rolling(index_column=pl.col("index").set_sorted(), period="3i").agg( [ pl.col("val").diff(n=1).alias("val.diff"), (pl.col("val") - pl.col("val").shift(1)).alias("val - val.shift"), @@ -473,7 +473,7 @@ def test_rolling_var_numerical_stability_5197() -> None: assert res[:4] == [None] * 4 -def test_rolling_iter() -> None: +def test_group_by_rolling_iter() -> None: df = pl.DataFrame( { "date": [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 5)], @@ -485,7 +485,7 @@ def test_rolling_iter() -> None: # Without 'by' argument result1 = [ (name, data.shape) - for name, data in df.rolling(index_column="date", period="2d") + for name, data in df.group_by_rolling(index_column="date", period="2d") ] expected1 = [ (date(2020, 1, 1), (1, 3)), @@ -497,7 +497,7 @@ def test_rolling_iter() -> None: # With 'by' argument result2 = [ (name, data.shape) - for name, data in df.rolling(index_column="date", period="2d", by="a") + for name, data in df.group_by_rolling(index_column="date", period="2d", by="a") ] expected2 = [ ((1, date(2020, 1, 1)), (1, 3)), @@ -507,18 +507,18 @@ def test_rolling_iter() -> None: assert result2 == expected2 -def test_rolling_negative_period() -> None: +def test_group_by_rolling_negative_period() -> None: df = pl.DataFrame({"ts": [datetime(2020, 1, 1)], "value": [1]}).with_columns( pl.col("ts").set_sorted() ) with pytest.raises( ComputeError, match="rolling window period should be strictly positive" ): - df.rolling("ts", period="-1d", offset="-1d").agg(pl.col("value")) + df.group_by_rolling("ts", period="-1d", offset="-1d").agg(pl.col("value")) with pytest.raises( ComputeError, match="rolling window period should be strictly positive" ): - df.lazy().rolling("ts", period="-1d", offset="-1d").agg( + df.lazy().group_by_rolling("ts", period="-1d", offset="-1d").agg( pl.col("value") ).collect() with pytest.raises(ComputeError, match="window size should be strictly positive"): diff --git a/py-polars/tests/unit/operations/test_aggregations.py b/py-polars/tests/unit/operations/test_aggregations.py index 0cd22d787120..5dff526eabc7 100644 --- a/py-polars/tests/unit/operations/test_aggregations.py +++ b/py-polars/tests/unit/operations/test_aggregations.py @@ -132,6 +132,7 @@ def test_quantile_vs_numpy(tp: type, n: int) -> None: np_result = np.quantile(a, q) except IndexError: np_result = None + pass if np_result: # nan check if np_result != np_result: diff --git a/py-polars/tests/unit/operations/test_explode.py b/py-polars/tests/unit/operations/test_explode.py index 85b7bf892dba..754a3e27bc1c 100644 --- a/py-polars/tests/unit/operations/test_explode.py +++ b/py-polars/tests/unit/operations/test_explode.py @@ -309,7 +309,7 @@ def test_explode_inner_null() -> None: def test_explode_array() -> None: df = pl.LazyFrame( {"a": [[1, 2], [2, 3]], "b": [1, 2]}, - schema_overrides={"a": pl.Array(inner=pl.Int64, width=2)}, + schema_overrides={"a": pl.Array(2, inner=pl.Int64)}, ) expected = pl.DataFrame({"a": [1, 2, 2, 3], "b": [1, 1, 2, 2]}) for ex in ("a", ~cs.integer()): diff --git a/py-polars/tests/unit/operations/test_group_by.py b/py-polars/tests/unit/operations/test_group_by.py index 1902459bb05a..d3fb4d6dcb49 100644 --- a/py-polars/tests/unit/operations/test_group_by.py +++ b/py-polars/tests/unit/operations/test_group_by.py @@ -744,32 +744,7 @@ def test_groupby_rolling_deprecated() -> None: .collect() ) - expected = df.rolling("date", period="2d").agg(pl.sum("value")) - assert_frame_equal(result, expected, check_row_order=False) - assert_frame_equal(result_lazy, expected, check_row_order=False) - - -def test_group_by_rolling_deprecated() -> None: - df = pl.DataFrame( - { - "date": pl.datetime_range( - datetime(2020, 1, 1), datetime(2020, 1, 5), eager=True - ), - "value": [1, 2, 3, 4, 5], - } - ) - - with pytest.deprecated_call(): - result = df.group_by_rolling("date", period="2d").agg(pl.sum("value")) - with pytest.deprecated_call(): - result_lazy = ( - df.lazy() - .groupby_rolling("date", period="2d") - .agg(pl.sum("value")) - .collect() - ) - - expected = df.rolling("date", period="2d").agg(pl.sum("value")) + expected = df.group_by_rolling("date", period="2d").agg(pl.sum("value")) assert_frame_equal(result, expected, check_row_order=False) assert_frame_equal(result_lazy, expected, check_row_order=False) @@ -813,19 +788,3 @@ def test_group_by_list_scalar_11749() -> None: "group_name": ["a;b", "c;d"], "eq": [[True, True, True, True], [True, False]], } - - -def test_group_by_with_expr_as_key() -> None: - gb = pl.select(x=1).group_by(pl.col("x").alias("key")) - assert gb.agg(pl.all().first()).frame_equal(gb.agg(pl.first("x"))) - - # tests: 11766 - assert gb.head(0).frame_equal(gb.agg(pl.col("x").head(0)).explode("x")) - assert gb.tail(0).frame_equal(gb.agg(pl.col("x").tail(0)).explode("x")) - - -def test_lazy_group_by_reuse_11767() -> None: - lgb = pl.select(x=1).lazy().group_by("x") - a = lgb.count() - b = lgb.count() - assert a.collect().frame_equal(b.collect()) diff --git a/py-polars/tests/unit/operations/test_group_by_dynamic.py b/py-polars/tests/unit/operations/test_group_by_dynamic.py index 3862901e5a7b..d8eff5bc96e8 100644 --- a/py-polars/tests/unit/operations/test_group_by_dynamic.py +++ b/py-polars/tests/unit/operations/test_group_by_dynamic.py @@ -364,7 +364,7 @@ def test_sorted_flag_group_by_dynamic() -> None: ) -def test_rolling_dynamic_sortedness_check() -> None: +def test_group_by_rolling_dynamic_sortedness_check() -> None: # when the by argument is passed, the sortedness flag # will be unset as the take shuffles data, so we must explicitly # check the sortedness diff --git a/py-polars/tests/unit/operations/test_group_by_rolling.py b/py-polars/tests/unit/operations/test_group_by_rolling.py index 435b7c34bdec..8220dbba6042 100644 --- a/py-polars/tests/unit/operations/test_group_by_rolling.py +++ b/py-polars/tests/unit/operations/test_group_by_rolling.py @@ -32,7 +32,7 @@ def test_rolling_group_by_overlapping_groups() -> None: ( df.with_row_count() .with_columns(pl.col("row_nr").cast(pl.Int32)) - .rolling( + .group_by_rolling( index_column="row_nr", period="5i", ) @@ -48,7 +48,7 @@ def test_rolling_group_by_overlapping_groups() -> None: @pytest.mark.parametrize("lazy", [True, False]) -def test_rolling_agg_input_types(lazy: bool) -> None: +def test_group_by_rolling_agg_input_types(lazy: bool) -> None: df = pl.DataFrame({"index_column": [0, 1, 2, 3], "b": [1, 3, 1, 2]}).set_sorted( "index_column" ) @@ -56,24 +56,24 @@ def test_rolling_agg_input_types(lazy: bool) -> None: for bad_param in bad_agg_parameters(): with pytest.raises(TypeError): # noqa: PT012 - result = df_or_lazy.rolling(index_column="index_column", period="2i").agg( - bad_param - ) + result = df_or_lazy.group_by_rolling( + index_column="index_column", period="2i" + ).agg(bad_param) if lazy: result.collect() # type: ignore[union-attr] expected = pl.DataFrame({"index_column": [0, 1, 2, 3], "b": [1, 4, 4, 3]}) for good_param in good_agg_parameters(): - result = df_or_lazy.rolling(index_column="index_column", period="2i").agg( - good_param - ) + result = df_or_lazy.group_by_rolling( + index_column="index_column", period="2i" + ).agg(good_param) if lazy: result = result.collect() # type: ignore[union-attr] assert_frame_equal(result, expected) -def test_rolling_negative_offset_3914() -> None: +def test_group_by_rolling_negative_offset_3914() -> None: df = pl.DataFrame( { "datetime": pl.datetime_range( @@ -81,7 +81,7 @@ def test_rolling_negative_offset_3914() -> None: ), } ) - assert df.rolling(index_column="datetime", period="2d", offset="-4d").agg( + assert df.group_by_rolling(index_column="datetime", period="2d", offset="-4d").agg( pl.count().alias("count") )["count"].to_list() == [0, 0, 1, 2, 2] @@ -91,7 +91,7 @@ def test_rolling_negative_offset_3914() -> None: } ) - assert df.rolling(index_column="ints", period="2i", offset="-5i").agg( + assert df.group_by_rolling(index_column="ints", period="2i", offset="-5i").agg( [pl.col("ints").alias("matches")] )["matches"].to_list() == [ [], @@ -118,7 +118,7 @@ def test_rolling_negative_offset_3914() -> None: @pytest.mark.parametrize("time_zone", [None, "US/Central"]) -def test_rolling_negative_offset_crossing_dst(time_zone: str | None) -> None: +def test_group_by_rolling_negative_offset_crossing_dst(time_zone: str | None) -> None: df = pl.DataFrame( { "datetime": pl.datetime_range( @@ -131,9 +131,9 @@ def test_rolling_negative_offset_crossing_dst(time_zone: str | None) -> None: "value": [1, 4, 9, 155], } ) - result = df.rolling(index_column="datetime", period="2d", offset="-1d").agg( - pl.col("value") - ) + result = df.group_by_rolling( + index_column="datetime", period="2d", offset="-1d" + ).agg(pl.col("value")) expected = pl.DataFrame( { "datetime": pl.datetime_range( @@ -163,7 +163,7 @@ def test_rolling_negative_offset_crossing_dst(time_zone: str | None) -> None: ("1d", "none", [[9], [155], [], []]), ], ) -def test_rolling_non_negative_offset_9077( +def test_group_by_rolling_non_negative_offset_9077( time_zone: str | None, offset: str, closed: ClosedInterval, @@ -181,7 +181,7 @@ def test_rolling_non_negative_offset_9077( "value": [1, 4, 9, 155], } ) - result = df.rolling( + result = df.group_by_rolling( index_column="datetime", period="2d", offset=offset, closed=closed ).agg(pl.col("value")) expected = pl.DataFrame( @@ -199,7 +199,7 @@ def test_rolling_non_negative_offset_9077( assert_frame_equal(result, expected) -def test_rolling_dynamic_sortedness_check() -> None: +def test_group_by_rolling_dynamic_sortedness_check() -> None: # when the by argument is passed, the sortedness flag # will be unset as the take shuffles data, so we must explicitly # check the sortedness @@ -211,17 +211,19 @@ def test_rolling_dynamic_sortedness_check() -> None: ) with pytest.raises(pl.ComputeError, match=r"input data is not sorted"): - df.rolling("idx", period="2i", by="group").agg(pl.col("idx").alias("idx1")) + df.group_by_rolling("idx", period="2i", by="group").agg( + pl.col("idx").alias("idx1") + ) # no `by` argument with pytest.raises( pl.InvalidOperationError, match=r"argument in operation 'group_by_rolling' is not explicitly sorted", ): - df.rolling("idx", period="2i").agg(pl.col("idx").alias("idx1")) + df.group_by_rolling("idx", period="2i").agg(pl.col("idx").alias("idx1")) -def test_rolling_empty_groups_9973() -> None: +def test_group_by_rolling_empty_groups_9973() -> None: dt1 = date(2001, 1, 1) dt2 = date(2001, 1, 2) @@ -248,7 +250,7 @@ def test_rolling_empty_groups_9973() -> None: } ) - out = data.rolling( + out = data.group_by_rolling( index_column="date", by="id", period="2d", @@ -260,7 +262,7 @@ def test_rolling_empty_groups_9973() -> None: assert_frame_equal(out, expected) -def test_rolling_duplicates_11281() -> None: +def test_group_by_rolling_duplicates_11281() -> None: df = pl.DataFrame( { "ts": [ @@ -274,6 +276,6 @@ def test_rolling_duplicates_11281() -> None: "val": [1, 2, 2, 2, 3, 4], } ).sort("ts") - result = df.rolling("ts", period="1d", closed="left").agg(pl.col("val")) + result = df.group_by_rolling("ts", period="1d", closed="left").agg(pl.col("val")) expected = df.with_columns(val=pl.Series([[], [1], [1], [1], [2, 2, 2], [3]])) assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/operations/test_join.py b/py-polars/tests/unit/operations/test_join.py index 844b02988f3e..ce529cea450a 100644 --- a/py-polars/tests/unit/operations/test_join.py +++ b/py-polars/tests/unit/operations/test_join.py @@ -561,28 +561,6 @@ def test_update() -> None: a.update(b.rename({"b": "a"}), how="outer", on="a").collect().to_series() ) - # check behavior of include_nulls=True - df = pl.DataFrame( - { - "A": [1, 2, 3, 4], - "B": [400, 500, 600, 700], - } - ) - new_df = pl.DataFrame( - { - "B": [-66, None, -99], - "C": [5, 3, 1], - } - ) - out = df.update(new_df, left_on="A", right_on="C", how="outer", include_nulls=True) - expected = pl.DataFrame( - { - "A": [1, 2, 3, 4, 5], - "B": [-99, 500, None, 700, -66], - } - ) - assert_frame_equal(out, expected) - # edge-case #11684 x = pl.DataFrame({"a": [0, 1]}) y = pl.DataFrame({"a": [2, 3]}) diff --git a/py-polars/tests/unit/series/test_series.py b/py-polars/tests/unit/series/test_series.py index c77cb2cecb5b..387fcec9a603 100644 --- a/py-polars/tests/unit/series/test_series.py +++ b/py-polars/tests/unit/series/test_series.py @@ -243,18 +243,6 @@ def test_equal() -> None: assert s3.dt.convert_time_zone("Asia/Tokyo").series_equal(s4) is True -@pytest.mark.parametrize( - "dtype", - [pl.Int64, pl.Float64, pl.Utf8, pl.Boolean], -) -def test_eq_missing_list_and_primitive(dtype: PolarsDataType) -> None: - s1 = pl.Series([None, None], dtype=dtype) - s2 = pl.Series([None, None], dtype=pl.List(dtype)) - - expected = pl.Series([True, True]) - assert_series_equal(s1.eq_missing(s2), expected) - - def test_to_frame() -> None: s1 = pl.Series([1, 2]) s2 = pl.Series("s", [1, 2]) @@ -614,25 +602,21 @@ def test_to_pandas() -> None: pass -def test_series_to_list() -> None: - s = pl.Series("a", range(20)) - result = s.to_list() - assert isinstance(result, list) - assert len(result) == 20 +def test_to_python() -> None: + a = pl.Series("a", range(20)) + b = a.to_list() + assert isinstance(b, list) + assert len(b) == 20 + + b = a.to_list(use_pyarrow=True) + assert isinstance(b, list) + assert len(b) == 20 a = pl.Series("a", [1, None, 2]) assert a.null_count() == 1 assert a.to_list() == [1, None, 2] -def test_series_to_list_use_pyarrow_deprecated() -> None: - s = pl.Series("a", range(20)) - with pytest.deprecated_call(): - result = s.to_list(use_pyarrow=True) - assert isinstance(result, list) - assert len(result) == 20 - - def test_to_struct() -> None: s = pl.Series("nums", ["12 34", "56 78", "90 00"]).str.extract_all(r"\d+") @@ -1279,7 +1263,6 @@ def test_pct_change() -> None: s = pl.Series("a", [1, 2, 4, 8, 16, 32, 64]) expected = pl.Series("a", [None, None, float("inf"), 3.0, 3.0, 3.0, 3.0]) assert_series_equal(s.pct_change(2), expected) - assert_series_equal(s.pct_change(pl.Series([2])), expected) # negative assert pl.Series(range(5)).pct_change(-1).to_list() == [ -1.0, @@ -1450,10 +1433,10 @@ def test_bitwise() -> None: # ensure mistaken use of logical 'and'/'or' raises an exception with pytest.raises(TypeError, match="ambiguous"): - a and b # type: ignore[redundant-expr] + a and b with pytest.raises(TypeError, match="ambiguous"): - a or b # type: ignore[redundant-expr] + a or b def test_to_numpy(monkeypatch: Any) -> None: @@ -1568,21 +1551,6 @@ def test_comparisons_int_series_to_float() -> None: assert_series_equal(srs_int - True, pl.Series([0, 1, 2, 3])) -def test_comparisons_int_series_to_float_scalar() -> None: - srs_int = pl.Series([1, 2, 3, 4]) - - assert_series_equal(srs_int < 1.5, pl.Series([True, False, False, False])) - assert_series_equal(srs_int > 1.5, pl.Series([False, True, True, True])) - - -def test_comparisons_datetime_series_to_date_scalar() -> None: - srs_date = pl.Series([date(2023, 1, 1), date(2023, 1, 2), date(2023, 1, 3)]) - dt = datetime(2023, 1, 1, 12, 0, 0) - - assert_series_equal(srs_date < dt, pl.Series([True, False, False])) - assert_series_equal(srs_date > dt, pl.Series([False, True, True])) - - def test_comparisons_float_series_to_int() -> None: srs_float = pl.Series([1.0, 2.0, 3.0, 4.0]) @@ -2163,9 +2131,7 @@ def test_ewm_mean() -> None: def test_ewm_mean_leading_nulls() -> None: for min_periods in [1, 2, 3]: assert ( - pl.Series([1, 2, 3, 4]) - .ewm_mean(com=3, min_periods=min_periods) - .null_count() + pl.Series([1, 2, 3, 4]).ewm_mean(3, min_periods=min_periods).null_count() == min_periods - 1 ) assert pl.Series([None, 1.0, 1.0, 1.0]).ewm_mean( @@ -2775,28 +2741,3 @@ def test_series_getitem_out_of_bounds_negative() -> None: IndexError, match="index -10 is out of bounds for sequence of length 2" ): s[-10] - - -def test_series_cmp_fast_paths() -> None: - assert ( - pl.Series([None], dtype=pl.Int32) != pl.Series([1, 2], dtype=pl.Int32) - ).to_list() == [None, None] - assert ( - pl.Series([None], dtype=pl.Int32) == pl.Series([1, 2], dtype=pl.Int32) - ).to_list() == [None, None] - - assert ( - pl.Series([None], dtype=pl.Utf8) != pl.Series(["a", "b"], dtype=pl.Utf8) - ).to_list() == [None, None] - assert ( - pl.Series([None], dtype=pl.Utf8) == pl.Series(["a", "b"], dtype=pl.Utf8) - ).to_list() == [None, None] - - assert ( - pl.Series([None], dtype=pl.Boolean) - != pl.Series([True, False], dtype=pl.Boolean) - ).to_list() == [None, None] - assert ( - pl.Series([None], dtype=pl.Boolean) - == pl.Series([False, False], dtype=pl.Boolean) - ).to_list() == [None, None] diff --git a/py-polars/tests/unit/sql/test_sql.py b/py-polars/tests/unit/sql/test_sql.py index ca281c380ab0..aad05b5f7b5e 100644 --- a/py-polars/tests/unit/sql/test_sql.py +++ b/py-polars/tests/unit/sql/test_sql.py @@ -1,6 +1,5 @@ from __future__ import annotations -import datetime import math from pathlib import Path @@ -942,14 +941,12 @@ def test_sql_trim(foods_ipc_path: Path) -> None: "BY NAME", [(1, "zz"), (2, "yy"), (3, "xx")], ), - pytest.param( + ( + # note: waiting for "https://github.com/sqlparser-rs/sqlparser-rs/pull/997" ["c1", "c2"], ["c2", "c1"], "DISTINCT BY NAME", - [(1, "zz"), (2, "yy"), (3, "xx")], - # TODO: Remove xfail marker when supported added in sqlparser-rs - # https://github.com/sqlparser-rs/sqlparser-rs/pull/997 - marks=pytest.mark.xfail, + None, # [(1, "zz"), (2, "yy"), (3, "xx")], ), ], ) @@ -957,7 +954,7 @@ def test_sql_union( cols1: list[str], cols2: list[str], union_subtype: str, - expected: list[tuple[int, str]], + expected: dict[str, list[int] | list[str]] | None, ) -> None: with pl.SQLContext( frame1=pl.DataFrame({"c1": [1, 2], "c2": ["zz", "yy"]}), @@ -969,7 +966,11 @@ def test_sql_union( UNION {union_subtype} SELECT {', '.join(cols2)} FROM frame2 """ - assert sorted(ctx.execute(query).rows()) == expected + if expected is not None: + assert sorted(ctx.execute(query).rows()) == expected + else: + with pytest.raises(pl.ComputeError, match="sql parser error"): + ctx.execute(query) def test_sql_nullif_coalesce(foods_ipc_path: Path) -> None: @@ -1207,27 +1208,3 @@ def test_sql_unary_ops_8890(match_float: bool) -> None: "c": [-3, -3], "d": [4, 4], } - - -def test_sql_date() -> None: - df = pl.DataFrame( - { - "date": [ - datetime.date(2021, 3, 15), - datetime.date(2021, 3, 28), - datetime.date(2021, 4, 4), - ], - "version": ["0.0.1", "0.7.3", "0.7.4"], - } - ) - - with pl.SQLContext(df=df, eager_execution=True) as ctx: - expected = pl.DataFrame({"date": [True, False, False]}) - assert ctx.execute("SELECT date < DATE('2021-03-20') from df").frame_equal( - expected - ) - - expected = pl.DataFrame({"literal": ["2023-03-01"]}) - assert pl.select( - pl.sql_expr("""CAST(DATE('2023-03', '%Y-%m') as STRING)""") - ).frame_equal(expected) diff --git a/py-polars/tests/unit/streaming/test_streaming_categoricals.py b/py-polars/tests/unit/streaming/test_streaming_categoricals.py deleted file mode 100644 index 776e0c0ce377..000000000000 --- a/py-polars/tests/unit/streaming/test_streaming_categoricals.py +++ /dev/null @@ -1,14 +0,0 @@ -import polars as pl - - -def test_streaming_nested_categorical() -> None: - assert ( - pl.LazyFrame({"numbers": [1, 1, 2], "cat": [["str"], ["foo"], ["bar"]]}) - .with_columns(pl.col("cat").cast(pl.List(pl.Categorical))) - .group_by("numbers") - .agg(pl.col("cat").first()) - .sort("numbers") - ).collect(streaming=True).to_dict(False) == { - "numbers": [1, 2], - "cat": [["str"], ["bar"]], - } diff --git a/py-polars/tests/unit/test_constructors.py b/py-polars/tests/unit/test_constructors.py index a09eb469eeab..82263ba09a81 100644 --- a/py-polars/tests/unit/test_constructors.py +++ b/py-polars/tests/unit/test_constructors.py @@ -592,10 +592,6 @@ def test_init_ndarray(monkeypatch: Any) -> None: assert df.shape == (2, 1) assert df.rows() == [([0, 1, 2, 3, 4],), ([5, 6, 7, 8, 9],)] - test_rows = [(1, 2), (3, 4)] - df = pl.DataFrame([np.array(test_rows[0]), np.array(test_rows[1])], orient="row") - assert_frame_equal(df, pl.DataFrame(test_rows, orient="row")) - # numpy arrays containing NaN df0 = pl.DataFrame( data={"x": [1.0, 2.5, float("nan")], "y": [4.0, float("nan"), 6.5]}, diff --git a/py-polars/tests/unit/test_errors.py b/py-polars/tests/unit/test_errors.py index f45ee6445bf1..059c63776c75 100644 --- a/py-polars/tests/unit/test_errors.py +++ b/py-polars/tests/unit/test_errors.py @@ -105,7 +105,7 @@ def test_string_numeric_comp_err() -> None: def test_panic_error() -> None: with pytest.raises( pl.PolarsPanicError, - match="dimensions cannot be empty", + match="""dimensions cannot be empty""", ): pl.Series("a", [1, 2, 3]).reshape(()) diff --git a/py-polars/tests/unit/test_exprs.py b/py-polars/tests/unit/test_exprs.py index a75fa521663a..8518755eb4f1 100644 --- a/py-polars/tests/unit/test_exprs.py +++ b/py-polars/tests/unit/test_exprs.py @@ -428,10 +428,10 @@ def test_logical_boolean() -> None: # note, cannot use expressions in logical # boolean context (eg: and/or/not operators) with pytest.raises(TypeError, match="ambiguous"): - pl.col("colx") and pl.col("coly") # type: ignore[redundant-expr] + pl.col("colx") and pl.col("coly") with pytest.raises(TypeError, match="ambiguous"): - pl.col("colx") or pl.col("coly") # type: ignore[redundant-expr] + pl.col("colx") or pl.col("coly") df = pl.DataFrame({"a": [1, 2, 3, 4, 5], "b": [1, 2, 3, 4, 5]}) diff --git a/py-polars/tests/unit/test_interop.py b/py-polars/tests/unit/test_interop.py index af4efa23edcc..d3981cc8528a 100644 --- a/py-polars/tests/unit/test_interop.py +++ b/py-polars/tests/unit/test_interop.py @@ -79,14 +79,6 @@ def test_to_numpy_no_zero_copy( series.to_numpy(zero_copy_only=True, use_pyarrow=use_pyarrow) -def test_to_numpy_empty_no_pyarrow() -> None: - series = pl.Series([], dtype=pl.Null) - result = series.to_numpy() - assert result.dtype == pl.Float32 - assert result.shape == (0,) - assert result.size == 0 - - def test_from_pandas() -> None: df = pd.DataFrame( { diff --git a/py-polars/tests/unit/test_predicates.py b/py-polars/tests/unit/test_predicates.py index 41e2207e844c..2af6376d05ef 100644 --- a/py-polars/tests/unit/test_predicates.py +++ b/py-polars/tests/unit/test_predicates.py @@ -192,14 +192,3 @@ def test_predicate_pushdown_group_by_keys() -> None: .filter(pl.col("group") == 1) .explain() ) - - -def test_no_predicate_push_down_with_cast_and_alias_11883() -> None: - df = pl.DataFrame({"a": [1, 2, 3]}) - out = ( - df.lazy() - .select(pl.col("a").cast(pl.Int64).alias("b")) - .filter(pl.col("b") == 1) - .filter((pl.col("b") >= 1) & (pl.col("b") < 1)) - ) - assert 'SELECTION: "None"' in out.explain(predicate_pushdown=True) diff --git a/py-polars/tests/unit/test_projections.py b/py-polars/tests/unit/test_projections.py index 85dc01ce006d..2e276b837b01 100644 --- a/py-polars/tests/unit/test_projections.py +++ b/py-polars/tests/unit/test_projections.py @@ -320,9 +320,3 @@ def test_projection_rename_10595() -> None: assert lf.select("a", "b").rename({"b": "a", "a": "b"}).select( "a" ).collect().schema == {"a": pl.Float32} - - -def test_projection_count_11841() -> None: - pl.LazyFrame({"x": 1}).select(records=pl.count()).select( - pl.lit(1).alias("x"), pl.all() - ).collect() diff --git a/py-polars/tests/unit/testing/test_assert_series_equal.py b/py-polars/tests/unit/test_testing.py similarity index 57% rename from py-polars/tests/unit/testing/test_assert_series_equal.py rename to py-polars/tests/unit/test_testing.py index b7dbd56c57d6..c44ea15961db 100644 --- a/py-polars/tests/unit/testing/test_assert_series_equal.py +++ b/py-polars/tests/unit/test_testing.py @@ -7,7 +7,14 @@ import pytest import polars as pl -from polars.testing import assert_series_equal, assert_series_not_equal +from polars.exceptions import InvalidAssert +from polars.testing import ( + assert_frame_equal, + assert_frame_equal_local_categoricals, + assert_frame_not_equal, + assert_series_equal, + assert_series_not_equal, +) def test_compare_series_value_mismatch() -> None: @@ -95,10 +102,35 @@ def test_compare_series_nulls() -> None: srs2 = pl.Series([1, None, None]) assert_series_not_equal(srs1, srs2) - with pytest.raises(AssertionError, match="value mismatch"): + with pytest.raises(AssertionError, match="null_count is not equal"): assert_series_equal(srs1, srs2) +def test_series_cmp_fast_paths() -> None: + assert ( + pl.Series([None], dtype=pl.Int32) != pl.Series([1, 2], dtype=pl.Int32) + ).to_list() == [None, None] + assert ( + pl.Series([None], dtype=pl.Int32) == pl.Series([1, 2], dtype=pl.Int32) + ).to_list() == [None, None] + + assert ( + pl.Series([None], dtype=pl.Utf8) != pl.Series(["a", "b"], dtype=pl.Utf8) + ).to_list() == [None, None] + assert ( + pl.Series([None], dtype=pl.Utf8) == pl.Series(["a", "b"], dtype=pl.Utf8) + ).to_list() == [None, None] + + assert ( + pl.Series([None], dtype=pl.Boolean) + != pl.Series([True, False], dtype=pl.Boolean) + ).to_list() == [None, None] + assert ( + pl.Series([None], dtype=pl.Boolean) + == pl.Series([False, False], dtype=pl.Boolean) + ).to_list() == [None, None] + + def test_compare_series_value_mismatch_string() -> None: srs1 = pl.Series(["hello", "no"]) srs2 = pl.Series(["hello", "yes"]) @@ -115,7 +147,7 @@ def test_compare_series_type_mismatch() -> None: srs2 = pl.DataFrame({"col1": [2, 3, 4]}) with pytest.raises( - AssertionError, match=r"inputs are different \(unexpected input types\)" + AssertionError, match=r"Inputs are different \(unexpected input types\)" ): assert_series_equal(srs1, srs2) # type: ignore[arg-type] @@ -134,7 +166,7 @@ def test_compare_series_name_mismatch() -> None: assert_series_equal(srs1, srs2) -def test_compare_series_length_mismatch() -> None: +def test_compare_series_shape_mismatch() -> None: srs1 = pl.Series(values=[1, 2, 3, 4], name="srs1") srs2 = pl.Series(values=[1, 2, 3], name="srs2") @@ -154,6 +186,399 @@ def test_compare_series_value_exact_mismatch() -> None: assert_series_equal(srs1, srs2, check_exact=True) +def test_compare_frame_equal_nans() -> None: + nan = float("NaN") + df1 = pl.DataFrame( + data={"x": [1.0, nan], "y": [nan, 2.0]}, + schema=[("x", pl.Float32), ("y", pl.Float64)], + ) + assert_frame_equal(df1, df1, check_exact=True) + + df2 = pl.DataFrame( + data={"x": [1.0, nan], "y": [None, 2.0]}, + schema=[("x", pl.Float32), ("y", pl.Float64)], + ) + assert_frame_not_equal(df1, df2) + with pytest.raises(AssertionError, match="values for column 'y' are different"): + assert_frame_equal(df1, df2, check_exact=True) + + +def test_compare_frame_equal_nested_nans() -> None: + nan = float("NaN") + + # list dtype + df1 = pl.DataFrame( + data={"x": [[1.0, nan]], "y": [[nan, 2.0]]}, + schema=[("x", pl.List(pl.Float32)), ("y", pl.List(pl.Float64))], + ) + assert_frame_equal(df1, df1, check_exact=True) + + df2 = pl.DataFrame( + data={"x": [[1.0, nan]], "y": [[None, 2.0]]}, + schema=[("x", pl.List(pl.Float32)), ("y", pl.List(pl.Float64))], + ) + assert_frame_not_equal(df1, df2) + with pytest.raises(AssertionError, match="values for column 'y' are different"): + assert_frame_equal(df1, df2, check_exact=True) + + # struct dtype + df3 = pl.from_dicts( + [ + { + "id": 1, + "struct": [ + {"x": "text", "y": [0.0, nan]}, + {"x": "text", "y": [0.0, nan]}, + ], + }, + { + "id": 2, + "struct": [ + {"x": "text", "y": [1]}, + {"x": "text", "y": [1]}, + ], + }, + ] + ) + df4 = pl.from_dicts( + [ + { + "id": 1, + "struct": [ + {"x": "text", "y": [0.0, nan], "z": ["$"]}, + {"x": "text", "y": [0.0, nan], "z": ["$"]}, + ], + }, + { + "id": 2, + "struct": [ + {"x": "text", "y": [nan, 1], "z": ["!"]}, + {"x": "text", "y": [nan, 1], "z": ["?"]}, + ], + }, + ] + ) + + assert_frame_equal(df3, df3) + assert_frame_not_equal(df3, df3, nans_compare_equal=False) + + assert_frame_equal(df4, df4) + assert_frame_not_equal(df4, df4, nans_compare_equal=False) + + assert_frame_not_equal(df3, df4) + for check_dtype in (True, False): + with pytest.raises(AssertionError, match="mismatch|different"): + assert_frame_equal(df3, df4, check_dtype=check_dtype) + + +def test_assert_frame_equal_pass() -> None: + df1 = pl.DataFrame({"a": [1, 2]}) + df2 = pl.DataFrame({"a": [1, 2]}) + assert_frame_equal(df1, df2) + + +def test_assert_frame_equal_types() -> None: + df1 = pl.DataFrame({"a": [1, 2]}) + srs1 = pl.Series(values=[1, 2], name="a") + with pytest.raises( + AssertionError, match=r"Inputs are different \(unexpected input types\)" + ): + assert_frame_equal(df1, srs1) # type: ignore[arg-type] + + +def test_assert_frame_equal_length_mismatch() -> None: + df1 = pl.DataFrame({"a": [1, 2]}) + df2 = pl.DataFrame({"a": [1, 2, 3]}) + with pytest.raises( + AssertionError, match=r"DataFrames are different \(length mismatch\)" + ): + assert_frame_equal(df1, df2) + + +def test_assert_frame_equal_column_mismatch() -> None: + df1 = pl.DataFrame({"a": [1, 2]}) + df2 = pl.DataFrame({"b": [1, 2]}) + with pytest.raises( + AssertionError, match="columns \\['a'\\] in left frame, but not in right" + ): + assert_frame_equal(df1, df2) + + +def test_assert_frame_equal_column_mismatch2() -> None: + df1 = pl.DataFrame({"a": [1, 2]}) + df2 = pl.DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]}) + with pytest.raises( + AssertionError, match="columns \\['b', 'c'\\] in right frame, but not in left" + ): + assert_frame_equal(df1, df2) + + +def test_assert_frame_equal_column_mismatch_order() -> None: + df1 = pl.DataFrame({"b": [3, 4], "a": [1, 2]}) + df2 = pl.DataFrame({"a": [1, 2], "b": [3, 4]}) + with pytest.raises(AssertionError, match="columns are not in the same order"): + assert_frame_equal(df1, df2) + + assert_frame_equal(df1, df2, check_column_order=False) + + +def test_assert_frame_equal_ignore_row_order() -> None: + df1 = pl.DataFrame({"a": [1, 2], "b": [4, 3]}) + df2 = pl.DataFrame({"a": [2, 1], "b": [3, 4]}) + df3 = pl.DataFrame({"b": [3, 4], "a": [2, 1]}) + with pytest.raises(AssertionError, match="values for column 'a' are different."): + assert_frame_equal(df1, df2) + + assert_frame_equal(df1, df2, check_row_order=False) + # eg: + # ┌─────┬─────┐ ┌─────┬─────┐ + # │ a ┆ b │ │ a ┆ b │ + # │ --- ┆ --- │ │ --- ┆ --- │ + # │ i64 ┆ i64 │ (eq) │ i64 ┆ i64 │ + # ╞═════╪═════╡ == ╞═════╪═════╡ + # │ 1 ┆ 4 │ │ 2 ┆ 3 │ + # │ 2 ┆ 3 │ │ 1 ┆ 4 │ + # └─────┴─────┘ └─────┴─────┘ + + with pytest.raises(AssertionError, match="columns are not in the same order"): + assert_frame_equal(df1, df3, check_row_order=False) + + assert_frame_equal(df1, df3, check_row_order=False, check_column_order=False) + + # note: not all column types support sorting + with pytest.raises( + InvalidAssert, + match="cannot set `check_row_order=False`.*unsortable columns", + ): + assert_frame_equal( + left=pl.DataFrame({"a": [[1, 2], [3, 4]], "b": [3, 4]}), + right=pl.DataFrame({"a": [[3, 4], [1, 2]], "b": [4, 3]}), + check_row_order=False, + ) + + +@pytest.mark.parametrize( + ("df1", "df2", "kwargs"), + [ + pytest.param( + pl.DataFrame({"a": [0.2, 0.3]}), + pl.DataFrame({"a": [0.2, 0.3]}), + {"atol": 1e-15}, + id="equal_floats_low_atol", + ), + pytest.param( + pl.DataFrame({"a": [0.2, 0.3]}), + pl.DataFrame({"a": [0.2, 0.3000000000000001]}), + {"atol": 1e-15}, + id="approx_equal_float_low_atol", + ), + pytest.param( + pl.DataFrame({"a": [0.2, 0.3]}), + pl.DataFrame({"a": [0.2, 0.31]}), + {"atol": 0.1}, + id="approx_equal_float_high_atol", + ), + pytest.param( + pl.DataFrame({"a": [0.2, 1.3]}), + pl.DataFrame({"a": [0.2, 0.9]}), + {"atol": 1}, + id="approx_equal_float_integer_atol", + ), + pytest.param( + pl.DataFrame({"a": [0.0, 1.0, 2.0]}, schema={"a": pl.Float64}), + pl.DataFrame({"a": [0, 1, 2]}, schema={"a": pl.Int64}), + {"check_dtype": False}, + id="equal_int_float_integer_no_check_dtype", + ), + pytest.param( + pl.DataFrame({"a": [0, 1, 2]}, schema={"a": pl.Float64}), + pl.DataFrame({"a": [0, 1, 2]}, schema={"a": pl.Float32}), + {"check_dtype": False}, + id="equal_int_float_integer_no_check_dtype", + ), + pytest.param( + pl.DataFrame({"a": [0, 1, 2]}, schema={"a": pl.Int64}), + pl.DataFrame({"a": [0, 1, 2]}, schema={"a": pl.Int64}), + {}, + id="equal_int", + ), + pytest.param( + pl.DataFrame({"a": ["a", "b", "c"]}, schema={"a": pl.Utf8}), + pl.DataFrame({"a": ["a", "b", "c"]}, schema={"a": pl.Utf8}), + {}, + id="equal_str", + ), + pytest.param( + pl.DataFrame({"a": [[0.2, 0.3]]}), + pl.DataFrame({"a": [[0.2, 0.300001]]}), + {"atol": 1e-5}, + id="list_of_float_low_atol", + ), + pytest.param( + pl.DataFrame({"a": [[0.2, 0.3]]}), + pl.DataFrame({"a": [[0.2, 0.31]]}), + {"atol": 0.1}, + id="list_of_float_high_atol", + ), + pytest.param( + pl.DataFrame({"a": [[0.2, 1.3]]}), + pl.DataFrame({"a": [[0.2, 0.9]]}), + {"atol": 1}, + id="list_of_float_integer_atol", + ), + pytest.param( + pl.DataFrame({"a": [[0.2, 0.3]]}), + pl.DataFrame({"a": [[0.2, 0.300000001]]}), + {"rtol": 1e-5}, + id="list_of_float_low_rtol", + ), + pytest.param( + pl.DataFrame({"a": [[0.2, 0.3]]}), + pl.DataFrame({"a": [[0.2, 0.301]]}), + {"rtol": 0.1}, + id="list_of_float_high_rtol", + ), + pytest.param( + pl.DataFrame({"a": [[0.2, 1.3]]}), + pl.DataFrame({"a": [[0.2, 0.9]]}), + {"rtol": 1}, + id="list_of_float_integer_rtol", + ), + pytest.param( + pl.DataFrame({"a": [[None, 1.3]]}), + pl.DataFrame({"a": [[None, 0.9]]}), + {"rtol": 1}, + id="list_of_none_and_float_integer_rtol", + ), + pytest.param( + pl.DataFrame({"a": [[None, 1.3]]}), + pl.DataFrame({"a": [[None, 0.9]]}), + {"rtol": 1, "nans_compare_equal": False}, + id="list_of_none_and_float_integer_rtol", + ), + pytest.param( + pl.DataFrame({"a": [[[0.2, 3.0]]]}), + pl.DataFrame({"a": [[[0.2, 3.00000001]]]}), + {"atol": 0.1, "nans_compare_equal": True}, + id="nested_list_of_float_atol_high_nans_compare_equal_true", + ), + pytest.param( + pl.DataFrame({"a": [[[0.2, 3.0]]]}), + pl.DataFrame({"a": [[[0.2, 3.00000001]]]}), + {"atol": 0.1, "nans_compare_equal": False}, + id="nested_list_of_float_atol_high_nans_compare_equal_false", + ), + ], +) +def test_assert_frame_equal_passes_assertion( + df1: pl.DataFrame, + df2: pl.DataFrame, + kwargs: Any, +) -> None: + assert_frame_equal(df1, df2, **kwargs) + with pytest.raises(AssertionError): + assert_frame_not_equal(df1, df2, **kwargs) + + +@pytest.mark.parametrize( + ("df1", "df2", "kwargs"), + [ + pytest.param( + pl.DataFrame({"a": [[0.2, 0.3]]}), + pl.DataFrame({"a": [[0.2, 0.3, 0.4]]}), + {}, + id="list_of_float_different_lengths", + ), + pytest.param( + pl.DataFrame({"a": [[0.2, 0.3]]}), + pl.DataFrame({"a": [[0.2, 0.3000000000000001]]}), + {"check_exact": True}, + id="list_of_float_check_exact", + ), + pytest.param( + pl.DataFrame({"a": [[0.2, 0.3]]}), + pl.DataFrame({"a": [[0.2, 0.300001]]}), + {"atol": 1e-15, "rtol": 0}, + id="list_of_float_too_low_atol", + ), + pytest.param( + pl.DataFrame({"a": [[0.2, 0.3]]}), + pl.DataFrame({"a": [[0.2, 0.30000001]]}), + {"atol": -1, "rtol": 0}, + id="list_of_float_negative_atol", + ), + pytest.param( + pl.DataFrame({"a": [[math.nan, 1.3]]}), + pl.DataFrame({"a": [[math.nan, 0.9]]}), + {"rtol": 1, "nans_compare_equal": False}, + id="list_of_nan_and_float_integer_rtol", + ), + pytest.param( + pl.DataFrame({"a": [[2.0, 3.0]]}), + pl.DataFrame({"a": [[2, 3]]}), + {"check_exact": False, "check_dtype": True}, + id="list_of_float_list_of_int_check_dtype_true", + ), + pytest.param( + pl.DataFrame({"a": [[[0.2, math.nan, 3.0]]]}), + pl.DataFrame({"a": [[[0.2, math.nan, 3.11]]]}), + {"atol": 0.1, "rtol": 0, "nans_compare_equal": True}, + id="nested_list_of_float_and_nan_atol_high_nans_compare_equal_true", + ), + pytest.param( + pl.DataFrame({"a": [[[0.2, math.nan, 3.0]]]}), + pl.DataFrame({"a": [[[0.2, math.nan, 3.0]]]}), + {"nans_compare_equal": False}, + id="nested_list_of_float_and_nan_atol_high_nans_compare_equal_false", + ), + pytest.param( + pl.DataFrame({"a": [[[0.2, 3.0]]]}), + pl.DataFrame({"a": [[[0.2, 3.11]]]}), + {"atol": 0.1, "nans_compare_equal": False}, + id="nested_list_of_float_atol_high_nans_compare_equal_false", + ), + pytest.param( + pl.DataFrame({"a": [[[[0.2, 3.0]]]]}), + pl.DataFrame({"a": [[[[0.2, 3.11]]]]}), + {"atol": 0.1, "rtol": 0, "nans_compare_equal": True}, + id="double_nested_list_of_float_atol_high_nans_compare_equal_true", + ), + pytest.param( + pl.DataFrame({"a": [[[[0.2, math.nan, 3.0]]]]}), + pl.DataFrame({"a": [[[[0.2, math.nan, 3.0]]]]}), + {"atol": 0.1, "nans_compare_equal": False}, + id="double_nested_list_of_float_and_nan_atol_high_nans_compare_equal_false", + ), + pytest.param( + pl.DataFrame({"a": [[[[0.2, 3.0]]]]}), + pl.DataFrame({"a": [[[[0.2, 3.11]]]]}), + {"atol": 0.1, "rtol": 0, "nans_compare_equal": False}, + id="double_nested_list_of_float_atol_high_nans_compare_equal_false", + ), + pytest.param( + pl.DataFrame({"a": [[[[[0.2, 3.0]]]]]}), + pl.DataFrame({"a": [[[[[0.2, 3.11]]]]]}), + {"atol": 0.1, "rtol": 0, "nans_compare_equal": True}, + id="triple_nested_list_of_float_atol_high_nans_compare_equal_true", + ), + pytest.param( + pl.DataFrame({"a": [[[[[0.2, math.nan, 3.0]]]]]}), + pl.DataFrame({"a": [[[[[0.2, math.nan, 3.0]]]]]}), + {"atol": 0.1, "nans_compare_equal": False}, + id="triple_nested_list_of_float_and_nan_atol_high_nans_compare_equal_true", + ), + ], +) +def test_assert_frame_equal_raises_assertion_error( + df1: pl.DataFrame, + df2: pl.DataFrame, + kwargs: Any, +) -> None: + with pytest.raises(AssertionError): + assert_frame_equal(df1, df2, **kwargs) + assert_frame_not_equal(df1, df2, **kwargs) + + def test_assert_series_equal_int_overflow() -> None: # internally may call 'abs' if not check_exact, which can overflow on signed int s0 = pl.Series([-128], dtype=pl.Int8) @@ -596,7 +1021,7 @@ def test_assert_series_equal_raises_assertion_error( def test_assert_series_equal_categorical() -> None: s1 = pl.Series(["a", "b", "a"], dtype=pl.Categorical) s2 = pl.Series(["a", "b", "a"], dtype=pl.Categorical) - with pytest.raises(AssertionError, match="incompatible data types"): + with pytest.raises(pl.ComputeError, match="cannot compare categoricals"): assert_series_equal(s1, s2) assert_series_equal(s1, s2, categorical_as_str=True) @@ -609,91 +1034,9 @@ def test_assert_series_equal_categorical_vs_str() -> None: with pytest.raises(AssertionError, match="dtype mismatch"): assert_series_equal(s1, s2, categorical_as_str=True) - assert_series_equal(s1, s2, check_dtype=False, categorical_as_str=True) - assert_series_equal(s2, s1, check_dtype=False, categorical_as_str=True) - - -def test_assert_series_equal_incompatible_data_types() -> None: - s1 = pl.Series(["a", "b", "a"], dtype=pl.Categorical) - s2 = pl.Series([0, 1, 0], dtype=pl.Int8) - - with pytest.raises(AssertionError, match="incompatible data types"): - assert_series_equal(s1, s2, check_dtype=False) - - -def test_assert_series_equal_full_series() -> None: - s1 = pl.Series([1, 2, 3]) - s2 = pl.Series([1, 2, 4]) - msg = ( - r"Series are different \(value mismatch\)\n" - r"\[left\]: \[1, 2, 3\]\n" - r"\[right\]: \[1, 2, 4\]" - ) - with pytest.raises(AssertionError, match=msg): - assert_series_equal(s1, s2) - - -def test_assert_series_not_equal() -> None: - s = pl.Series("a", [1, 2]) - with pytest.raises(AssertionError, match="Series are equal"): - assert_series_not_equal(s, s) - - -def test_assert_series_equal_nested_list_float() -> None: - # First entry has only integers - s1 = pl.Series([[1.0, 2.0], [3.0, 4.0]], dtype=pl.List(pl.Float64)) - s2 = pl.Series([[1.0, 2.0], [3.0, 4.9]], dtype=pl.List(pl.Float64)) - - with pytest.raises(AssertionError): - assert_series_equal(s1, s2) - - -def test_assert_series_equal_nested_struct_float() -> None: - s1 = pl.Series( - [{"a": 1.0, "b": 2.0}, {"a": 3.0, "b": 4.0}], - dtype=pl.Struct({"a": pl.Float64, "b": pl.Float64}), - ) - s2 = pl.Series( - [{"a": 1.0, "b": 2.0}, {"a": 3.0, "b": 4.9}], - dtype=pl.Struct({"a": pl.Float64, "b": pl.Float64}), - ) - - with pytest.raises(AssertionError): - assert_series_equal(s1, s2) - - -def test_assert_series_equal_nested_list_full_null() -> None: - # First entry has only integers - s1 = pl.Series([None, None], dtype=pl.List(pl.Float64)) - s2 = pl.Series([None, None], dtype=pl.List(pl.Float64)) - - assert_series_equal(s1, s2) - - -def test_assert_series_equal_nested_list_nan() -> None: - s1 = pl.Series([[1.0, 2.0], [3.0, float("nan")]], dtype=pl.List(pl.Float64)) - s2 = pl.Series([[1.0, 2.0], [3.0, float("nan")]], dtype=pl.List(pl.Float64)) - - with pytest.raises(AssertionError): - assert_series_equal(s1, s2, nans_compare_equal=False) - - -def test_assert_series_equal_nested_list_none() -> None: - s1 = pl.Series([[1.0, 2.0], None], dtype=pl.List(pl.Float64)) - s2 = pl.Series([[1.0, 2.0], None], dtype=pl.List(pl.Float64)) - - assert_series_equal(s1, s2, nans_compare_equal=False) - - -def test_assert_series_equal_full_none_nested_not_nested() -> None: - s1 = pl.Series([None, None], dtype=pl.List(pl.Float64)) - s2 = pl.Series([None, None], dtype=pl.Float64) - - assert_series_equal(s1, s2, check_dtype=False) - -def test_assert_series_equal_unsigned_ints_underflow() -> None: - s1 = pl.Series([1, 3], dtype=pl.UInt8) - s2 = pl.Series([2, 4], dtype=pl.Int64) +def test_assert_frame_equal_local_categoricals_deprecated() -> None: + df = pl.Series(["a", "b", "a"], dtype=pl.Categorical).to_frame() - assert_series_equal(s1, s2, atol=1, check_dtype=False) + with pytest.deprecated_call(): + assert_frame_equal_local_categoricals(df, df) diff --git a/py-polars/tests/unit/testing/__init__.py b/py-polars/tests/unit/testing/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/py-polars/tests/unit/testing/test_assert_frame_equal.py b/py-polars/tests/unit/testing/test_assert_frame_equal.py deleted file mode 100644 index c9a3bf19814d..000000000000 --- a/py-polars/tests/unit/testing/test_assert_frame_equal.py +++ /dev/null @@ -1,420 +0,0 @@ -from __future__ import annotations - -import math -from typing import Any - -import pytest - -import polars as pl -from polars.exceptions import InvalidAssert -from polars.testing import assert_frame_equal, assert_frame_not_equal - - -@pytest.mark.parametrize( - ("df1", "df2", "kwargs"), - [ - pytest.param( - pl.DataFrame({"a": [0.2, 0.3]}), - pl.DataFrame({"a": [0.2, 0.3]}), - {"atol": 1e-15}, - id="equal_floats_low_atol", - ), - pytest.param( - pl.DataFrame({"a": [0.2, 0.3]}), - pl.DataFrame({"a": [0.2, 0.3000000000000001]}), - {"atol": 1e-15}, - id="approx_equal_float_low_atol", - ), - pytest.param( - pl.DataFrame({"a": [0.2, 0.3]}), - pl.DataFrame({"a": [0.2, 0.31]}), - {"atol": 0.1}, - id="approx_equal_float_high_atol", - ), - pytest.param( - pl.DataFrame({"a": [0.2, 1.3]}), - pl.DataFrame({"a": [0.2, 0.9]}), - {"atol": 1}, - id="approx_equal_float_integer_atol", - ), - pytest.param( - pl.DataFrame({"a": [0.0, 1.0, 2.0]}, schema={"a": pl.Float64}), - pl.DataFrame({"a": [0, 1, 2]}, schema={"a": pl.Int64}), - {"check_dtype": False}, - id="equal_int_float_integer_no_check_dtype", - ), - pytest.param( - pl.DataFrame({"a": [0, 1, 2]}, schema={"a": pl.Float64}), - pl.DataFrame({"a": [0, 1, 2]}, schema={"a": pl.Float32}), - {"check_dtype": False}, - id="equal_int_float_integer_no_check_dtype", - ), - pytest.param( - pl.DataFrame({"a": [0, 1, 2]}, schema={"a": pl.Int64}), - pl.DataFrame({"a": [0, 1, 2]}, schema={"a": pl.Int64}), - {}, - id="equal_int", - ), - pytest.param( - pl.DataFrame({"a": ["a", "b", "c"]}, schema={"a": pl.Utf8}), - pl.DataFrame({"a": ["a", "b", "c"]}, schema={"a": pl.Utf8}), - {}, - id="equal_str", - ), - pytest.param( - pl.DataFrame({"a": [[0.2, 0.3]]}), - pl.DataFrame({"a": [[0.2, 0.300001]]}), - {"atol": 1e-5}, - id="list_of_float_low_atol", - ), - pytest.param( - pl.DataFrame({"a": [[0.2, 0.3]]}), - pl.DataFrame({"a": [[0.2, 0.31]]}), - {"atol": 0.1}, - id="list_of_float_high_atol", - ), - pytest.param( - pl.DataFrame({"a": [[0.2, 1.3]]}), - pl.DataFrame({"a": [[0.2, 0.9]]}), - {"atol": 1}, - id="list_of_float_integer_atol", - ), - pytest.param( - pl.DataFrame({"a": [[0.2, 0.3]]}), - pl.DataFrame({"a": [[0.2, 0.300000001]]}), - {"rtol": 1e-5}, - id="list_of_float_low_rtol", - ), - pytest.param( - pl.DataFrame({"a": [[0.2, 0.3]]}), - pl.DataFrame({"a": [[0.2, 0.301]]}), - {"rtol": 0.1}, - id="list_of_float_high_rtol", - ), - pytest.param( - pl.DataFrame({"a": [[0.2, 1.3]]}), - pl.DataFrame({"a": [[0.2, 0.9]]}), - {"rtol": 1}, - id="list_of_float_integer_rtol", - ), - pytest.param( - pl.DataFrame({"a": [[None, 1.3]]}), - pl.DataFrame({"a": [[None, 0.9]]}), - {"rtol": 1}, - id="list_of_none_and_float_integer_rtol", - ), - pytest.param( - pl.DataFrame({"a": [[None, 1.3]]}), - pl.DataFrame({"a": [[None, 0.9]]}), - {"rtol": 1, "nans_compare_equal": False}, - id="list_of_none_and_float_integer_rtol", - ), - pytest.param( - pl.DataFrame({"a": [[[0.2, 3.0]]]}), - pl.DataFrame({"a": [[[0.2, 3.00000001]]]}), - {"atol": 0.1, "nans_compare_equal": True}, - id="nested_list_of_float_atol_high_nans_compare_equal_true", - ), - pytest.param( - pl.DataFrame({"a": [[[0.2, 3.0]]]}), - pl.DataFrame({"a": [[[0.2, 3.00000001]]]}), - {"atol": 0.1, "nans_compare_equal": False}, - id="nested_list_of_float_atol_high_nans_compare_equal_false", - ), - ], -) -def test_assert_frame_equal_passes_assertion( - df1: pl.DataFrame, - df2: pl.DataFrame, - kwargs: dict[str, Any], -) -> None: - assert_frame_equal(df1, df2, **kwargs) - with pytest.raises(AssertionError): - assert_frame_not_equal(df1, df2, **kwargs) - - -@pytest.mark.parametrize( - ("df1", "df2", "kwargs"), - [ - pytest.param( - pl.DataFrame({"a": [[0.2, 0.3]]}), - pl.DataFrame({"a": [[0.2, 0.3, 0.4]]}), - {}, - id="list_of_float_different_lengths", - ), - pytest.param( - pl.DataFrame({"a": [[0.2, 0.3]]}), - pl.DataFrame({"a": [[0.2, 0.3000000000000001]]}), - {"check_exact": True}, - id="list_of_float_check_exact", - ), - pytest.param( - pl.DataFrame({"a": [[0.2, 0.3]]}), - pl.DataFrame({"a": [[0.2, 0.300001]]}), - {"atol": 1e-15, "rtol": 0}, - id="list_of_float_too_low_atol", - ), - pytest.param( - pl.DataFrame({"a": [[0.2, 0.3]]}), - pl.DataFrame({"a": [[0.2, 0.30000001]]}), - {"atol": -1, "rtol": 0}, - id="list_of_float_negative_atol", - ), - pytest.param( - pl.DataFrame({"a": [[math.nan, 1.3]]}), - pl.DataFrame({"a": [[math.nan, 0.9]]}), - {"rtol": 1, "nans_compare_equal": False}, - id="list_of_nan_and_float_integer_rtol", - ), - pytest.param( - pl.DataFrame({"a": [[2.0, 3.0]]}), - pl.DataFrame({"a": [[2, 3]]}), - {"check_exact": False, "check_dtype": True}, - id="list_of_float_list_of_int_check_dtype_true", - ), - pytest.param( - pl.DataFrame({"a": [[[0.2, math.nan, 3.0]]]}), - pl.DataFrame({"a": [[[0.2, math.nan, 3.11]]]}), - {"atol": 0.1, "rtol": 0, "nans_compare_equal": True}, - id="nested_list_of_float_and_nan_atol_high_nans_compare_equal_true", - ), - pytest.param( - pl.DataFrame({"a": [[[0.2, math.nan, 3.0]]]}), - pl.DataFrame({"a": [[[0.2, math.nan, 3.0]]]}), - {"nans_compare_equal": False}, - id="nested_list_of_float_and_nan_atol_high_nans_compare_equal_false", - ), - pytest.param( - pl.DataFrame({"a": [[[0.2, 3.0]]]}), - pl.DataFrame({"a": [[[0.2, 3.11]]]}), - {"atol": 0.1, "nans_compare_equal": False}, - id="nested_list_of_float_atol_high_nans_compare_equal_false", - ), - pytest.param( - pl.DataFrame({"a": [[[[0.2, 3.0]]]]}), - pl.DataFrame({"a": [[[[0.2, 3.11]]]]}), - {"atol": 0.1, "rtol": 0, "nans_compare_equal": True}, - id="double_nested_list_of_float_atol_high_nans_compare_equal_true", - ), - pytest.param( - pl.DataFrame({"a": [[[[0.2, math.nan, 3.0]]]]}), - pl.DataFrame({"a": [[[[0.2, math.nan, 3.0]]]]}), - {"atol": 0.1, "nans_compare_equal": False}, - id="double_nested_list_of_float_and_nan_atol_high_nans_compare_equal_false", - ), - pytest.param( - pl.DataFrame({"a": [[[[0.2, 3.0]]]]}), - pl.DataFrame({"a": [[[[0.2, 3.11]]]]}), - {"atol": 0.1, "rtol": 0, "nans_compare_equal": False}, - id="double_nested_list_of_float_atol_high_nans_compare_equal_false", - ), - pytest.param( - pl.DataFrame({"a": [[[[[0.2, 3.0]]]]]}), - pl.DataFrame({"a": [[[[[0.2, 3.11]]]]]}), - {"atol": 0.1, "rtol": 0, "nans_compare_equal": True}, - id="triple_nested_list_of_float_atol_high_nans_compare_equal_true", - ), - pytest.param( - pl.DataFrame({"a": [[[[[0.2, math.nan, 3.0]]]]]}), - pl.DataFrame({"a": [[[[[0.2, math.nan, 3.0]]]]]}), - {"atol": 0.1, "nans_compare_equal": False}, - id="triple_nested_list_of_float_and_nan_atol_high_nans_compare_equal_true", - ), - ], -) -def test_assert_frame_equal_raises_assertion_error( - df1: pl.DataFrame, - df2: pl.DataFrame, - kwargs: dict[str, Any], -) -> None: - with pytest.raises(AssertionError): - assert_frame_equal(df1, df2, **kwargs) - assert_frame_not_equal(df1, df2, **kwargs) - - -def test_compare_frame_equal_nans() -> None: - nan = float("NaN") - df1 = pl.DataFrame( - data={"x": [1.0, nan], "y": [nan, 2.0]}, - schema=[("x", pl.Float32), ("y", pl.Float64)], - ) - assert_frame_equal(df1, df1, check_exact=True) - - df2 = pl.DataFrame( - data={"x": [1.0, nan], "y": [None, 2.0]}, - schema=[("x", pl.Float32), ("y", pl.Float64)], - ) - assert_frame_not_equal(df1, df2) - with pytest.raises(AssertionError, match="values for column 'y' are different"): - assert_frame_equal(df1, df2, check_exact=True) - - -def test_compare_frame_equal_nested_nans() -> None: - nan = float("NaN") - - # list dtype - df1 = pl.DataFrame( - data={"x": [[1.0, nan]], "y": [[nan, 2.0]]}, - schema=[("x", pl.List(pl.Float32)), ("y", pl.List(pl.Float64))], - ) - assert_frame_equal(df1, df1, check_exact=True) - - df2 = pl.DataFrame( - data={"x": [[1.0, nan]], "y": [[None, 2.0]]}, - schema=[("x", pl.List(pl.Float32)), ("y", pl.List(pl.Float64))], - ) - assert_frame_not_equal(df1, df2) - with pytest.raises(AssertionError, match="values for column 'y' are different"): - assert_frame_equal(df1, df2, check_exact=True) - - # struct dtype - df3 = pl.from_dicts( - [ - { - "id": 1, - "struct": [ - {"x": "text", "y": [0.0, nan]}, - {"x": "text", "y": [0.0, nan]}, - ], - }, - { - "id": 2, - "struct": [ - {"x": "text", "y": [1]}, - {"x": "text", "y": [1]}, - ], - }, - ] - ) - df4 = pl.from_dicts( - [ - { - "id": 1, - "struct": [ - {"x": "text", "y": [0.0, nan], "z": ["$"]}, - {"x": "text", "y": [0.0, nan], "z": ["$"]}, - ], - }, - { - "id": 2, - "struct": [ - {"x": "text", "y": [nan, 1], "z": ["!"]}, - {"x": "text", "y": [nan, 1], "z": ["?"]}, - ], - }, - ] - ) - - assert_frame_equal(df3, df3) - assert_frame_not_equal(df3, df3, nans_compare_equal=False) - - assert_frame_equal(df4, df4) - assert_frame_not_equal(df4, df4, nans_compare_equal=False) - - assert_frame_not_equal(df3, df4) - for check_dtype in (True, False): - with pytest.raises(AssertionError, match="mismatch|different"): - assert_frame_equal(df3, df4, check_dtype=check_dtype) - - -def test_assert_frame_equal_pass() -> None: - df1 = pl.DataFrame({"a": [1, 2]}) - df2 = pl.DataFrame({"a": [1, 2]}) - assert_frame_equal(df1, df2) - - -def test_assert_frame_equal_types() -> None: - df1 = pl.DataFrame({"a": [1, 2]}) - srs1 = pl.Series(values=[1, 2], name="a") - with pytest.raises( - AssertionError, match=r"inputs are different \(unexpected input types\)" - ): - assert_frame_equal(df1, srs1) # type: ignore[arg-type] - - -def test_assert_frame_equal_length_mismatch() -> None: - df1 = pl.DataFrame({"a": [1, 2]}) - df2 = pl.DataFrame({"a": [1, 2, 3]}) - with pytest.raises( - AssertionError, - match=r"DataFrames are different \(number of rows does not match\)", - ): - assert_frame_equal(df1, df2) - - -def test_assert_frame_equal_column_mismatch() -> None: - df1 = pl.DataFrame({"a": [1, 2]}) - df2 = pl.DataFrame({"b": [1, 2]}) - with pytest.raises( - AssertionError, match="columns \\['a'\\] in left DataFrame, but not in right" - ): - assert_frame_equal(df1, df2) - - -def test_assert_frame_equal_column_mismatch2() -> None: - df1 = pl.LazyFrame({"a": [1, 2]}) - df2 = pl.LazyFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]}) - with pytest.raises( - AssertionError, - match="columns \\['b', 'c'\\] in right LazyFrame, but not in left", - ): - assert_frame_equal(df1, df2) - - -def test_assert_frame_equal_column_mismatch_order() -> None: - df1 = pl.DataFrame({"b": [3, 4], "a": [1, 2]}) - df2 = pl.DataFrame({"a": [1, 2], "b": [3, 4]}) - with pytest.raises(AssertionError, match="columns are not in the same order"): - assert_frame_equal(df1, df2) - - assert_frame_equal(df1, df2, check_column_order=False) - - -def test_assert_frame_equal_ignore_row_order() -> None: - df1 = pl.DataFrame({"a": [1, 2], "b": [4, 3]}) - df2 = pl.DataFrame({"a": [2, 1], "b": [3, 4]}) - df3 = pl.DataFrame({"b": [3, 4], "a": [2, 1]}) - with pytest.raises(AssertionError, match="values for column 'a' are different"): - assert_frame_equal(df1, df2) - - assert_frame_equal(df1, df2, check_row_order=False) - # eg: - # ┌─────┬─────┐ ┌─────┬─────┐ - # │ a ┆ b │ │ a ┆ b │ - # │ --- ┆ --- │ │ --- ┆ --- │ - # │ i64 ┆ i64 │ (eq) │ i64 ┆ i64 │ - # ╞═════╪═════╡ == ╞═════╪═════╡ - # │ 1 ┆ 4 │ │ 2 ┆ 3 │ - # │ 2 ┆ 3 │ │ 1 ┆ 4 │ - # └─────┴─────┘ └─────┴─────┘ - - with pytest.raises(AssertionError, match="columns are not in the same order"): - assert_frame_equal(df1, df3, check_row_order=False) - - assert_frame_equal(df1, df3, check_row_order=False, check_column_order=False) - - # note: not all column types support sorting - with pytest.raises( - InvalidAssert, - match="cannot set `check_row_order=False`.*unsortable columns", - ): - assert_frame_equal( - left=pl.DataFrame({"a": [[1, 2], [3, 4]], "b": [3, 4]}), - right=pl.DataFrame({"a": [[3, 4], [1, 2]], "b": [4, 3]}), - check_row_order=False, - ) - - -def test_assert_frame_equal_dtypes_mismatch() -> None: - data = {"a": [1, 2], "b": [3, 4]} - df1 = pl.DataFrame(data, schema={"a": pl.Int8, "b": pl.Int16}) - df2 = pl.DataFrame(data, schema={"b": pl.Int16, "a": pl.Int16}) - - with pytest.raises(AssertionError, match="dtypes do not match"): - assert_frame_equal(df1, df2, check_column_order=False) - - -def test_assert_frame_not_equal() -> None: - df = pl.DataFrame({"a": [1, 2]}) - with pytest.raises(AssertionError, match="frames are equal"): - assert_frame_not_equal(df, df)