Skip to content

Commit

Permalink
fix: sum_horizontal should not always cast to int (#12031)
Browse files Browse the repository at this point in the history
  • Loading branch information
reswqa authored Oct 26, 2023
1 parent a397e97 commit 06e6b3e
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 24 deletions.
48 changes: 24 additions & 24 deletions crates/polars-ops/src/series/ops/horizontal.rs
Original file line number Diff line number Diff line change
@@ -1,34 +1,34 @@
use std::borrow::Cow;
use std::ops::{BitAnd, BitOr};

use polars_core::prelude::*;
use polars_core::POOL;
use rayon::prelude::*;

pub fn sum_horizontal(s: &[Series]) -> PolarsResult<Series> {
let out = POOL
.install(|| {
s.par_iter()
.try_fold(
|| UInt32Chunked::new("", &[0u32]).into_series(),
|acc, b| {
PolarsResult::Ok(
acc.fill_null(FillNullStrategy::Zero)?
+ b.fill_null(FillNullStrategy::Zero)?,
)
},
)
.try_reduce(
|| UInt32Chunked::new("", &[0u32]).into_series(),
|a, b| {
PolarsResult::Ok(
a.fill_null(FillNullStrategy::Zero)?
+ b.fill_null(FillNullStrategy::Zero)?,
)
},
)
})?
.with_name("sum");
Ok(out)
let sum_fn = |acc: &Series, s: &Series| {
PolarsResult::Ok(
acc.fill_null(FillNullStrategy::Zero)? + s.fill_null(FillNullStrategy::Zero)?,
)
};
let out = match s.len() {
0 => Ok(UInt32Chunked::new("", &[0u32]).into_series()),
1 => Ok(s[0].clone()),
2 => sum_fn(&s[0], &s[1]),
_ => {
// the try_reduce_with is a bit slower in parallelism,
// but I don't think it matters here as we parallelize over series, not over elements
POOL.install(|| {
s.par_iter()
.map(|s| Ok(Cow::Borrowed(s)))
.try_reduce_with(|l, r| sum_fn(&l, &r).map(Cow::Owned))
// we can unwrap the option, because we are certain there is a series
.unwrap()
.map(|cow| cow.into_owned())
})
},
};
out.map(|ok| ok.with_name("sum"))
}

pub fn any_horizontal(s: &[Series]) -> PolarsResult<Series> {
Expand Down
17 changes: 17 additions & 0 deletions py-polars/tests/unit/functions/aggregation/test_horizontal.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import datetime
from typing import Any

import pytest
Expand Down Expand Up @@ -227,3 +228,19 @@ def test_cumsum_fold() -> None:
)
result = df.select(pl.cumsum_horizontal("a", "c"))
assert result.to_dict(False) == {"cumsum": [{"a": 1, "c": 6}, {"a": 2, "c": 8}]}


def test_sum_dtype_12028() -> None:
result = pl.select(
pl.sum_horizontal([pl.duration(seconds=10)]).alias("sum_duration")
)
expected = pl.DataFrame(
[
pl.Series(
"sum_duration",
[datetime.timedelta(seconds=10)],
dtype=pl.Duration(time_unit="us"),
),
]
)
assert_frame_equal(expected, result)

0 comments on commit 06e6b3e

Please sign in to comment.