diff --git a/crates/polars-arrow/src/compute/decimal.rs b/crates/polars-arrow/src/compute/decimal.rs index cb265d43d75e..04066f1d9629 100644 --- a/crates/polars-arrow/src/compute/decimal.rs +++ b/crates/polars-arrow/src/compute/decimal.rs @@ -1,4 +1,4 @@ -use atoi::atoi; +use atoi::FromRadix10SignedChecked; fn significant_digits(bytes: &[u8]) -> u8 { (bytes.len() as u8) - leading_zeros(bytes) @@ -9,12 +9,17 @@ fn leading_zeros(bytes: &[u8]) -> u8 { } fn split_decimal_bytes(bytes: &[u8]) -> (Option<&[u8]>, Option<&[u8]>) { - let mut a = bytes.split(|x| *x == b'.'); + let mut a = bytes.splitn(2, |x| *x == b'.'); let lhs = a.next(); let rhs = a.next(); (lhs, rhs) } +fn parse_integer_checked(bytes: &[u8]) -> Option { + let (n, len) = i128::from_radix_10_signed_checked(bytes); + n.filter(|_| len == bytes.len()) +} + pub fn infer_scale(bytes: &[u8]) -> Option { let (_lhs, rhs) = split_decimal_bytes(bytes); rhs.map(significant_digits) @@ -26,67 +31,64 @@ pub fn infer_scale(bytes: &[u8]) -> Option { pub(super) fn deserialize_decimal(bytes: &[u8], precision: Option, scale: u8) -> Option { let (lhs, rhs) = split_decimal_bytes(bytes); let precision = precision.unwrap_or(u8::MAX); - match (lhs, rhs) { - (Some(lhs), Some(rhs)) => atoi::(lhs).and_then(|x| { - atoi::(rhs) - .map(|y| (x, lhs, y, rhs)) - .and_then(|(lhs, lhs_b, rhs, rhs_b)| { - let lhs_s = significant_digits(lhs_b); - let leading_zeros_rhs = leading_zeros(rhs_b); - let rhs_s = rhs_b.len() as u8 - leading_zeros_rhs; - - // parameters don't match bytes - if lhs_s + rhs_s > precision || rhs_s > scale { - None - } - // significant digits don't fit scale - else if rhs_s < scale { - // scale: 2 - // number: x.09 - // significant digits: 1 - // leading_zeros: 1 - // parsed: 9 - // so this is correct - if leading_zeros_rhs + rhs_s == scale { - Some((lhs, rhs)) + + let lhs_b = lhs?; + parse_integer_checked(lhs_b).and_then(|x| { + match rhs { + Some(rhs) => { + parse_integer_checked(rhs) + .map(|y| (x, lhs_b, y, rhs)) + .and_then(|(lhs, lhs_b, rhs, rhs_b)| { + let lhs_s = significant_digits(lhs_b); + let leading_zeros_rhs = leading_zeros(rhs_b); + let rhs_s = rhs_b.len() as u8 - leading_zeros_rhs; + + // parameters don't match bytes + if lhs_s + rhs_s > precision || rhs_s > scale { + None + } + // significant digits don't fit scale + else if rhs_s < scale { + // scale: 2 + // number: x.09 + // significant digits: 1 + // leading_zeros: 1 + // parsed: 9 + // so this is correct + if leading_zeros_rhs + rhs_s == scale { + Some((lhs, rhs)) + } + // scale: 2 + // number: x.9 + // significant digits: 1 + // parsed: 9 + // so we must multiply by 10 to get 90 + else { + let diff = scale as u32 - (rhs_s + leading_zeros_rhs) as u32; + Some((lhs, rhs * 10i128.pow(diff))) + } } // scale: 2 - // number: x.9 - // significant digits: 1 - // parsed: 9 - // so we must multiply by 10 to get 90 + // number: x.90 + // significant digits: 2 + // parsed: 90 + // so this is correct else { - let diff = scale as u32 - (rhs_s + leading_zeros_rhs) as u32; - Some((lhs, rhs * 10i128.pow(diff))) + Some((lhs, rhs)) } - } - // scale: 2 - // number: x.90 - // significant digits: 2 - // parsed: 90 - // so this is correct - else { - Some((lhs, rhs)) - } - }) - .map(|(lhs, rhs)| { - lhs * 10i128.pow(scale as u32) + (if lhs < 0 { -rhs } else { rhs }) - }) - }), - (None, Some(rhs)) => { - if rhs.len() > precision as usize || rhs.len() != scale as usize { - return None; - } - atoi::(rhs) - }, - (Some(lhs), None) => { - if lhs.len() > precision as usize || scale != 0 { - return None; - } - atoi::(lhs) - }, - (None, None) => None, - } + }) + .map(|(lhs, rhs)| { + lhs * 10i128.pow(scale as u32) + (if lhs < 0 { -rhs } else { rhs }) + }) + }, + None => { + if lhs_b.len() > precision as usize || scale != 0 { + return None; + } + parse_integer_checked(lhs_b) + }, + } + }) } #[cfg(test)] @@ -127,5 +129,32 @@ mod test { deserialize_decimal(val.as_bytes(), precision, scale), Some(1000000000000000000) ); + let scale = 5; + let val = "12ABC.34"; + assert_eq!(deserialize_decimal(val.as_bytes(), precision, scale), None); + + let val = "1ABC2.34"; + assert_eq!(deserialize_decimal(val.as_bytes(), precision, scale), None); + + let val = "12.3ABC4"; + assert_eq!(deserialize_decimal(val.as_bytes(), precision, scale), None); + + let val = "12.3.ABC4"; + assert_eq!(deserialize_decimal(val.as_bytes(), precision, scale), None); + + let val = ""; + assert_eq!(deserialize_decimal(val.as_bytes(), precision, scale), None); + + let val = "5."; + assert_eq!( + deserialize_decimal(val.as_bytes(), precision, scale), + Some(500000i128) + ); + + let val = ".5"; + assert_eq!( + deserialize_decimal(val.as_bytes(), precision, scale), + Some(50000i128) + ); } }