Skip to content

Commit

Permalink
reduce to HALF_EVEN
Browse files Browse the repository at this point in the history
  • Loading branch information
coastalwhite committed Nov 14, 2024
1 parent 8c5b6c1 commit cb3133f
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 183 deletions.
24 changes: 0 additions & 24 deletions crates/polars-core/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@ pub(crate) const FMT_TABLE_INLINE_COLUMN_DATA_TYPE: &str =
pub(crate) const FMT_TABLE_ROUNDED_CORNERS: &str = "POLARS_FMT_TABLE_ROUNDED_CORNERS";
pub(crate) const FMT_TABLE_CELL_LIST_LEN: &str = "POLARS_FMT_TABLE_CELL_LIST_LEN";

#[cfg(feature = "dtype-decimal")]
pub(crate) const DECIMAL_ROUNDING_MODE: &str = "POLARS_DECIMAL_ROUNDING_MODE";

pub fn verbose() -> bool {
std::env::var("POLARS_VERBOSE").as_deref().unwrap_or("") == "1"
}
Expand All @@ -49,27 +46,6 @@ pub fn get_rg_prefetch_size() -> usize {
.unwrap_or_else(|_| std::cmp::max(get_file_prefetch_size(), 128))
}

#[cfg(feature = "dtype-decimal")]
pub fn get_decimal_rounding_mode() -> crate::datatypes::RoundingMode {
use crate::datatypes::RoundingMode as RM;

let Ok(value) = std::env::var(DECIMAL_ROUNDING_MODE) else {
return RM::default();
};

match &value[..] {
"ROUND_CEILING" => RM::Ceiling,
"ROUND_DOWN" => RM::Down,
"ROUND_FLOOR" => RM::Floor,
"ROUND_HALF_DOWN" => RM::HalfDown,
"ROUND_HALF_EVEN" => RM::HalfEven,
"ROUND_HALF_UP" => RM::HalfUp,
"ROUND_UP" => RM::Up,
"ROUND_05UP" => RM::Up05,
_ => panic!("Invalid rounding mode '{value}' given through `{DECIMAL_ROUNDING_MODE}` environment value."),
}
}

pub fn force_async() -> bool {
std::env::var("POLARS_FORCE_ASYNC")
.map(|value| value == "1")
Expand Down
2 changes: 0 additions & 2 deletions crates/polars-core/src/datatypes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,13 @@ mod into_scalar;
#[cfg(feature = "object")]
mod static_array_collect;
mod time_unit;
mod rounding_mode;

use std::cmp::Ordering;
use std::fmt::{Display, Formatter};
use std::hash::{Hash, Hasher};
use std::ops::{Add, AddAssign, Div, Mul, Rem, Sub, SubAssign};

pub use aliases::*;
pub use rounding_mode::RoundingMode;
pub use any_value::*;
pub use arrow::array::{ArrayCollectIterExt, ArrayFromIter, ArrayFromIterDtype, StaticArray};
pub use arrow::datatypes::reshape::*;
Expand Down
20 changes: 0 additions & 20 deletions crates/polars-core/src/datatypes/rounding_mode.rs

This file was deleted.

4 changes: 0 additions & 4 deletions crates/polars-ops/src/series/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ mod rle;
mod rolling;
#[cfg(feature = "round_series")]
mod round;
#[cfg(all(feature = "round_series", feature = "dtype-decimal"))]
mod round_decimal;
#[cfg(feature = "search_sorted")]
mod search_sorted;
#[cfg(feature = "to_dummies")]
Expand Down Expand Up @@ -124,8 +122,6 @@ pub use rle::*;
pub use rolling::*;
#[cfg(feature = "round_series")]
pub use round::*;
#[cfg(all(feature = "round_series", feature = "dtype-decimal"))]
pub use round_decimal::*;
#[cfg(feature = "search_sorted")]
pub use search_sorted::*;
#[cfg(feature = "to_dummies")]
Expand Down
108 changes: 105 additions & 3 deletions crates/polars-ops/src/series/ops/round.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,33 @@ pub trait RoundSeries: SeriesSealed {
}
#[cfg(feature = "dtype-decimal")]
if let Some(ca) = s.try_decimal() {
return Ok(super::round_decimal::dec_round(ca, decimals).into_series());
let precision = ca.precision();
let scale = ca.scale() as u32;
if scale <= decimals {
return Ok(ca.clone().into_series());
}

let decimal_delta = scale - decimals;
let multiplier = 10i128.pow(decimal_delta);
let threshold = multiplier / 2;

let ca = ca
.apply_values(|v| {
// We use rounding=ROUND_HALF_EVEN
let rem = v % multiplier;
let is_v_floor_even = ((v - rem) / multiplier) % 2 == 0;
let threshold = threshold + i128::from(is_v_floor_even);
let round_offset = if rem.abs() >= threshold {
multiplier
} else {
0
};
let round_offset = if v < 0 { -round_offset } else { round_offset };
v - rem + round_offset
})
.into_decimal_unchecked(precision, scale as usize);

return Ok(ca.into_series());
}

polars_ensure!(s.dtype().is_numeric(), InvalidOperation: "round can only be used on numeric types" );
Expand All @@ -47,6 +73,46 @@ pub trait RoundSeries: SeriesSealed {
fn round_sig_figs(&self, digits: i32) -> PolarsResult<Series> {
let s = self.as_series();
polars_ensure!(digits >= 1, InvalidOperation: "digits must be an integer >= 1");

if let Some(ca) = s.try_decimal() {
let precision = ca.precision();
let scale = ca.scale() as u32;

let s = ca
.apply_values(|v| {
if v == 0 {
return 0;
}

let mut magnitude = v.abs().ilog10();
let magnitude_mult = 10i128.pow(magnitude); // @Q? It might be better to do this with a
// LUT.
if v.abs() > magnitude_mult {
magnitude += 1;
}
let decimals = magnitude.saturating_sub(digits as u32);
let multiplier = 10i128.pow(decimals); // @Q? It might be better to do this with a
// LUT.
let threshold = multiplier / 2;

// We use rounding=ROUND_HALF_EVEN
let rem = v % multiplier;
let is_v_floor_even = decimals <= scale && ((v - rem) / multiplier) % 2 == 0;
let threshold = threshold + i128::from(is_v_floor_even);
let round_offset = if rem.abs() >= threshold {
multiplier
} else {
0
};
let round_offset = if v < 0 { -round_offset } else { round_offset };
v - rem + round_offset
})
.into_decimal_unchecked(precision, scale as usize)
.into_series();

return Ok(s);
}

polars_ensure!(s.dtype().is_numeric(), InvalidOperation: "round_sig_figs can only be used on numeric types" );
with_match_physical_numeric_polars_type!(s.dtype(), |$T| {
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
Expand Down Expand Up @@ -76,7 +142,25 @@ pub trait RoundSeries: SeriesSealed {
}
#[cfg(feature = "dtype-decimal")]
if let Some(ca) = s.try_decimal() {
return Ok(super::round_decimal::dec_round_floor(ca, 0).into_series());
let precision = ca.precision();
let scale = ca.scale() as u32;
if scale == 0 {
return Ok(ca.clone().into_series());
}

let decimal_delta = scale;
let multiplier = 10i128.pow(decimal_delta);

let ca = ca
.apply_values(|v| {
let rem = v % multiplier;
let round_offset = if v < 0 { multiplier + rem } else { rem };
let round_offset = if rem == 0 { 0 } else { round_offset };
v - round_offset
})
.into_decimal_unchecked(precision, scale as usize);

return Ok(ca.into_series());
}

polars_ensure!(s.dtype().is_numeric(), InvalidOperation: "floor can only be used on numeric types" );
Expand All @@ -97,7 +181,25 @@ pub trait RoundSeries: SeriesSealed {
}
#[cfg(feature = "dtype-decimal")]
if let Some(ca) = s.try_decimal() {
return Ok(super::round_decimal::dec_round_ceiling(ca, 0).into_series());
let precision = ca.precision();
let scale = ca.scale() as u32;
if scale == 0 {
return Ok(ca.clone().into_series());
}

let decimal_delta = scale;
let multiplier = 10i128.pow(decimal_delta);

let ca = ca
.apply_values(|v| {
let rem = v % multiplier;
let round_offset = if v < 0 { -rem } else { multiplier - rem };
let round_offset = if rem == 0 { 0 } else { round_offset };
v + round_offset
})
.into_decimal_unchecked(precision, scale as usize);

return Ok(ca.into_series());
}

polars_ensure!(s.dtype().is_numeric(), InvalidOperation: "ceil can only be used on numeric types" );
Expand Down
130 changes: 0 additions & 130 deletions crates/polars-ops/src/series/ops/round_decimal.rs

This file was deleted.

19 changes: 19 additions & 0 deletions py-polars/tests/unit/datatypes/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from decimal import Decimal as D
from random import choice, randrange, seed
from typing import Any, Callable, NamedTuple
from math import floor, ceil

import pytest

Expand Down Expand Up @@ -529,3 +530,21 @@ def test_decimal_strict_scale_inference_17770() -> None:
s = pl.Series(values, strict=True)
assert s.dtype == pl.Decimal(precision=None, scale=4)
assert s.to_list() == values


def test_decimal_round() -> None:
dtype = pl.Decimal(3, 2)
values = [D(f"{float(v) / 100.:.02f}") for v in range(-150, 250, 1)]
i_s = pl.Series('a', values, dtype)

floor_s = pl.Series('a', [floor(v) for v in values], dtype)
ceil_s = pl.Series('a', [ceil(v) for v in values], dtype)

assert_series_equal(i_s.floor(), floor_s)
assert_series_equal(i_s.ceil(), ceil_s)

for decimals in range(0, 10):
got_s = i_s.round(decimals)
expected_s = pl.Series('a', [round(v, decimals) for v in values], dtype)

assert_series_equal(got_s, expected_s)

0 comments on commit cb3133f

Please sign in to comment.