Skip to content

Commit

Permalink
fix: Correctly process subseconds in pl.duration (#11748)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego authored Oct 16, 2023
1 parent bda942d commit 6c7446f
Show file tree
Hide file tree
Showing 7 changed files with 250 additions and 221 deletions.
106 changes: 57 additions & 49 deletions crates/polars-plan/src/dsl/functions/temporal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,95 +256,103 @@ impl DurationArgs {
#[cfg(feature = "temporal")]
pub fn duration(args: DurationArgs) -> Expr {
let function = SpecialEq::new(Arc::new(move |s: &mut [Series]| {
assert_eq!(s.len(), 8);
if s.iter().any(|s| s.is_empty()) {
return Ok(Some(Series::new_empty(
s[0].name(),
&DataType::Duration(args.time_unit),
)));
}

let days = s[0].cast(&DataType::Int64).unwrap();
let seconds = s[1].cast(&DataType::Int64).unwrap();
let mut nanoseconds = s[2].cast(&DataType::Int64).unwrap();
let mut microseconds = s[3].cast(&DataType::Int64).unwrap();
let mut milliseconds = s[4].cast(&DataType::Int64).unwrap();
let minutes = s[5].cast(&DataType::Int64).unwrap();
let hours = s[6].cast(&DataType::Int64).unwrap();
let weeks = s[7].cast(&DataType::Int64).unwrap();
// TODO: Handle overflow for UInt64
let weeks = s[0].cast(&DataType::Int64).unwrap();
let days = s[1].cast(&DataType::Int64).unwrap();
let hours = s[2].cast(&DataType::Int64).unwrap();
let minutes = s[3].cast(&DataType::Int64).unwrap();
let seconds = s[4].cast(&DataType::Int64).unwrap();
let mut milliseconds = s[5].cast(&DataType::Int64).unwrap();
let mut microseconds = s[6].cast(&DataType::Int64).unwrap();
let mut nanoseconds = s[7].cast(&DataType::Int64).unwrap();

let max_len = s.iter().map(|s| s.len()).max().unwrap();

let condition = |s: &Series| {
// check if not literal 0 || full column
(s.len() != max_len && s.get(0).unwrap() != AnyValue::Int64(0)) || s.len() == max_len
};

let multiplier = match args.time_unit {
TimeUnit::Nanoseconds => NANOSECONDS,
TimeUnit::Microseconds => MICROSECONDS,
TimeUnit::Milliseconds => MILLISECONDS,
};
let is_scalar = |s: &Series| s.len() == 1;
let is_zero_scalar = |s: &Series| is_scalar(s) && s.get(0).unwrap() == AnyValue::Int64(0);

// Process subseconds
let max_len = s.iter().map(|s| s.len()).max().unwrap();
let mut duration = match args.time_unit {
TimeUnit::Microseconds => {
if is_scalar(&microseconds) {
microseconds = microseconds.new_from_index(0, max_len);
}
if !is_zero_scalar(&nanoseconds) {
microseconds = microseconds + (nanoseconds / 1_000);
}
if !is_zero_scalar(&milliseconds) {
microseconds = microseconds + (milliseconds * 1_000);
}
microseconds
},
TimeUnit::Nanoseconds => {
if nanoseconds.len() != max_len {
if is_scalar(&nanoseconds) {
nanoseconds = nanoseconds.new_from_index(0, max_len);
}
if condition(&microseconds) {
if !is_zero_scalar(&microseconds) {
nanoseconds = nanoseconds + (microseconds * 1_000);
}
if condition(&milliseconds) {
if !is_zero_scalar(&milliseconds) {
nanoseconds = nanoseconds + (milliseconds * 1_000_000);
}
nanoseconds
},
TimeUnit::Microseconds => {
if microseconds.len() != max_len {
microseconds = microseconds.new_from_index(0, max_len);
}
if condition(&milliseconds) {
microseconds = microseconds + (milliseconds * 1_000);
}
microseconds
},
TimeUnit::Milliseconds => {
if milliseconds.len() != max_len {
if is_scalar(&milliseconds) {
milliseconds = milliseconds.new_from_index(0, max_len);
}
if !is_zero_scalar(&nanoseconds) {
milliseconds = milliseconds + (nanoseconds / 1_000_000);
}
if !is_zero_scalar(&microseconds) {
milliseconds = milliseconds + (microseconds / 1_000);
}
milliseconds
},
};

if condition(&seconds) {
duration = duration + (seconds * multiplier);
// Process other duration specifiers
let multiplier = match args.time_unit {
TimeUnit::Nanoseconds => NANOSECONDS,
TimeUnit::Microseconds => MICROSECONDS,
TimeUnit::Milliseconds => MILLISECONDS,
};
if !is_zero_scalar(&seconds) {
duration = duration + seconds * multiplier;
}
if condition(&days) {
duration = duration + (days * multiplier * SECONDS_IN_DAY);
if !is_zero_scalar(&minutes) {
duration = duration + minutes * (multiplier * 60);
}
if condition(&minutes) {
duration = duration + minutes * multiplier * 60;
if !is_zero_scalar(&hours) {
duration = duration + hours * (multiplier * 60 * 60);
}
if condition(&hours) {
duration = duration + hours * multiplier * 60 * 60;
if !is_zero_scalar(&days) {
duration = duration + days * (multiplier * SECONDS_IN_DAY);
}
if condition(&weeks) {
duration = duration + weeks * multiplier * SECONDS_IN_DAY * 7;
if !is_zero_scalar(&weeks) {
duration = duration + weeks * (multiplier * SECONDS_IN_DAY * 7);
}

duration.cast(&DataType::Duration(args.time_unit)).map(Some)
}) as Arc<dyn SeriesUdf>);

// TODO: Make non-anonymous
Expr::AnonymousFunction {
input: vec![
args.weeks,
args.days,
args.hours,
args.minutes,
args.seconds,
args.nanoseconds,
args.microseconds,
args.milliseconds,
args.minutes,
args.hours,
args.weeks,
args.microseconds,
args.nanoseconds,
],
function,
output_type: GetOutput::from_type(DataType::Duration(args.time_unit)),
Expand Down
117 changes: 4 additions & 113 deletions py-polars/tests/unit/datatypes/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,111 +817,6 @@ def test_asof_join_tolerance_grouper() -> None:
assert_frame_equal(out, expected)


def test_datetime_duration_offset() -> None:
df = pl.DataFrame(
{
"datetime": [
datetime(1999, 1, 1, 7),
datetime(2022, 1, 2, 14),
datetime(3000, 12, 31, 21),
],
"add": [1, 2, -1],
}
)
out = df.select(
[
(pl.col("datetime") + pl.duration(weeks="add")).alias("add_weeks"),
(pl.col("datetime") + pl.duration(days="add")).alias("add_days"),
(pl.col("datetime") + pl.duration(hours="add")).alias("add_hours"),
(pl.col("datetime") + pl.duration(seconds="add")).alias("add_seconds"),
(pl.col("datetime") + pl.duration(microseconds=pl.col("add") * 1000)).alias(
"add_usecs"
),
]
)
expected = pl.DataFrame(
{
"add_weeks": [
datetime(1999, 1, 8, 7),
datetime(2022, 1, 16, 14),
datetime(3000, 12, 24, 21),
],
"add_days": [
datetime(1999, 1, 2, 7),
datetime(2022, 1, 4, 14),
datetime(3000, 12, 30, 21),
],
"add_hours": [
datetime(1999, 1, 1, hour=8),
datetime(2022, 1, 2, hour=16),
datetime(3000, 12, 31, hour=20),
],
"add_seconds": [
datetime(1999, 1, 1, 7, second=1),
datetime(2022, 1, 2, 14, second=2),
datetime(3000, 12, 31, 20, 59, 59),
],
"add_usecs": [
datetime(1999, 1, 1, 7, microsecond=1000),
datetime(2022, 1, 2, 14, microsecond=2000),
datetime(3000, 12, 31, 20, 59, 59, 999000),
],
}
)
assert_frame_equal(out, expected)


def test_date_duration_offset() -> None:
df = pl.DataFrame(
{
"date": [date(10, 1, 1), date(2000, 7, 5), date(9990, 12, 31)],
"offset": [365, 7, -31],
}
)
out = df.select(
[
(pl.col("date") + pl.duration(days="offset")).alias("add_days"),
(pl.col("date") - pl.duration(days="offset")).alias("sub_days"),
(pl.col("date") + pl.duration(weeks="offset")).alias("add_weeks"),
(pl.col("date") - pl.duration(weeks="offset")).alias("sub_weeks"),
]
)
assert out.to_dict(False) == {
"add_days": [date(11, 1, 1), date(2000, 7, 12), date(9990, 11, 30)],
"sub_days": [date(9, 1, 1), date(2000, 6, 28), date(9991, 1, 31)],
"add_weeks": [date(16, 12, 30), date(2000, 8, 23), date(9990, 5, 28)],
"sub_weeks": [date(3, 1, 3), date(2000, 5, 17), date(9991, 8, 5)],
}


def test_add_duration_3786() -> None:
df = pl.DataFrame(
{
"datetime": [datetime(2022, 1, 1), datetime(2022, 1, 2)],
"add": [1, 2],
}
)
assert df.slice(0, 1).with_columns(
[
(pl.col("datetime") + pl.duration(weeks="add")).alias("add_weeks"),
(pl.col("datetime") + pl.duration(days="add")).alias("add_days"),
(pl.col("datetime") + pl.duration(seconds="add")).alias("add_seconds"),
(pl.col("datetime") + pl.duration(milliseconds="add")).alias(
"add_milliseconds"
),
(pl.col("datetime") + pl.duration(hours="add")).alias("add_hours"),
]
).to_dict(False) == {
"datetime": [datetime(2022, 1, 1, 0, 0)],
"add": [1],
"add_weeks": [datetime(2022, 1, 8, 0, 0)],
"add_days": [datetime(2022, 1, 2, 0, 0)],
"add_seconds": [datetime(2022, 1, 1, 0, 0, 1)],
"add_milliseconds": [datetime(2022, 1, 1, 0, 0, 0, 1000)],
"add_hours": [datetime(2022, 1, 1, 1, 0)],
}


def test_rolling_group_by_by_argument() -> None:
df = pl.DataFrame({"times": range(10), "groups": [1] * 4 + [2] * 6})

Expand Down Expand Up @@ -2644,16 +2539,12 @@ def test_datetime_cum_agg_schema() -> None:
assert (
df.lazy()
.with_columns(
[
(pl.col("timestamp").cummin()).alias("cummin"),
(pl.col("timestamp").cummax()).alias("cummax"),
]
(pl.col("timestamp").cummin()).alias("cummin"),
(pl.col("timestamp").cummax()).alias("cummax"),
)
.with_columns(
[
(pl.col("cummin") + pl.duration(hours=24)).alias("cummin+24"),
(pl.col("cummax") + pl.duration(hours=24)).alias("cummax+24"),
]
(pl.col("cummin") + pl.duration(hours=24)).alias("cummin+24"),
(pl.col("cummax") + pl.duration(hours=24)).alias("cummax+24"),
)
.collect()
).to_dict(False) == {
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from datetime import date, datetime, timedelta
from datetime import date, datetime
from typing import TYPE_CHECKING

import pytest
Expand Down Expand Up @@ -94,38 +94,6 @@ def test_time() -> None:
assert_series_equal(out["ms2"], df["micro"].rename("ms2"))


def test_empty_duration() -> None:
s = pl.DataFrame([], {"days": pl.Int32}).select(pl.duration(days="days"))
assert s.dtypes == [pl.Duration("us")]
assert s.shape == (0, 1)


@pytest.mark.parametrize(
("time_unit", "expected"),
[
("ms", timedelta(days=1, minutes=2, seconds=3, milliseconds=4)),
("us", timedelta(days=1, minutes=2, seconds=3, milliseconds=4, microseconds=5)),
("ns", timedelta(days=1, minutes=2, seconds=3, milliseconds=4, microseconds=5)),
],
)
def test_duration_time_units(time_unit: TimeUnit, expected: timedelta) -> None:
result = pl.LazyFrame().select(
pl.duration(
days=1,
minutes=2,
seconds=3,
milliseconds=4,
microseconds=5,
nanoseconds=6,
time_unit=time_unit,
)
)
assert result.schema["duration"] == pl.Duration(time_unit)
assert result.collect()["duration"].item() == expected
if time_unit == "ns":
assert result.collect()["duration"].dt.nanoseconds().item() == 86523004005006


def test_list_concat() -> None:
s0 = pl.Series("a", [[1, 2]])
s1 = pl.Series("b", [[3, 4, 5]])
Expand Down
Loading

0 comments on commit 6c7446f

Please sign in to comment.