Skip to content

Commit

Permalink
fix: Infer decimal scales on mixed scale input (pola-rs#17840)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jul 24, 2024
1 parent e96dd26 commit 954538e
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 3 deletions.
14 changes: 12 additions & 2 deletions crates/polars-core/src/series/any_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ impl Series {

// TODO: Remove this when Decimal data type equality is implemented.
#[cfg(feature = "dtype-decimal")]
if !strict && dtype.is_decimal() {
if dtype.is_decimal() {
let dtype = DataType::Decimal(None, None);
return Self::from_any_values_and_dtype(name, values, &dtype, strict);
}
Expand Down Expand Up @@ -488,7 +488,17 @@ fn any_values_to_decimal(
let mut builder = PrimitiveChunkedBuilder::<Int128Type>::new("", values.len());
for av in values {
match av {
AnyValue::Decimal(v, s) if *s == scale => builder.append_value(*v),
// Allow equal or less scale. We do want to support different scales even in 'strict' mode.
AnyValue::Decimal(v, s) if *s <= scale => {
if *s == scale {
builder.append_value(*v)
} else {
match av.strict_cast(&target_dtype) {
Some(AnyValue::Decimal(i, _)) => builder.append_value(i),
_ => builder.append_null(),
}
}
},
AnyValue::Null => builder.append_null(),
av => {
if strict {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,6 @@ def test_fallback_without_dtype(
[timedelta(hours=0), 1_000],
[D("12.345"), 100],
[D("12.345"), 3.14],
[D("0.12345"), D("6789.0")],
[{"a": 1, "b": "foo"}, {"a": -1, "b": date(2020, 12, 31)}],
[{"a": None}, {"a": 1.0}, {"a": 1}],
],
Expand Down
7 changes: 7 additions & 0 deletions py-polars/tests/unit/datatypes/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,3 +515,10 @@ def test_decimal_dynamic_float_st() -> None:
assert pl.LazyFrame({"a": [D("2.0"), D("0.5")]}).filter(
pl.col("a").is_between(0.45, 0.9)
).collect().to_dict(as_series=False) == {"a": [D("0.5")]}


def test_decimal_strict_scale_inference_17770() -> None:
values = [D("0.1"), D("0.10"), D("1.0121")]
s = pl.Series(values, strict=True)
assert s.dtype == pl.Decimal(precision=None, scale=4)
assert s.to_list() == values

0 comments on commit 954538e

Please sign in to comment.