From 1554a579a9464d9f2e3d4e84435bd1e5fe927c22 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Mon, 15 Jul 2024 17:31:47 +0100 Subject: [PATCH] align formulae between datetime and duration --- crates/polars-time/src/round.rs | 21 +++++++------------ crates/polars-time/src/truncate.rs | 14 +++++++------ .../namespaces/temporal/test_truncate.py | 2 +- 3 files changed, 17 insertions(+), 20 deletions(-) diff --git a/crates/polars-time/src/round.rs b/crates/polars-time/src/round.rs index 8a304958db32..5821f3e786fb 100644 --- a/crates/polars-time/src/round.rs +++ b/crates/polars-time/src/round.rs @@ -12,6 +12,11 @@ pub trait PolarsRound { Self: Sized; } +fn simple_round(t: i64, every: i64) -> i64 { + let half_away = t.signum() * every / 2; + t + half_away - (t + half_away) % every +} + impl PolarsRound for DatetimeChunked { fn round(&self, every: &StringChunked, tz: Option<&Tz>) -> PolarsResult { let time_zone = self.time_zone(); @@ -33,11 +38,7 @@ impl PolarsRound for DatetimeChunked { TimeUnit::Nanoseconds => every_parsed.duration_ns(), }; return Ok(self - .apply_values(|t| { - // Round half-way values away from zero - let half_away = t.signum() * every / 2; - t + half_away - (t + half_away) % every - }) + .apply_values(|t| simple_round(t, every)) .into_datetime(self.time_unit(), time_zone.clone())); } else { let w = Window::new(every_parsed, every_parsed, offset); @@ -143,11 +144,7 @@ impl PolarsRound for DurationChunked { TimeUnit::Nanoseconds => every_parsed.duration_ns(), }; return Ok(self - .apply_values(|t| { - // Round half-way values away from zero - let half_away = t.signum() * every / 2; - t + half_away - (t + half_away) % every - }) + .apply_values(|t| simple_round(t, every)) .into_duration(self.time_unit())); } else { return Ok(Int64Chunked::full_null(self.name(), self.len()) @@ -172,9 +169,7 @@ impl PolarsRound for DurationChunked { TimeUnit::Microseconds => every_parsed.duration_us(), TimeUnit::Nanoseconds => every_parsed.duration_ns(), }; - // Round half-way values away from zero - let half_away = t.signum() * every / 2; - Ok(Some(t + half_away - (t + half_away) % every)) + Ok(Some(simple_round(t, every))) }, _ => Ok(None), }); diff --git a/crates/polars-time/src/truncate.rs b/crates/polars-time/src/truncate.rs index b7380378e1cd..25d6c3f2a206 100644 --- a/crates/polars-time/src/truncate.rs +++ b/crates/polars-time/src/truncate.rs @@ -12,6 +12,11 @@ pub trait PolarsTruncate { Self: Sized; } +fn simple_truncate(t: i64, every: i64) -> i64 { + let remainder = t % every; + t - (remainder + every * (remainder < 0) as i64) +} + impl PolarsTruncate for DatetimeChunked { fn truncate(&self, tz: Option<&Tz>, every: &StringChunked) -> PolarsResult { let time_zone = self.time_zone(); @@ -33,10 +38,7 @@ impl PolarsTruncate for DatetimeChunked { TimeUnit::Nanoseconds => every_parsed.duration_ns(), }; return Ok(self - .apply_values(|t| { - let remainder = t % every; - t - (remainder + every * (remainder < 0) as i64) - }) + .apply_values(|t| simple_truncate(t, every)) .into_datetime(self.time_unit(), time_zone.clone())); } else { let w = Window::new(every_parsed, every_parsed, offset); @@ -140,7 +142,7 @@ impl PolarsTruncate for DurationChunked { TimeUnit::Nanoseconds => every_parsed.duration_ns(), }; return Ok(self - .apply_values(|t| t - t % every) + .apply_values(|t: i64| simple_truncate(t, every)) .into_duration(self.time_unit())); } else { return Ok(Int64Chunked::full_null(self.name(), self.len()) @@ -165,7 +167,7 @@ impl PolarsTruncate for DurationChunked { TimeUnit::Microseconds => every_parsed.duration_us(), TimeUnit::Nanoseconds => every_parsed.duration_ns(), }; - Ok(Some(t - t % every)) + Ok(Some(simple_truncate(t, every))) }, _ => Ok(None), }); diff --git a/py-polars/tests/unit/operations/namespaces/temporal/test_truncate.py b/py-polars/tests/unit/operations/namespaces/temporal/test_truncate.py index 7a7fee428d5b..868422ac65e4 100644 --- a/py-polars/tests/unit/operations/namespaces/temporal/test_truncate.py +++ b/py-polars/tests/unit/operations/namespaces/temporal/test_truncate.py @@ -174,7 +174,7 @@ def test_truncate_duration(time_unit: TimeUnit) -> None: timedelta(seconds=30), timedelta(seconds=50), None, - timedelta(seconds=-30), + timedelta(seconds=-40), ] ).dt.cast_time_unit(time_unit)