From 78b8601934dbaa132dce5e9dfc459129c22ce286 Mon Sep 17 00:00:00 2001 From: Rob <124158982+rob-sil@users.noreply.github.com> Date: Fri, 16 Feb 2024 04:54:49 -0800 Subject: [PATCH] Handle rounding to the nearest zero --- crates/polars-time/src/round.rs | 4 ++++ crates/polars-time/src/truncate.rs | 12 ++++++++---- py-polars/tests/unit/namespaces/test_datetime.py | 16 ++++++++++++++++ 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/crates/polars-time/src/round.rs b/crates/polars-time/src/round.rs index c7f512d510964..0fd166fbad717 100644 --- a/crates/polars-time/src/round.rs +++ b/crates/polars-time/src/round.rs @@ -53,6 +53,10 @@ impl PolarsRound for DurationChunked { TimeUnit::Milliseconds => (every.duration_ms(), offset.duration_ms()), }; + if every == 0 { + polars_bail!(InvalidOperation: "duration cannot be zero.") + } + let out = self.apply_values(|duration| { // Round half-way values away from zero let half_away = duration.signum() * every / 2; diff --git a/crates/polars-time/src/truncate.rs b/crates/polars-time/src/truncate.rs index c7487cb72cd70..809cbe9d6a5a2 100644 --- a/crates/polars-time/src/truncate.rs +++ b/crates/polars-time/src/truncate.rs @@ -113,13 +113,17 @@ impl PolarsTruncate for DurationChunked { if every_duration.is_constant_duration() { let every_units = to_time_unit(&every_duration); + if every_units == 0 { + polars_bail!(InvalidOperation: "duration cannot be zero.") + } + Ok(self .0 .apply_values(|duration| duration - duration % every_units + offset_units)) } else { - Err(polars_err!(InvalidOperation: + polars_bail!(InvalidOperation: "Cannot truncate a Duration series to a non-constant duration." - )) + ) } } else { Ok(Int64Chunked::full_null(self.name(), self.len())) @@ -133,9 +137,9 @@ impl PolarsTruncate for DurationChunked { Ok(Some(duration - duration % every_units + offset_units)) } else { - Err(polars_err!(InvalidOperation: + polars_bail!(InvalidOperation: "Cannot truncate a Duration series to a non-constant duration." - )) + ) } } else { Ok(None) diff --git a/py-polars/tests/unit/namespaces/test_datetime.py b/py-polars/tests/unit/namespaces/test_datetime.py index e36340bada2b4..5f2483d690b42 100644 --- a/py-polars/tests/unit/namespaces/test_datetime.py +++ b/py-polars/tests/unit/namespaces/test_datetime.py @@ -538,6 +538,14 @@ def test_truncate_duration(time_unit: TimeUnit) -> None: assert_series_equal(durations.dt.truncate("10s"), expected) +def test_truncate_duration_zero() -> None: + """Truncating to the nearest zero should raise a descriptive error.""" + durations = pl.Series([timedelta(seconds=21), timedelta(seconds=35)]) + + with pytest.raises(InvalidOperationError, match="duration cannot be zero"): + durations.dt.truncate("0s") + + @pytest.mark.parametrize("every_unit", ["d", "w", "mo", "q", "y"]) def test_truncated_duration_non_constant(every_unit: str) -> None: # Duration series can't be truncated to non-constant durations @@ -614,6 +622,14 @@ def test_round_duration(time_unit: TimeUnit) -> None: assert_series_equal(durations.dt.round("10s"), expected) +def test_round_duration_zero() -> None: + """Rounding to the nearest zero should raise a descriptive error.""" + durations = pl.Series([timedelta(seconds=21), timedelta(seconds=35)]) + + with pytest.raises(InvalidOperationError, match="duration cannot be zero"): + durations.dt.round("0s") + + @pytest.mark.parametrize("every", ["d", "w", "mo", "q", "y"]) def test_round_duration_non_constant(every: str) -> None: # Duration series can't be rounded to non-constant durations