Skip to content

Commit

Permalink
Handle rounding to the nearest zero
Browse files Browse the repository at this point in the history
  • Loading branch information
rob-sil committed Feb 16, 2024
1 parent b269d3c commit 78b8601
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 4 deletions.
4 changes: 4 additions & 0 deletions crates/polars-time/src/round.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ impl PolarsRound for DurationChunked {
TimeUnit::Milliseconds => (every.duration_ms(), offset.duration_ms()),
};

if every == 0 {
polars_bail!(InvalidOperation: "duration cannot be zero.")
}

let out = self.apply_values(|duration| {
// Round half-way values away from zero
let half_away = duration.signum() * every / 2;
Expand Down
12 changes: 8 additions & 4 deletions crates/polars-time/src/truncate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,17 @@ impl PolarsTruncate for DurationChunked {
if every_duration.is_constant_duration() {
let every_units = to_time_unit(&every_duration);

if every_units == 0 {
polars_bail!(InvalidOperation: "duration cannot be zero.")
}

Ok(self
.0
.apply_values(|duration| duration - duration % every_units + offset_units))
} else {
Err(polars_err!(InvalidOperation:
polars_bail!(InvalidOperation:
"Cannot truncate a Duration series to a non-constant duration."
))
)
}
} else {
Ok(Int64Chunked::full_null(self.name(), self.len()))
Expand All @@ -133,9 +137,9 @@ impl PolarsTruncate for DurationChunked {

Ok(Some(duration - duration % every_units + offset_units))
} else {
Err(polars_err!(InvalidOperation:
polars_bail!(InvalidOperation:
"Cannot truncate a Duration series to a non-constant duration."
))
)
}
} else {
Ok(None)
Expand Down
16 changes: 16 additions & 0 deletions py-polars/tests/unit/namespaces/test_datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,14 @@ def test_truncate_duration(time_unit: TimeUnit) -> None:
assert_series_equal(durations.dt.truncate("10s"), expected)


def test_truncate_duration_zero() -> None:
"""Truncating to the nearest zero should raise a descriptive error."""
durations = pl.Series([timedelta(seconds=21), timedelta(seconds=35)])

with pytest.raises(InvalidOperationError, match="duration cannot be zero"):
durations.dt.truncate("0s")


@pytest.mark.parametrize("every_unit", ["d", "w", "mo", "q", "y"])
def test_truncated_duration_non_constant(every_unit: str) -> None:
# Duration series can't be truncated to non-constant durations
Expand Down Expand Up @@ -614,6 +622,14 @@ def test_round_duration(time_unit: TimeUnit) -> None:
assert_series_equal(durations.dt.round("10s"), expected)


def test_round_duration_zero() -> None:
"""Rounding to the nearest zero should raise a descriptive error."""
durations = pl.Series([timedelta(seconds=21), timedelta(seconds=35)])

with pytest.raises(InvalidOperationError, match="duration cannot be zero"):
durations.dt.round("0s")


@pytest.mark.parametrize("every", ["d", "w", "mo", "q", "y"])
def test_round_duration_non_constant(every: str) -> None:
# Duration series can't be rounded to non-constant durations
Expand Down

0 comments on commit 78b8601

Please sign in to comment.