From 954538e3356fe43959acdfe88d51ae7d0c6693f1 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Wed, 24 Jul 2024 14:58:35 +0200 Subject: [PATCH] fix: Infer decimal scales on mixed scale input (#17840) --- crates/polars-core/src/series/any_value.rs | 14 ++++++++++++-- .../unit/constructors/test_any_value_fallbacks.py | 1 - py-polars/tests/unit/datatypes/test_decimal.py | 7 +++++++ 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/crates/polars-core/src/series/any_value.rs b/crates/polars-core/src/series/any_value.rs index 3eedb866118b..83abf75e980d 100644 --- a/crates/polars-core/src/series/any_value.rs +++ b/crates/polars-core/src/series/any_value.rs @@ -68,7 +68,7 @@ impl Series { // TODO: Remove this when Decimal data type equality is implemented. #[cfg(feature = "dtype-decimal")] - if !strict && dtype.is_decimal() { + if dtype.is_decimal() { let dtype = DataType::Decimal(None, None); return Self::from_any_values_and_dtype(name, values, &dtype, strict); } @@ -488,7 +488,17 @@ fn any_values_to_decimal( let mut builder = PrimitiveChunkedBuilder::::new("", values.len()); for av in values { match av { - AnyValue::Decimal(v, s) if *s == scale => builder.append_value(*v), + // Allow equal or less scale. We do want to support different scales even in 'strict' mode. + AnyValue::Decimal(v, s) if *s <= scale => { + if *s == scale { + builder.append_value(*v) + } else { + match av.strict_cast(&target_dtype) { + Some(AnyValue::Decimal(i, _)) => builder.append_value(i), + _ => builder.append_null(), + } + } + }, AnyValue::Null => builder.append_null(), av => { if strict { diff --git a/py-polars/tests/unit/constructors/test_any_value_fallbacks.py b/py-polars/tests/unit/constructors/test_any_value_fallbacks.py index e9413edfb060..0f3c40f21925 100644 --- a/py-polars/tests/unit/constructors/test_any_value_fallbacks.py +++ b/py-polars/tests/unit/constructors/test_any_value_fallbacks.py @@ -299,7 +299,6 @@ def test_fallback_without_dtype( [timedelta(hours=0), 1_000], [D("12.345"), 100], [D("12.345"), 3.14], - [D("0.12345"), D("6789.0")], [{"a": 1, "b": "foo"}, {"a": -1, "b": date(2020, 12, 31)}], [{"a": None}, {"a": 1.0}, {"a": 1}], ], diff --git a/py-polars/tests/unit/datatypes/test_decimal.py b/py-polars/tests/unit/datatypes/test_decimal.py index c6cdaf247524..13acb7d66741 100644 --- a/py-polars/tests/unit/datatypes/test_decimal.py +++ b/py-polars/tests/unit/datatypes/test_decimal.py @@ -515,3 +515,10 @@ def test_decimal_dynamic_float_st() -> None: assert pl.LazyFrame({"a": [D("2.0"), D("0.5")]}).filter( pl.col("a").is_between(0.45, 0.9) ).collect().to_dict(as_series=False) == {"a": [D("0.5")]} + + +def test_decimal_strict_scale_inference_17770() -> None: + values = [D("0.1"), D("0.10"), D("1.0121")] + s = pl.Series(values, strict=True) + assert s.dtype == pl.Decimal(precision=None, scale=4) + assert s.to_list() == values