Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement round and truncate for Duration columns #12597

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions crates/polars-plan/src/dsl/function_expr/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,12 @@ pub(super) fn truncate(s: &[Series]) -> PolarsResult<Series> {
_ => time_series.datetime()?.truncate(None, every)?.into_series(),
},
DataType::Date => time_series.date()?.truncate(None, every)?.into_series(),
dt => polars_bail!(opq = round, got = dt, expected = "date/datetime"),
DataType::Duration(_) => time_series.duration()?.truncate(None, every)?.into_series(),
dt => polars_bail!(
opq = truncate,
got = dt,
expected = "date/datetime/duration"
),
};
out.set_sorted_flag(time_series.is_sorted_flag());
Ok(out)
Expand Down Expand Up @@ -498,7 +503,12 @@ pub(super) fn round(s: &[Series]) -> PolarsResult<Series> {
.unwrap()
.round(every, None)?
.into_series(),
dt => polars_bail!(opq = round, got = dt, expected = "date/datetime"),
DataType::Duration(_) => time_series
.duration()
.unwrap()
.round(every, None)?
.into_series(),
dt => polars_bail!(opq = round, got = dt, expected = "date/datetime/duration"),
})
}

Expand Down
83 changes: 63 additions & 20 deletions crates/polars-time/src/round.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ impl PolarsRound for DatetimeChunked {
if every.len() == 1 {
if let Some(every) = every.get(0) {
let every_parsed = Duration::parse(every);
if every_parsed.negative {
polars_bail!(ComputeError: "cannot round a Datetime to a negative duration")
}
polars_ensure!(!every_parsed.negative & !every_parsed.is_zero(), InvalidOperation: "cannot round a Datetime to a non-positive Duration");
if (time_zone.is_none() || time_zone.as_deref() == Some("UTC"))
&& (every_parsed.months() == 0 && every_parsed.weeks() == 0)
{
Expand Down Expand Up @@ -76,14 +74,11 @@ impl PolarsRound for DatetimeChunked {
opt_every,
) {
(Some(timestamp), Some(every)) => {
let every =
let every_parsed =
*duration_cache.get_or_insert_with(every, |every| Duration::parse(every));
polars_ensure!(!every_parsed.negative & !every_parsed.is_zero(), InvalidOperation: "cannot round a Datetime to a non-positive Duration");

if every.negative {
polars_bail!(ComputeError: "cannot round a Datetime to a negative duration")
}

let w = Window::new(every, every, offset);
let w = Window::new(every_parsed, every_parsed, offset);
func(&w, timestamp, tz).map(Some)
},
_ => Ok(None),
Expand All @@ -98,11 +93,9 @@ impl PolarsRound for DateChunked {
let out = match every.len() {
1 => {
if let Some(every) = every.get(0) {
let every = Duration::parse(every);
if every.negative {
polars_bail!(ComputeError: "cannot round a Date to a negative duration")
}
let w = Window::new(every, every, offset);
let every_parsed = Duration::parse(every);
polars_ensure!(!every_parsed.negative & !every_parsed.is_zero(), InvalidOperation: "cannot round a Date to a non-positive Duration");
let w = Window::new(every_parsed, every_parsed, offset);
self.try_apply_nonnull_values_generic(|t| {
Ok(
(w.round_ms(MILLISECONDS_IN_DAY * t as i64, None)?
Expand All @@ -118,14 +111,11 @@ impl PolarsRound for DateChunked {
let mut duration_cache = FastFixedCache::new((every.len() as f64).sqrt() as usize);
match (opt_t, opt_every) {
(Some(t), Some(every)) => {
let every = *duration_cache
let every_parsed = *duration_cache
.get_or_insert_with(every, |every| Duration::parse(every));
polars_ensure!(!every_parsed.negative & !every_parsed.is_zero(), InvalidOperation: "cannot round a Date to a non-positive Duration");

if every.negative {
polars_bail!(ComputeError: "cannot round a Date to a negative duration")
}

let w = Window::new(every, every, offset);
let w = Window::new(every_parsed, every_parsed, offset);
Ok(Some(
(w.round_ms(MILLISECONDS_IN_DAY * t as i64, None)?
/ MILLISECONDS_IN_DAY) as i32,
Expand All @@ -138,3 +128,56 @@ impl PolarsRound for DateChunked {
Ok(out?.into_date())
}
}

#[cfg(feature = "dtype-duration")]
impl PolarsRound for DurationChunked {
fn round(&self, every: &StringChunked, _tz: Option<&Tz>) -> PolarsResult<Self> {
if every.len() == 1 {
if let Some(every) = every.get(0) {
let every_parsed = Duration::parse(every);
polars_ensure!(!every_parsed.negative & !every_parsed.is_zero(), InvalidOperation: "cannot round a Duration to a non-positive Duration");
polars_ensure!(every_parsed.is_constant_duration(None), InvalidOperation:"cannot round a Duration to a non-constant Duration (i.e. one that involves weeks / months)");
let every = match self.time_unit() {
TimeUnit::Milliseconds => every_parsed.duration_ms(),
TimeUnit::Microseconds => every_parsed.duration_us(),
TimeUnit::Nanoseconds => every_parsed.duration_ns(),
};
return Ok(self
.apply_values(|t| {
// Round half-way values away from zero
let half_away = t.signum() * every / 2;
t + half_away - (t + half_away) % every
})
.into_duration(self.time_unit()));
} else {
return Ok(Int64Chunked::full_null(self.name(), self.len())
.into_duration(self.time_unit()));
}
}

// A sqrt(n) cache is not too small, not too large.
let mut duration_cache = FastFixedCache::new((every.len() as f64).sqrt() as usize);

let out = broadcast_try_binary_elementwise(self, every, |opt_timestamp, opt_every| match (
opt_timestamp,
opt_every,
) {
(Some(t), Some(every)) => {
let every_parsed =
*duration_cache.get_or_insert_with(every, |every| Duration::parse(every));
polars_ensure!(!every_parsed.negative, InvalidOperation: "cannot round a Duration to a negative duration");
polars_ensure!(every_parsed.is_constant_duration(None), InvalidOperation:"cannot round a Duration to a non-constant Duration (i.e. one that involves weeks / months)");
let every = match self.time_unit() {
TimeUnit::Milliseconds => every_parsed.duration_ms(),
TimeUnit::Microseconds => every_parsed.duration_us(),
TimeUnit::Nanoseconds => every_parsed.duration_ns(),
};
// Round half-way values away from zero
let half_away = t.signum() * every / 2;
Ok(Some(t + half_away - (t + half_away) % every))
},
_ => Ok(None),
});
Ok(out?.into_duration(self.time_unit()))
}
}
77 changes: 57 additions & 20 deletions crates/polars-time/src/truncate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ impl PolarsTruncate for DatetimeChunked {
if every.len() == 1 {
if let Some(every) = every.get(0) {
let every_parsed = Duration::parse(every);
if every_parsed.negative {
polars_bail!(ComputeError: "cannot truncate a Datetime to a negative duration")
}
polars_ensure!(!every_parsed.negative & !every_parsed.is_zero(), InvalidOperation: "cannot truncate a Datetime to a non-positive Duration");
if (time_zone.is_none() || time_zone.as_deref() == Some("UTC"))
&& (every_parsed.months() == 0 && every_parsed.weeks() == 0)
{
Expand Down Expand Up @@ -75,14 +73,11 @@ impl PolarsTruncate for DatetimeChunked {
opt_every,
) {
(Some(timestamp), Some(every)) => {
let every =
let every_parsed =
*duration_cache.get_or_insert_with(every, |every| Duration::parse(every));
polars_ensure!(!every_parsed.negative & !every_parsed.is_zero(), InvalidOperation: "cannot truncate a Datetime to a non-positive Duration");

if every.negative {
polars_bail!(ComputeError: "cannot truncate a Datetime to a negative duration")
}

let w = Window::new(every, every, offset);
let w = Window::new(every_parsed, every_parsed, offset);
func(&w, timestamp, tz).map(Some)
},
_ => Ok(None),
Expand All @@ -97,11 +92,9 @@ impl PolarsTruncate for DateChunked {
let out = match every.len() {
1 => {
if let Some(every) = every.get(0) {
let every = Duration::parse(every);
if every.negative {
polars_bail!(ComputeError: "cannot truncate a Date to a negative duration")
}
let w = Window::new(every, every, offset);
let every_parsed = Duration::parse(every);
polars_ensure!(!every_parsed.negative & !every_parsed.is_zero(), InvalidOperation: "cannot truncate a Date to a non-positive Duration");
let w = Window::new(every_parsed, every_parsed, offset);
self.try_apply_nonnull_values_generic(|t| {
Ok((w.truncate_ms(MILLISECONDS_IN_DAY * t as i64, None)?
/ MILLISECONDS_IN_DAY) as i32)
Expand All @@ -115,14 +108,11 @@ impl PolarsTruncate for DateChunked {
let mut duration_cache = FastFixedCache::new((every.len() as f64).sqrt() as usize);
match (opt_t, opt_every) {
(Some(t), Some(every)) => {
let every = *duration_cache
let every_parsed = *duration_cache
.get_or_insert_with(every, |every| Duration::parse(every));
polars_ensure!(!every_parsed.negative & !every_parsed.is_zero(), InvalidOperation: "cannot truncate a Date to a non-positive Duration");

if every.negative {
polars_bail!(ComputeError: "cannot truncate a Date to a negative duration")
}

let w = Window::new(every, every, offset);
let w = Window::new(every_parsed, every_parsed, offset);
Ok(Some(
(w.truncate_ms(MILLISECONDS_IN_DAY * t as i64, None)?
/ MILLISECONDS_IN_DAY) as i32,
Expand All @@ -135,3 +125,50 @@ impl PolarsTruncate for DateChunked {
Ok(out?.into_date())
}
}

#[cfg(feature = "dtype-duration")]
impl PolarsTruncate for DurationChunked {
fn truncate(&self, _tz: Option<&Tz>, every: &StringChunked) -> PolarsResult<Self> {
if every.len() == 1 {
if let Some(every) = every.get(0) {
let every_parsed = Duration::parse(every);
polars_ensure!(!every_parsed.negative & !every_parsed.is_zero(), InvalidOperation: "cannot truncate a Duration to a non-positive Duration");
polars_ensure!(every_parsed.is_constant_duration(None), InvalidOperation:"cannot truncate a Duration to a non-constant Duration (i.e. one that involves weeks / months)");
let every = match self.time_unit() {
TimeUnit::Milliseconds => every_parsed.duration_ms(),
TimeUnit::Microseconds => every_parsed.duration_us(),
TimeUnit::Nanoseconds => every_parsed.duration_ns(),
};
return Ok(self
.apply_values(|t| t - t % every)
.into_duration(self.time_unit()));
} else {
return Ok(Int64Chunked::full_null(self.name(), self.len())
.into_duration(self.time_unit()));
}
}

// A sqrt(n) cache is not too small, not too large.
let mut duration_cache = FastFixedCache::new((every.len() as f64).sqrt() as usize);

let out = broadcast_try_binary_elementwise(self, every, |opt_timestamp, opt_every| match (
opt_timestamp,
opt_every,
) {
(Some(t), Some(every)) => {
let every_parsed =
*duration_cache.get_or_insert_with(every, |every| Duration::parse(every));
polars_ensure!(!every_parsed.negative & !every_parsed.is_zero(), InvalidOperation: "cannot truncate a Duration to a non-positive Duration");
polars_ensure!(every_parsed.is_constant_duration(None), InvalidOperation:"cannot truncate a Duration to a non-constant Duration (i.e. one that involves weeks / months)");
let every = match self.time_unit() {
TimeUnit::Milliseconds => every_parsed.duration_ms(),
TimeUnit::Microseconds => every_parsed.duration_us(),
TimeUnit::Nanoseconds => every_parsed.duration_ns(),
};
Ok(Some(t - t % every))
MarcoGorelli marked this conversation as resolved.
Show resolved Hide resolved
},
_ => Ok(None),
});
Ok(out?.into_duration(self.time_unit()))
}
}
21 changes: 15 additions & 6 deletions py-polars/polars/expr/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,11 @@ def add_business_days(

def truncate(self, every: str | dt.timedelta | Expr) -> Expr:
"""
Divide the date/datetime range into buckets.
Divide the dates, datetimes, or durations into buckets.

Each date/datetime is mapped to the start of its bucket using the corresponding
local datetime. Note that weekly buckets start on Monday.
For dates or datetimes, each date/datetime is mapped to the start of its bucket
using the corresponding local datetime.
Note that weekly buckets start on Monday.
Ambiguous results are localised using the DST offset of the original timestamp -
for example, truncating `'2022-11-06 01:30:00 CST'` by `'1h'` results in
`'2022-11-06 01:00:00 CST'`, whereas truncating `'2022-11-06 01:30:00 CDT'` by
Expand Down Expand Up @@ -192,6 +193,10 @@ def truncate(self, every: str | dt.timedelta | Expr) -> Expr:
not be 24 hours, due to daylight savings). Similarly for "calendar week",
"calendar month", "calendar quarter", and "calendar year".

Durations may not be truncated to a period length `every` containing calendar
days, weeks, months, quarters, or years, as these are not constant time
intervals.

Returns
-------
Expr
Expand Down Expand Up @@ -278,15 +283,15 @@ def truncate(self, every: str | dt.timedelta | Expr) -> Expr:
@unstable()
def round(self, every: str | dt.timedelta | IntoExprColumn) -> Expr:
"""
Divide the date/datetime range into buckets.
Divide the dates, datetimes, or durations into buckets.

.. warning::
This functionality is considered **unstable**. It may be changed
at any point without it being considered a breaking change.

Each date/datetime in the first half of the interval
Each date/datetime/duration in the first half of the interval
is mapped to the start of its bucket.
Each date/datetime in the second half of the interval
Each date/datetime/duration in the second half of the interval
is mapped to the end of its bucket.
Ambiguous results are localised using the DST offset of the original timestamp -
for example, rounding `'2022-11-06 01:20:00 CST'` by `'1h'` results in
Expand Down Expand Up @@ -326,6 +331,10 @@ def round(self, every: str | dt.timedelta | IntoExprColumn) -> Expr:
not be 24 hours, due to daylight savings). Similarly for "calendar week",
"calendar month", "calendar quarter", and "calendar year".

Durations may not be rounded to a period length `every` containing calendar
days, weeks, months, quarters, or years, as these are not constant time
intervals.

Examples
--------
>>> from datetime import timedelta, datetime
Expand Down
27 changes: 18 additions & 9 deletions py-polars/polars/series/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -1644,10 +1644,11 @@ def offset_by(self, by: str | Expr) -> Series:

def truncate(self, every: str | dt.timedelta | IntoExprColumn) -> Series:
"""
Divide the date/ datetime range into buckets.
Divide the dates, datetimes, or durations into buckets.

Each date/datetime is mapped to the start of its bucket using the corresponding
local datetime. Note that weekly buckets start on Monday.
For dates or datetimes, each date/datetime is mapped to the start of its bucket
using the corresponding local datetime.
Note that weekly buckets start on Monday.
Ambiguous results are localised using the DST offset of the original timestamp -
for example, truncating `'2022-11-06 01:30:00 CST'` by `'1h'` results in
`'2022-11-06 01:00:00 CST'`, whereas truncating `'2022-11-06 01:30:00 CDT'` by
Expand Down Expand Up @@ -1683,6 +1684,10 @@ def truncate(self, every: str | dt.timedelta | IntoExprColumn) -> Series:
not be 24 hours, due to daylight savings). Similarly for "calendar week",
"calendar month", "calendar quarter", and "calendar year".

Durations may not be truncated to a period length `every` containing calendar
days, weeks, months, quarters, or years, as these are not constant time
intervals.

Returns
-------
Series
Expand Down Expand Up @@ -1758,17 +1763,17 @@ def truncate(self, every: str | dt.timedelta | IntoExprColumn) -> Series:
@unstable()
def round(self, every: str | dt.timedelta | IntoExprColumn) -> Series:
"""
Divide the date/ datetime range into buckets.
Divide the dates, datetimes, or durations into buckets.

.. warning::
This functionality is considered **unstable**. It may be changed
at any point without it being considered a breaking change.

Each date/datetime in the first half of the interval is mapped to the start of
its bucket.
Each date/datetime in the second half of the interval is mapped to the end of
its bucket.
Ambiguous results are localized using the DST offset of the original timestamp -
Each date/datetime/duration in the first half of the interval
is mapped to the start of its bucket.
Each date/datetime/duration in the second half of the interval
is mapped to the end of its bucket.
Ambiguous results are localised using the DST offset of the original timestamp -
for example, rounding `'2022-11-06 01:20:00 CST'` by `'1h'` results in
`'2022-11-06 01:00:00 CST'`, whereas rounding `'2022-11-06 01:20:00 CDT'` by
`'1h'` results in `'2022-11-06 01:00:00 CDT'`.
Expand Down Expand Up @@ -1808,6 +1813,10 @@ def round(self, every: str | dt.timedelta | IntoExprColumn) -> Series:
not be 24 hours, due to daylight savings). Similarly for "calendar week",
"calendar month", "calendar quarter", and "calendar year".

Durations may not be rounded to a period length `every` containing calendar
days, weeks, months, quarters, or years, as these are not constant time
intervals.

Examples
--------
>>> from datetime import timedelta, datetime
Expand Down
Loading