diff --git a/crates/polars-plan/src/dsl/function_expr/datetime.rs b/crates/polars-plan/src/dsl/function_expr/datetime.rs index 1a6251a43f82..ea1c6f17c225 100644 --- a/crates/polars-plan/src/dsl/function_expr/datetime.rs +++ b/crates/polars-plan/src/dsl/function_expr/datetime.rs @@ -397,7 +397,12 @@ pub(super) fn truncate(s: &[Series]) -> PolarsResult { _ => time_series.datetime()?.truncate(None, every)?.into_series(), }, DataType::Date => time_series.date()?.truncate(None, every)?.into_series(), - dt => polars_bail!(opq = round, got = dt, expected = "date/datetime"), + DataType::Duration(_) => time_series.duration()?.truncate(None, every)?.into_series(), + dt => polars_bail!( + opq = truncate, + got = dt, + expected = "date/datetime/duration" + ), }; out.set_sorted_flag(time_series.is_sorted_flag()); Ok(out) @@ -498,7 +503,12 @@ pub(super) fn round(s: &[Series]) -> PolarsResult { .unwrap() .round(every, None)? .into_series(), - dt => polars_bail!(opq = round, got = dt, expected = "date/datetime"), + DataType::Duration(_) => time_series + .duration() + .unwrap() + .round(every, None)? + .into_series(), + dt => polars_bail!(opq = round, got = dt, expected = "date/datetime/duration"), }) } diff --git a/crates/polars-time/src/round.rs b/crates/polars-time/src/round.rs index 4bb6f2a3386f..5821f3e786fb 100644 --- a/crates/polars-time/src/round.rs +++ b/crates/polars-time/src/round.rs @@ -12,6 +12,11 @@ pub trait PolarsRound { Self: Sized; } +fn simple_round(t: i64, every: i64) -> i64 { + let half_away = t.signum() * every / 2; + t + half_away - (t + half_away) % every +} + impl PolarsRound for DatetimeChunked { fn round(&self, every: &StringChunked, tz: Option<&Tz>) -> PolarsResult { let time_zone = self.time_zone(); @@ -21,9 +26,7 @@ impl PolarsRound for DatetimeChunked { if every.len() == 1 { if let Some(every) = every.get(0) { let every_parsed = Duration::parse(every); - if every_parsed.negative { - polars_bail!(ComputeError: "cannot round a Datetime to a negative duration") - } + polars_ensure!(!every_parsed.negative & !every_parsed.is_zero(), InvalidOperation: "cannot round a Datetime to a non-positive Duration"); if (time_zone.is_none() || time_zone.as_deref() == Some("UTC")) && (every_parsed.months() == 0 && every_parsed.weeks() == 0) { @@ -35,11 +38,7 @@ impl PolarsRound for DatetimeChunked { TimeUnit::Nanoseconds => every_parsed.duration_ns(), }; return Ok(self - .apply_values(|t| { - // Round half-way values away from zero - let half_away = t.signum() * every / 2; - t + half_away - (t + half_away) % every - }) + .apply_values(|t| simple_round(t, every)) .into_datetime(self.time_unit(), time_zone.clone())); } else { let w = Window::new(every_parsed, every_parsed, offset); @@ -76,14 +75,11 @@ impl PolarsRound for DatetimeChunked { opt_every, ) { (Some(timestamp), Some(every)) => { - let every = + let every_parsed = *duration_cache.get_or_insert_with(every, |every| Duration::parse(every)); + polars_ensure!(!every_parsed.negative & !every_parsed.is_zero(), InvalidOperation: "cannot round a Datetime to a non-positive Duration"); - if every.negative { - polars_bail!(ComputeError: "cannot round a Datetime to a negative duration") - } - - let w = Window::new(every, every, offset); + let w = Window::new(every_parsed, every_parsed, offset); func(&w, timestamp, tz).map(Some) }, _ => Ok(None), @@ -98,11 +94,9 @@ impl PolarsRound for DateChunked { let out = match every.len() { 1 => { if let Some(every) = every.get(0) { - let every = Duration::parse(every); - if every.negative { - polars_bail!(ComputeError: "cannot round a Date to a negative duration") - } - let w = Window::new(every, every, offset); + let every_parsed = Duration::parse(every); + polars_ensure!(!every_parsed.negative & !every_parsed.is_zero(), InvalidOperation: "cannot round a Date to a non-positive Duration"); + let w = Window::new(every_parsed, every_parsed, offset); self.try_apply_nonnull_values_generic(|t| { Ok( (w.round_ms(MILLISECONDS_IN_DAY * t as i64, None)? @@ -118,14 +112,11 @@ impl PolarsRound for DateChunked { let mut duration_cache = FastFixedCache::new((every.len() as f64).sqrt() as usize); match (opt_t, opt_every) { (Some(t), Some(every)) => { - let every = *duration_cache + let every_parsed = *duration_cache .get_or_insert_with(every, |every| Duration::parse(every)); + polars_ensure!(!every_parsed.negative & !every_parsed.is_zero(), InvalidOperation: "cannot round a Date to a non-positive Duration"); - if every.negative { - polars_bail!(ComputeError: "cannot round a Date to a negative duration") - } - - let w = Window::new(every, every, offset); + let w = Window::new(every_parsed, every_parsed, offset); Ok(Some( (w.round_ms(MILLISECONDS_IN_DAY * t as i64, None)? / MILLISECONDS_IN_DAY) as i32, @@ -138,3 +129,50 @@ impl PolarsRound for DateChunked { Ok(out?.into_date()) } } + +#[cfg(feature = "dtype-duration")] +impl PolarsRound for DurationChunked { + fn round(&self, every: &StringChunked, _tz: Option<&Tz>) -> PolarsResult { + if every.len() == 1 { + if let Some(every) = every.get(0) { + let every_parsed = Duration::parse(every); + polars_ensure!(!every_parsed.negative & !every_parsed.is_zero(), InvalidOperation: "cannot round a Duration to a non-positive Duration"); + polars_ensure!(every_parsed.is_constant_duration(None), InvalidOperation:"cannot round a Duration to a non-constant Duration (i.e. one that involves weeks / months)"); + let every = match self.time_unit() { + TimeUnit::Milliseconds => every_parsed.duration_ms(), + TimeUnit::Microseconds => every_parsed.duration_us(), + TimeUnit::Nanoseconds => every_parsed.duration_ns(), + }; + return Ok(self + .apply_values(|t| simple_round(t, every)) + .into_duration(self.time_unit())); + } else { + return Ok(Int64Chunked::full_null(self.name(), self.len()) + .into_duration(self.time_unit())); + } + } + + // A sqrt(n) cache is not too small, not too large. + let mut duration_cache = FastFixedCache::new((every.len() as f64).sqrt() as usize); + + let out = broadcast_try_binary_elementwise(self, every, |opt_timestamp, opt_every| match ( + opt_timestamp, + opt_every, + ) { + (Some(t), Some(every)) => { + let every_parsed = + *duration_cache.get_or_insert_with(every, |every| Duration::parse(every)); + polars_ensure!(!every_parsed.negative, InvalidOperation: "cannot round a Duration to a negative duration"); + polars_ensure!(every_parsed.is_constant_duration(None), InvalidOperation:"cannot round a Duration to a non-constant Duration (i.e. one that involves weeks / months)"); + let every = match self.time_unit() { + TimeUnit::Milliseconds => every_parsed.duration_ms(), + TimeUnit::Microseconds => every_parsed.duration_us(), + TimeUnit::Nanoseconds => every_parsed.duration_ns(), + }; + Ok(Some(simple_round(t, every))) + }, + _ => Ok(None), + }); + Ok(out?.into_duration(self.time_unit())) + } +} diff --git a/crates/polars-time/src/truncate.rs b/crates/polars-time/src/truncate.rs index 991ce50b547a..25d6c3f2a206 100644 --- a/crates/polars-time/src/truncate.rs +++ b/crates/polars-time/src/truncate.rs @@ -12,6 +12,11 @@ pub trait PolarsTruncate { Self: Sized; } +fn simple_truncate(t: i64, every: i64) -> i64 { + let remainder = t % every; + t - (remainder + every * (remainder < 0) as i64) +} + impl PolarsTruncate for DatetimeChunked { fn truncate(&self, tz: Option<&Tz>, every: &StringChunked) -> PolarsResult { let time_zone = self.time_zone(); @@ -21,9 +26,7 @@ impl PolarsTruncate for DatetimeChunked { if every.len() == 1 { if let Some(every) = every.get(0) { let every_parsed = Duration::parse(every); - if every_parsed.negative { - polars_bail!(ComputeError: "cannot truncate a Datetime to a negative duration") - } + polars_ensure!(!every_parsed.negative & !every_parsed.is_zero(), InvalidOperation: "cannot truncate a Datetime to a non-positive Duration"); if (time_zone.is_none() || time_zone.as_deref() == Some("UTC")) && (every_parsed.months() == 0 && every_parsed.weeks() == 0) { @@ -35,10 +38,7 @@ impl PolarsTruncate for DatetimeChunked { TimeUnit::Nanoseconds => every_parsed.duration_ns(), }; return Ok(self - .apply_values(|t| { - let remainder = t % every; - t - (remainder + every * (remainder < 0) as i64) - }) + .apply_values(|t| simple_truncate(t, every)) .into_datetime(self.time_unit(), time_zone.clone())); } else { let w = Window::new(every_parsed, every_parsed, offset); @@ -75,14 +75,11 @@ impl PolarsTruncate for DatetimeChunked { opt_every, ) { (Some(timestamp), Some(every)) => { - let every = + let every_parsed = *duration_cache.get_or_insert_with(every, |every| Duration::parse(every)); + polars_ensure!(!every_parsed.negative & !every_parsed.is_zero(), InvalidOperation: "cannot truncate a Datetime to a non-positive Duration"); - if every.negative { - polars_bail!(ComputeError: "cannot truncate a Datetime to a negative duration") - } - - let w = Window::new(every, every, offset); + let w = Window::new(every_parsed, every_parsed, offset); func(&w, timestamp, tz).map(Some) }, _ => Ok(None), @@ -97,11 +94,9 @@ impl PolarsTruncate for DateChunked { let out = match every.len() { 1 => { if let Some(every) = every.get(0) { - let every = Duration::parse(every); - if every.negative { - polars_bail!(ComputeError: "cannot truncate a Date to a negative duration") - } - let w = Window::new(every, every, offset); + let every_parsed = Duration::parse(every); + polars_ensure!(!every_parsed.negative & !every_parsed.is_zero(), InvalidOperation: "cannot truncate a Date to a non-positive Duration"); + let w = Window::new(every_parsed, every_parsed, offset); self.try_apply_nonnull_values_generic(|t| { Ok((w.truncate_ms(MILLISECONDS_IN_DAY * t as i64, None)? / MILLISECONDS_IN_DAY) as i32) @@ -115,14 +110,11 @@ impl PolarsTruncate for DateChunked { let mut duration_cache = FastFixedCache::new((every.len() as f64).sqrt() as usize); match (opt_t, opt_every) { (Some(t), Some(every)) => { - let every = *duration_cache + let every_parsed = *duration_cache .get_or_insert_with(every, |every| Duration::parse(every)); + polars_ensure!(!every_parsed.negative & !every_parsed.is_zero(), InvalidOperation: "cannot truncate a Date to a non-positive Duration"); - if every.negative { - polars_bail!(ComputeError: "cannot truncate a Date to a negative duration") - } - - let w = Window::new(every, every, offset); + let w = Window::new(every_parsed, every_parsed, offset); Ok(Some( (w.truncate_ms(MILLISECONDS_IN_DAY * t as i64, None)? / MILLISECONDS_IN_DAY) as i32, @@ -135,3 +127,50 @@ impl PolarsTruncate for DateChunked { Ok(out?.into_date()) } } + +#[cfg(feature = "dtype-duration")] +impl PolarsTruncate for DurationChunked { + fn truncate(&self, _tz: Option<&Tz>, every: &StringChunked) -> PolarsResult { + if every.len() == 1 { + if let Some(every) = every.get(0) { + let every_parsed = Duration::parse(every); + polars_ensure!(!every_parsed.negative & !every_parsed.is_zero(), InvalidOperation: "cannot truncate a Duration to a non-positive Duration"); + polars_ensure!(every_parsed.is_constant_duration(None), InvalidOperation:"cannot truncate a Duration to a non-constant Duration (i.e. one that involves weeks / months)"); + let every = match self.time_unit() { + TimeUnit::Milliseconds => every_parsed.duration_ms(), + TimeUnit::Microseconds => every_parsed.duration_us(), + TimeUnit::Nanoseconds => every_parsed.duration_ns(), + }; + return Ok(self + .apply_values(|t: i64| simple_truncate(t, every)) + .into_duration(self.time_unit())); + } else { + return Ok(Int64Chunked::full_null(self.name(), self.len()) + .into_duration(self.time_unit())); + } + } + + // A sqrt(n) cache is not too small, not too large. + let mut duration_cache = FastFixedCache::new((every.len() as f64).sqrt() as usize); + + let out = broadcast_try_binary_elementwise(self, every, |opt_timestamp, opt_every| match ( + opt_timestamp, + opt_every, + ) { + (Some(t), Some(every)) => { + let every_parsed = + *duration_cache.get_or_insert_with(every, |every| Duration::parse(every)); + polars_ensure!(!every_parsed.negative & !every_parsed.is_zero(), InvalidOperation: "cannot truncate a Duration to a non-positive Duration"); + polars_ensure!(every_parsed.is_constant_duration(None), InvalidOperation:"cannot truncate a Duration to a non-constant Duration (i.e. one that involves weeks / months)"); + let every = match self.time_unit() { + TimeUnit::Milliseconds => every_parsed.duration_ms(), + TimeUnit::Microseconds => every_parsed.duration_us(), + TimeUnit::Nanoseconds => every_parsed.duration_ns(), + }; + Ok(Some(simple_truncate(t, every))) + }, + _ => Ok(None), + }); + Ok(out?.into_duration(self.time_unit())) + } +} diff --git a/py-polars/polars/expr/datetime.py b/py-polars/polars/expr/datetime.py index cdf6ccb6516f..528b61984b92 100644 --- a/py-polars/polars/expr/datetime.py +++ b/py-polars/polars/expr/datetime.py @@ -152,10 +152,11 @@ def add_business_days( def truncate(self, every: str | dt.timedelta | Expr) -> Expr: """ - Divide the date/datetime range into buckets. + Divide the dates, datetimes, or durations into buckets. - Each date/datetime is mapped to the start of its bucket using the corresponding - local datetime. Note that weekly buckets start on Monday. + For dates or datetimes, each date/datetime is mapped to the start of its bucket + using the corresponding local datetime. + Note that weekly buckets start on Monday. Ambiguous results are localised using the DST offset of the original timestamp - for example, truncating `'2022-11-06 01:30:00 CST'` by `'1h'` results in `'2022-11-06 01:00:00 CST'`, whereas truncating `'2022-11-06 01:30:00 CDT'` by @@ -192,6 +193,10 @@ def truncate(self, every: str | dt.timedelta | Expr) -> Expr: not be 24 hours, due to daylight savings). Similarly for "calendar week", "calendar month", "calendar quarter", and "calendar year". + Durations may not be truncated to a period length `every` containing calendar + days, weeks, months, quarters, or years, as these are not constant time + intervals. + Returns ------- Expr @@ -278,15 +283,15 @@ def truncate(self, every: str | dt.timedelta | Expr) -> Expr: @unstable() def round(self, every: str | dt.timedelta | IntoExprColumn) -> Expr: """ - Divide the date/datetime range into buckets. + Divide the dates, datetimes, or durations into buckets. .. warning:: This functionality is considered **unstable**. It may be changed at any point without it being considered a breaking change. - Each date/datetime in the first half of the interval + Each date/datetime/duration in the first half of the interval is mapped to the start of its bucket. - Each date/datetime in the second half of the interval + Each date/datetime/duration in the second half of the interval is mapped to the end of its bucket. Ambiguous results are localised using the DST offset of the original timestamp - for example, rounding `'2022-11-06 01:20:00 CST'` by `'1h'` results in @@ -326,6 +331,10 @@ def round(self, every: str | dt.timedelta | IntoExprColumn) -> Expr: not be 24 hours, due to daylight savings). Similarly for "calendar week", "calendar month", "calendar quarter", and "calendar year". + Durations may not be rounded to a period length `every` containing calendar + days, weeks, months, quarters, or years, as these are not constant time + intervals. + Examples -------- >>> from datetime import timedelta, datetime diff --git a/py-polars/polars/series/datetime.py b/py-polars/polars/series/datetime.py index 8c8bfb32bad8..de97e54389bd 100644 --- a/py-polars/polars/series/datetime.py +++ b/py-polars/polars/series/datetime.py @@ -1644,10 +1644,11 @@ def offset_by(self, by: str | Expr) -> Series: def truncate(self, every: str | dt.timedelta | IntoExprColumn) -> Series: """ - Divide the date/ datetime range into buckets. + Divide the dates, datetimes, or durations into buckets. - Each date/datetime is mapped to the start of its bucket using the corresponding - local datetime. Note that weekly buckets start on Monday. + For dates or datetimes, each date/datetime is mapped to the start of its bucket + using the corresponding local datetime. + Note that weekly buckets start on Monday. Ambiguous results are localised using the DST offset of the original timestamp - for example, truncating `'2022-11-06 01:30:00 CST'` by `'1h'` results in `'2022-11-06 01:00:00 CST'`, whereas truncating `'2022-11-06 01:30:00 CDT'` by @@ -1683,6 +1684,10 @@ def truncate(self, every: str | dt.timedelta | IntoExprColumn) -> Series: not be 24 hours, due to daylight savings). Similarly for "calendar week", "calendar month", "calendar quarter", and "calendar year". + Durations may not be truncated to a period length `every` containing calendar + days, weeks, months, quarters, or years, as these are not constant time + intervals. + Returns ------- Series @@ -1758,17 +1763,17 @@ def truncate(self, every: str | dt.timedelta | IntoExprColumn) -> Series: @unstable() def round(self, every: str | dt.timedelta | IntoExprColumn) -> Series: """ - Divide the date/ datetime range into buckets. + Divide the dates, datetimes, or durations into buckets. .. warning:: This functionality is considered **unstable**. It may be changed at any point without it being considered a breaking change. - Each date/datetime in the first half of the interval is mapped to the start of - its bucket. - Each date/datetime in the second half of the interval is mapped to the end of - its bucket. - Ambiguous results are localized using the DST offset of the original timestamp - + Each date/datetime/duration in the first half of the interval + is mapped to the start of its bucket. + Each date/datetime/duration in the second half of the interval + is mapped to the end of its bucket. + Ambiguous results are localised using the DST offset of the original timestamp - for example, rounding `'2022-11-06 01:20:00 CST'` by `'1h'` results in `'2022-11-06 01:00:00 CST'`, whereas rounding `'2022-11-06 01:20:00 CDT'` by `'1h'` results in `'2022-11-06 01:00:00 CDT'`. @@ -1808,6 +1813,10 @@ def round(self, every: str | dt.timedelta | IntoExprColumn) -> Series: not be 24 hours, due to daylight savings). Similarly for "calendar week", "calendar month", "calendar quarter", and "calendar year". + Durations may not be rounded to a period length `every` containing calendar + days, weeks, months, quarters, or years, as these are not constant time + intervals. + Examples -------- >>> from datetime import timedelta, datetime diff --git a/py-polars/tests/unit/operations/namespaces/temporal/test_datetime.py b/py-polars/tests/unit/operations/namespaces/temporal/test_datetime.py index c9a43984cd45..0329400012b9 100644 --- a/py-polars/tests/unit/operations/namespaces/temporal/test_datetime.py +++ b/py-polars/tests/unit/operations/namespaces/temporal/test_datetime.py @@ -452,176 +452,6 @@ def test_duration_extract_times( assert_series_equal(getattr(duration.dt, unit_attr)(), expected) -@pytest.mark.parametrize( - ("time_unit", "every"), - [ - ("ms", "1h"), - ("us", "1h0m0s"), - ("ns", timedelta(hours=1)), - ], - ids=["milliseconds", "microseconds", "nanoseconds"], -) -def test_truncate( - time_unit: TimeUnit, - every: str | timedelta, -) -> None: - start, stop = datetime(2022, 1, 1), datetime(2022, 1, 2) - s = pl.datetime_range( - start, - stop, - timedelta(minutes=30), - time_unit=time_unit, - eager=True, - ).alias(f"dates[{time_unit}]") - - # can pass strings and time-deltas - out = s.dt.truncate(every) - assert out.dt[0] == start - assert out.dt[1] == start - assert out.dt[2] == start + timedelta(hours=1) - assert out.dt[3] == start + timedelta(hours=1) - # ... - assert out.dt[-3] == stop - timedelta(hours=1) - assert out.dt[-2] == stop - timedelta(hours=1) - assert out.dt[-1] == stop - - -def test_truncate_negative() -> None: - """Test that truncating to a negative duration gives a helpful error message.""" - df = pl.DataFrame( - { - "date": [date(1895, 5, 7), date(1955, 11, 5)], - "datetime": [datetime(1895, 5, 7), datetime(1955, 11, 5)], - "duration": ["-1m", "1m"], - } - ) - - with pytest.raises( - ComputeError, match="cannot truncate a Date to a negative duration" - ): - df.select(pl.col("date").dt.truncate("-1m")) - - with pytest.raises( - ComputeError, match="cannot truncate a Datetime to a negative duration" - ): - df.select(pl.col("datetime").dt.truncate("-1m")) - - with pytest.raises( - ComputeError, match="cannot truncate a Date to a negative duration" - ): - df.select(pl.col("date").dt.truncate(pl.col("duration"))) - - with pytest.raises( - ComputeError, match="cannot truncate a Datetime to a negative duration" - ): - df.select(pl.col("datetime").dt.truncate(pl.col("duration"))) - - -@pytest.mark.parametrize( - ("time_unit", "every"), - [ - ("ms", "1h"), - ("us", "1h0m0s"), - ("ns", timedelta(hours=1)), - ], - ids=["milliseconds", "microseconds", "nanoseconds"], -) -def test_round( - time_unit: TimeUnit, - every: str | timedelta, -) -> None: - start, stop = datetime(2022, 1, 1), datetime(2022, 1, 2) - s = pl.datetime_range( - start, - stop, - timedelta(minutes=30), - time_unit=time_unit, - eager=True, - ).alias(f"dates[{time_unit}]") - - # can pass strings and time-deltas - out = s.dt.round(every) - assert out.dt[0] == start - assert out.dt[1] == start + timedelta(hours=1) - assert out.dt[2] == start + timedelta(hours=1) - assert out.dt[3] == start + timedelta(hours=2) - # ... - assert out.dt[-3] == stop - timedelta(hours=1) - assert out.dt[-2] == stop - assert out.dt[-1] == stop - - -def test_round_expr() -> None: - df = pl.DataFrame( - { - "date": [ - datetime(2022, 11, 14), - datetime(2023, 10, 11), - datetime(2022, 3, 20, 5, 7, 18), - datetime(2022, 4, 3, 13, 30, 32), - None, - datetime(2022, 12, 1), - ], - "every": ["1y", "1mo", "1m", "1m", "1mo", None], - } - ) - - output = df.select( - all_expr=pl.col("date").dt.round(every=pl.col("every")), - date_lit=pl.lit(datetime(2022, 4, 3, 13, 30, 32)).dt.round( - every=pl.col("every") - ), - every_lit=pl.col("date").dt.round("1d"), - ) - - expected = pl.DataFrame( - { - "all_expr": [ - datetime(2023, 1, 1), - datetime(2023, 10, 1), - datetime(2022, 3, 20, 5, 7), - datetime(2022, 4, 3, 13, 31), - None, - None, - ], - "date_lit": [ - datetime(2022, 1, 1), - datetime(2022, 4, 1), - datetime(2022, 4, 3, 13, 31), - datetime(2022, 4, 3, 13, 31), - datetime(2022, 4, 1), - None, - ], - "every_lit": [ - datetime(2022, 11, 14), - datetime(2023, 10, 11), - datetime(2022, 3, 20), - datetime(2022, 4, 4), - None, - datetime(2022, 12, 1), - ], - } - ) - - assert_frame_equal(output, expected) - - all_lit = pl.select(all_lit=pl.lit(datetime(2022, 3, 20, 5, 7)).dt.round("1h")) - assert all_lit.to_dict(as_series=False) == {"all_lit": [datetime(2022, 3, 20, 5)]} - - -def test_round_negative() -> None: - """Test that rounding to a negative duration gives a helpful error message.""" - with pytest.raises( - ComputeError, match="cannot round a Date to a negative duration" - ): - pl.Series([date(1895, 5, 7)]).dt.round("-1m") - - with pytest.raises( - ComputeError, match="cannot round a Datetime to a negative duration" - ): - pl.Series([datetime(1895, 5, 7)]).dt.round("-1m") - - @pytest.mark.parametrize( ("time_unit", "date_in_that_unit"), [ diff --git a/py-polars/tests/unit/operations/namespaces/temporal/test_round.py b/py-polars/tests/unit/operations/namespaces/temporal/test_round.py index 1ac7acc3edcd..25b7304208cf 100644 --- a/py-polars/tests/unit/operations/namespaces/temporal/test_round.py +++ b/py-polars/tests/unit/operations/namespaces/temporal/test_round.py @@ -9,7 +9,8 @@ import polars as pl from polars._utils.convert import parse_as_duration_string -from polars.testing import assert_series_equal +from polars.exceptions import InvalidOperationError +from polars.testing import assert_frame_equal, assert_series_equal if TYPE_CHECKING: from zoneinfo import ZoneInfo @@ -189,3 +190,174 @@ def test_round_datetime_w_expression(time_unit: TimeUnit) -> None: result = df.select(pl.col("a").dt.round(pl.col("b")))["a"] assert result[0] == datetime(2020, 1, 1) assert result[1] == datetime(2020, 1, 21) + + +@pytest.mark.parametrize( + ("time_unit", "every"), + [ + ("ms", "1h"), + ("us", "1h0m0s"), + ("ns", timedelta(hours=1)), + ], + ids=["milliseconds", "microseconds", "nanoseconds"], +) +def test_round( + time_unit: TimeUnit, + every: str | timedelta, +) -> None: + start, stop = datetime(2022, 1, 1), datetime(2022, 1, 2) + s = pl.datetime_range( + start, + stop, + timedelta(minutes=30), + time_unit=time_unit, + eager=True, + ).alias(f"dates[{time_unit}]") + + # can pass strings and time-deltas + out = s.dt.round(every) + assert out.dt[0] == start + assert out.dt[1] == start + timedelta(hours=1) + assert out.dt[2] == start + timedelta(hours=1) + assert out.dt[3] == start + timedelta(hours=2) + # ... + assert out.dt[-3] == stop - timedelta(hours=1) + assert out.dt[-2] == stop + assert out.dt[-1] == stop + + +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +def test_round_duration(time_unit: TimeUnit) -> None: + durations = pl.Series( + [ + timedelta(seconds=21), + timedelta(seconds=35), + timedelta(seconds=59), + None, + timedelta(seconds=-35), + ] + ).dt.cast_time_unit(time_unit) + + expected = pl.Series( + [ + timedelta(seconds=20), + timedelta(seconds=40), + timedelta(seconds=60), + None, + timedelta(seconds=-40), + ] + ).dt.cast_time_unit(time_unit) + + assert_series_equal(durations.dt.round("10s"), expected) + + +def test_round_duration_zero() -> None: + """Rounding to the nearest zero should raise a descriptive error.""" + durations = pl.Series([timedelta(seconds=21), timedelta(seconds=35)]) + + with pytest.raises( + InvalidOperationError, + match="cannot round a Duration to a non-positive Duration", + ): + durations.dt.round("0s") + + +@pytest.mark.parametrize("every", ["mo", "q", "y"]) +def test_round_duration_non_constant(every: str) -> None: + # Duration series can't be rounded to non-constant durations + durations = pl.Series([timedelta(seconds=21)]) + + with pytest.raises(InvalidOperationError): + durations.dt.round("1" + every) + + +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +def test_round_duration_half(time_unit: TimeUnit) -> None: + # Values at halfway points should round away from zero + durations = pl.Series( + [timedelta(minutes=-30), timedelta(minutes=30), timedelta(minutes=90)] + ).dt.cast_time_unit(time_unit) + + expected = pl.Series( + [timedelta(hours=-1), timedelta(hours=1), timedelta(hours=2)] + ).dt.cast_time_unit(time_unit) + + assert_series_equal(durations.dt.round("1h"), expected) + + +def test_round_expr() -> None: + df = pl.DataFrame( + { + "date": [ + datetime(2022, 11, 14), + datetime(2023, 10, 11), + datetime(2022, 3, 20, 5, 7, 18), + datetime(2022, 4, 3, 13, 30, 32), + None, + datetime(2022, 12, 1), + ], + "every": ["1y", "1mo", "1m", "1m", "1mo", None], + } + ) + + output = df.select( + all_expr=pl.col("date").dt.round(every=pl.col("every")), + date_lit=pl.lit(datetime(2022, 4, 3, 13, 30, 32)).dt.round( + every=pl.col("every") + ), + every_lit=pl.col("date").dt.round("1d"), + ) + + expected = pl.DataFrame( + { + "all_expr": [ + datetime(2023, 1, 1), + datetime(2023, 10, 1), + datetime(2022, 3, 20, 5, 7), + datetime(2022, 4, 3, 13, 31), + None, + None, + ], + "date_lit": [ + datetime(2022, 1, 1), + datetime(2022, 4, 1), + datetime(2022, 4, 3, 13, 31), + datetime(2022, 4, 3, 13, 31), + datetime(2022, 4, 1), + None, + ], + "every_lit": [ + datetime(2022, 11, 14), + datetime(2023, 10, 11), + datetime(2022, 3, 20), + datetime(2022, 4, 4), + None, + datetime(2022, 12, 1), + ], + } + ) + + assert_frame_equal(output, expected) + + all_lit = pl.select(all_lit=pl.lit(datetime(2022, 3, 20, 5, 7)).dt.round("1h")) + assert all_lit.to_dict(as_series=False) == {"all_lit": [datetime(2022, 3, 20, 5)]} + + +def test_round_negative() -> None: + """Test that rounding to a negative duration gives a helpful error message.""" + with pytest.raises( + InvalidOperationError, match="cannot round a Date to a non-positive Duration" + ): + pl.Series([date(1895, 5, 7)]).dt.round("-1m") + + with pytest.raises( + InvalidOperationError, + match="cannot round a Datetime to a non-positive Duration", + ): + pl.Series([datetime(1895, 5, 7)]).dt.round("-1m") + + with pytest.raises( + InvalidOperationError, + match="cannot round a Duration to a non-positive Duration", + ): + pl.Series([timedelta(days=1)]).dt.round("-1m") diff --git a/py-polars/tests/unit/operations/namespaces/temporal/test_truncate.py b/py-polars/tests/unit/operations/namespaces/temporal/test_truncate.py index f56d356b0457..868422ac65e4 100644 --- a/py-polars/tests/unit/operations/namespaces/temporal/test_truncate.py +++ b/py-polars/tests/unit/operations/namespaces/temporal/test_truncate.py @@ -9,6 +9,7 @@ import polars as pl from polars._utils.convert import parse_as_duration_string +from polars.exceptions import InvalidOperationError from polars.testing import assert_series_equal if TYPE_CHECKING: @@ -119,3 +120,149 @@ def test_fast_path_vs_slow_path(datetimes: list[datetime], every: str) -> None: # Definitely uses slowpath: expected = s.dt.truncate(pl.Series([every] * len(datetimes))) assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + ("time_unit", "every"), + [ + ("ms", "1h"), + ("us", "1h0m0s"), + ("ns", timedelta(hours=1)), + ], + ids=["milliseconds", "microseconds", "nanoseconds"], +) +def test_truncate( + time_unit: TimeUnit, + every: str | timedelta, +) -> None: + start, stop = datetime(2022, 1, 1), datetime(2022, 1, 2) + s = pl.datetime_range( + start, + stop, + timedelta(minutes=30), + time_unit=time_unit, + eager=True, + ).alias(f"dates[{time_unit}]") + + # can pass strings and time-deltas + out = s.dt.truncate(every) + assert out.dt[0] == start + assert out.dt[1] == start + assert out.dt[2] == start + timedelta(hours=1) + assert out.dt[3] == start + timedelta(hours=1) + # ... + assert out.dt[-3] == stop - timedelta(hours=1) + assert out.dt[-2] == stop - timedelta(hours=1) + assert out.dt[-1] == stop + + +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +def test_truncate_duration(time_unit: TimeUnit) -> None: + durations = pl.Series( + [ + timedelta(seconds=21), + timedelta(seconds=35), + timedelta(seconds=59), + None, + timedelta(seconds=-35), + ] + ).dt.cast_time_unit(time_unit) + + expected = pl.Series( + [ + timedelta(seconds=20), + timedelta(seconds=30), + timedelta(seconds=50), + None, + timedelta(seconds=-40), + ] + ).dt.cast_time_unit(time_unit) + + assert_series_equal(durations.dt.truncate("10s"), expected) + + +def test_truncate_duration_zero() -> None: + """Truncating to the nearest zero should raise a descriptive error.""" + durations = pl.Series([timedelta(seconds=21), timedelta(seconds=35)]) + + with pytest.raises( + InvalidOperationError, + match="cannot truncate a Duration to a non-positive Duration", + ): + durations.dt.truncate("0s") + + +def test_truncate_expressions() -> None: + df = pl.DataFrame( + { + "duration": [ + timedelta(seconds=20), + timedelta(seconds=21), + timedelta(seconds=22), + ], + "every": ["3s", "4s", "5s"], + } + ) + result = df.select(pl.col("duration").dt.truncate(pl.col("every")))["duration"] + expected = pl.Series( + "duration", + [timedelta(seconds=18), timedelta(seconds=20), timedelta(seconds=20)], + ) + assert_series_equal(result, expected) + + +@pytest.mark.parametrize("every_unit", ["mo", "q", "y"]) +def test_truncated_duration_non_constant(every_unit: str) -> None: + # Duration series can't be truncated to non-constant durations + df = pl.DataFrame( + { + "durations": [timedelta(seconds=1), timedelta(seconds=2)], + "every": ["1" + every_unit, "1" + every_unit], + } + ) + + with pytest.raises(InvalidOperationError): + df["durations"].dt.truncate("1" + every_unit) + + with pytest.raises(InvalidOperationError): + df.select(pl.col("durations").dt.truncate(pl.col("every"))) + + +def test_truncate_negative() -> None: + """Test that truncating to a negative duration gives a helpful error message.""" + df = pl.DataFrame( + { + "date": [date(1895, 5, 7), date(1955, 11, 5)], + "datetime": [datetime(1895, 5, 7), datetime(1955, 11, 5)], + "duration": [timedelta(minutes=1), timedelta(minutes=-1)], + "every": ["-1m", "1m"], + } + ) + with pytest.raises( + InvalidOperationError, match="cannot truncate a Date to a non-positive Duration" + ): + df.select(pl.col("date").dt.truncate("-1m")) + with pytest.raises( + InvalidOperationError, + match="cannot truncate a Datetime to a non-positive Duration", + ): + df.select(pl.col("datetime").dt.truncate("-1m")) + with pytest.raises( + InvalidOperationError, + match="cannot truncate a Duration to a non-positive Duration", + ): + df.select(pl.col("duration").dt.truncate("-1m")) + with pytest.raises( + InvalidOperationError, match="cannot truncate a Date to a non-positive Duration" + ): + df.select(pl.col("date").dt.truncate(pl.col("every"))) + with pytest.raises( + InvalidOperationError, + match="cannot truncate a Datetime to a non-positive Duration", + ): + df.select(pl.col("datetime").dt.truncate(pl.col("every"))) + with pytest.raises( + InvalidOperationError, + match="cannot truncate a Duration to a non-positive Duration", + ): + df.select(pl.col("duration").dt.truncate(pl.col("every")))