diff --git a/crates/proof-of-sql/src/base/database/column_operation.rs b/crates/proof-of-sql/src/base/database/column_operation.rs index f7a82a0f0..2af3197e2 100644 --- a/crates/proof-of-sql/src/base/database/column_operation.rs +++ b/crates/proof-of-sql/src/base/database/column_operation.rs @@ -17,44 +17,163 @@ use proof_of_sql_parser::intermediate_ast::BinaryOperator; // For decimal type manipulation please refer to // https://learn.microsoft.com/en-us/sql/t-sql/data-types/precision-scale-and-length-transact-sql?view=sql-server-ver16 -/// Determine the output type of an add or subtract operation if it is possible -/// to add or subtract the two input types. If the types are not compatible, return -/// an error. -/// -/// # Panics -/// -/// - Panics if `lhs` or `rhs` does not have a precision or scale when they are expected to be numeric types. -/// - Panics if `lhs` or `rhs` is an integer, and `lhs.max_integer_type(&rhs)` returns `None`. -pub fn try_add_subtract_column_types( - lhs: ColumnType, - rhs: ColumnType, - operator: BinaryOperator, -) -> ColumnOperationResult { - if !lhs.is_numeric() || !rhs.is_numeric() { - return Err(ColumnOperationError::BinaryOperationInvalidColumnType { - operator, - left_type: lhs, - right_type: rhs, - }); +/// A trait for column operations. +pub trait ColumnOperation { + /// Determine the output type of an add or subtract operation if it is possible + /// to add or subtract the two input types. If the types are not compatible, return + /// an error. + /// + /// # Panics + /// + /// - Panics if `lhs` or `rhs` does not have a precision or scale when they are expected to be numeric types. + /// - Panics if `lhs` or `rhs` is an integer, and `lhs.max_integer_type(&rhs)` returns `None`. + fn try_add_subtract_column_types( + self, + rhs: ColumnType, + operator: BinaryOperator, + ) -> ColumnOperationResult; + + /// Determine the output type of a multiplication operation if it is possible + /// to multiply the two input types. If the types are not compatible, return + /// an error. + /// + /// # Panics + /// + /// - Panics if `lhs` or `rhs` does not have a precision or scale when they are expected to be numeric types. + /// - Panics if `lhs` or `rhs` is an integer, and `lhs.max_integer_type(&rhs)` returns `None`. + fn try_multiply_column_types(self, rhs: ColumnType) -> ColumnOperationResult; + + /// Determine the output type of a division operation if it is possible + /// to multiply the two input types. If the types are not compatible, return + /// an error. + /// + /// # Panics + /// + /// - Panics if `lhs` or `rhs` does not have a precision or scale when they are expected to be numeric types. + /// - Panics if `lhs` or `rhs` is an integer, and `lhs.max_integer_type(&rhs)` returns `None`. + fn try_divide_column_types(self, rhs: ColumnType) -> ColumnOperationResult; +} + +impl ColumnOperation for ColumnType { + fn try_add_subtract_column_types( + self, + rhs: ColumnType, + operator: BinaryOperator, + ) -> ColumnOperationResult { + let lhs = self; + if !lhs.is_numeric() || !rhs.is_numeric() { + return Err(ColumnOperationError::BinaryOperationInvalidColumnType { + operator, + left_type: lhs, + right_type: rhs, + }); + } + if lhs.is_integer() && rhs.is_integer() { + // We can unwrap here because we know that both types are integers + return Ok(lhs.max_integer_type(&rhs).unwrap()); + } + if lhs == ColumnType::Scalar || rhs == ColumnType::Scalar { + Ok(ColumnType::Scalar) + } else { + let left_precision_value = + lhs.precision_value().expect("Numeric types have precision") as i16; + let right_precision_value = + rhs.precision_value().expect("Numeric types have precision") as i16; + let left_scale = lhs.scale().expect("Numeric types have scale"); + let right_scale = rhs.scale().expect("Numeric types have scale"); + let scale = left_scale.max(right_scale); + let precision_value: i16 = scale as i16 + + (left_precision_value - left_scale as i16) + .max(right_precision_value - right_scale as i16) + + 1_i16; + let precision = u8::try_from(precision_value) + .map_err(|_| ColumnOperationError::DecimalConversionError { + source: DecimalError::InvalidPrecision { + error: precision_value.to_string(), + }, + }) + .and_then(|p| { + Precision::new(p).map_err(|_| ColumnOperationError::DecimalConversionError { + source: DecimalError::InvalidPrecision { + error: p.to_string(), + }, + }) + })?; + Ok(ColumnType::Decimal75(precision, scale)) + } } - if lhs.is_integer() && rhs.is_integer() { - // We can unwrap here because we know that both types are integers - return Ok(lhs.max_integer_type(&rhs).unwrap()); + + fn try_multiply_column_types(self, rhs: ColumnType) -> ColumnOperationResult { + let lhs = self; + if !lhs.is_numeric() || !rhs.is_numeric() { + return Err(ColumnOperationError::BinaryOperationInvalidColumnType { + operator: BinaryOperator::Multiply, + left_type: lhs, + right_type: rhs, + }); + } + if lhs.is_integer() && rhs.is_integer() { + // We can unwrap here because we know that both types are integers + return Ok(lhs.max_integer_type(&rhs).unwrap()); + } + if lhs == ColumnType::Scalar || rhs == ColumnType::Scalar { + Ok(ColumnType::Scalar) + } else { + let left_precision_value = lhs.precision_value().expect("Numeric types have precision"); + let right_precision_value = + rhs.precision_value().expect("Numeric types have precision"); + let precision_value = left_precision_value + right_precision_value + 1; + let precision = Precision::new(precision_value).map_err(|_| { + ColumnOperationError::DecimalConversionError { + source: DecimalError::InvalidPrecision { + error: format!( + "Required precision {precision_value} is beyond what we can support" + ), + }, + } + })?; + let left_scale = lhs.scale().expect("Numeric types have scale"); + let right_scale = rhs.scale().expect("Numeric types have scale"); + let scale = left_scale.checked_add(right_scale).ok_or( + ColumnOperationError::DecimalConversionError { + source: DecimalError::InvalidScale { + scale: left_scale as i16 + right_scale as i16, + }, + }, + )?; + Ok(ColumnType::Decimal75(precision, scale)) + } } - if lhs == ColumnType::Scalar || rhs == ColumnType::Scalar { - Ok(ColumnType::Scalar) - } else { + + fn try_divide_column_types(self, rhs: ColumnType) -> ColumnOperationResult { + let lhs = self; + if !lhs.is_numeric() + || !rhs.is_numeric() + || lhs == ColumnType::Scalar + || rhs == ColumnType::Scalar + { + return Err(ColumnOperationError::BinaryOperationInvalidColumnType { + operator: BinaryOperator::Division, + left_type: lhs, + right_type: rhs, + }); + } + if lhs.is_integer() && rhs.is_integer() { + // We can unwrap here because we know that both types are integers + return Ok(lhs.max_integer_type(&rhs).unwrap()); + } let left_precision_value = lhs.precision_value().expect("Numeric types have precision") as i16; let right_precision_value = rhs.precision_value().expect("Numeric types have precision") as i16; - let left_scale = lhs.scale().expect("Numeric types have scale"); - let right_scale = rhs.scale().expect("Numeric types have scale"); - let scale = left_scale.max(right_scale); - let precision_value: i16 = scale as i16 - + (left_precision_value - left_scale as i16) - .max(right_precision_value - right_scale as i16) - + 1_i16; + let left_scale = lhs.scale().expect("Numeric types have scale") as i16; + let right_scale = rhs.scale().expect("Numeric types have scale") as i16; + let raw_scale = (left_scale + right_precision_value + 1_i16).max(6_i16); + let precision_value: i16 = left_precision_value - left_scale + right_scale + raw_scale; + let scale = + i8::try_from(raw_scale).map_err(|_| ColumnOperationError::DecimalConversionError { + source: DecimalError::InvalidScale { scale: raw_scale }, + })?; let precision = u8::try_from(precision_value) .map_err(|_| ColumnOperationError::DecimalConversionError { source: DecimalError::InvalidPrecision { @@ -72,110 +191,6 @@ pub fn try_add_subtract_column_types( } } -/// Determine the output type of a multiplication operation if it is possible -/// to multiply the two input types. If the types are not compatible, return -/// an error. -/// -/// # Panics -/// -/// - Panics if `lhs` or `rhs` does not have a precision or scale when they are expected to be numeric types. -/// - Panics if `lhs` or `rhs` is an integer, and `lhs.max_integer_type(&rhs)` returns `None`. -pub fn try_multiply_column_types( - lhs: ColumnType, - rhs: ColumnType, -) -> ColumnOperationResult { - if !lhs.is_numeric() || !rhs.is_numeric() { - return Err(ColumnOperationError::BinaryOperationInvalidColumnType { - operator: BinaryOperator::Multiply, - left_type: lhs, - right_type: rhs, - }); - } - if lhs.is_integer() && rhs.is_integer() { - // We can unwrap here because we know that both types are integers - return Ok(lhs.max_integer_type(&rhs).unwrap()); - } - if lhs == ColumnType::Scalar || rhs == ColumnType::Scalar { - Ok(ColumnType::Scalar) - } else { - let left_precision_value = lhs.precision_value().expect("Numeric types have precision"); - let right_precision_value = rhs.precision_value().expect("Numeric types have precision"); - let precision_value = left_precision_value + right_precision_value + 1; - let precision = Precision::new(precision_value).map_err(|_| { - ColumnOperationError::DecimalConversionError { - source: DecimalError::InvalidPrecision { - error: format!( - "Required precision {precision_value} is beyond what we can support" - ), - }, - } - })?; - let left_scale = lhs.scale().expect("Numeric types have scale"); - let right_scale = rhs.scale().expect("Numeric types have scale"); - let scale = left_scale.checked_add(right_scale).ok_or( - ColumnOperationError::DecimalConversionError { - source: DecimalError::InvalidScale { - scale: left_scale as i16 + right_scale as i16, - }, - }, - )?; - Ok(ColumnType::Decimal75(precision, scale)) - } -} - -/// Determine the output type of a division operation if it is possible -/// to multiply the two input types. If the types are not compatible, return -/// an error. -/// -/// # Panics -/// -/// - Panics if `lhs` or `rhs` does not have a precision or scale when they are expected to be numeric types. -/// - Panics if `lhs` or `rhs` is an integer, and `lhs.max_integer_type(&rhs)` returns `None`. -pub fn try_divide_column_types( - lhs: ColumnType, - rhs: ColumnType, -) -> ColumnOperationResult { - if !lhs.is_numeric() - || !rhs.is_numeric() - || lhs == ColumnType::Scalar - || rhs == ColumnType::Scalar - { - return Err(ColumnOperationError::BinaryOperationInvalidColumnType { - operator: BinaryOperator::Division, - left_type: lhs, - right_type: rhs, - }); - } - if lhs.is_integer() && rhs.is_integer() { - // We can unwrap here because we know that both types are integers - return Ok(lhs.max_integer_type(&rhs).unwrap()); - } - let left_precision_value = lhs.precision_value().expect("Numeric types have precision") as i16; - let right_precision_value = rhs.precision_value().expect("Numeric types have precision") as i16; - let left_scale = lhs.scale().expect("Numeric types have scale") as i16; - let right_scale = rhs.scale().expect("Numeric types have scale") as i16; - let raw_scale = (left_scale + right_precision_value + 1_i16).max(6_i16); - let precision_value: i16 = left_precision_value - left_scale + right_scale + raw_scale; - let scale = - i8::try_from(raw_scale).map_err(|_| ColumnOperationError::DecimalConversionError { - source: DecimalError::InvalidScale { scale: raw_scale }, - })?; - let precision = u8::try_from(precision_value) - .map_err(|_| ColumnOperationError::DecimalConversionError { - source: DecimalError::InvalidPrecision { - error: precision_value.to_string(), - }, - }) - .and_then(|p| { - Precision::new(p).map_err(|_| ColumnOperationError::DecimalConversionError { - source: DecimalError::InvalidPrecision { - error: p.to_string(), - }, - }) - })?; - Ok(ColumnType::Decimal75(precision, scale)) -} - // Unary operations /// Negate a slice of boolean values. @@ -733,7 +748,7 @@ where T1: Copy, { let new_column_type = - try_add_subtract_column_types(left_column_type, right_column_type, BinaryOperator::Add)?; + left_column_type.try_add_subtract_column_types(right_column_type, BinaryOperator::Add)?; let new_precision_value = new_column_type .precision_value() .expect("numeric columns have precision"); @@ -789,11 +804,8 @@ where T0: Copy, T1: Copy, { - let new_column_type = try_add_subtract_column_types( - left_column_type, - right_column_type, - BinaryOperator::Subtract, - )?; + let new_column_type = left_column_type + .try_add_subtract_column_types(right_column_type, BinaryOperator::Subtract)?; let new_precision_value = new_column_type .precision_value() .expect("numeric columns have precision"); @@ -849,7 +861,7 @@ where T0: Copy, T1: Copy, { - let new_column_type = try_multiply_column_types(left_column_type, right_column_type)?; + let new_column_type = left_column_type.try_multiply_column_types(right_column_type)?; let new_precision_value = new_column_type .precision_value() .expect("numeric columns have precision"); @@ -887,7 +899,7 @@ where T0: Copy + Debug + Into, T1: Copy + Debug + Into, { - let new_column_type = try_divide_column_types(left_column_type, right_column_type)?; + let new_column_type = left_column_type.try_divide_column_types(right_column_type)?; let new_precision_value = new_column_type .precision_value() .expect("numeric columns have precision"); @@ -934,49 +946,63 @@ mod test { // lhs and rhs are integers with the same precision let lhs = ColumnType::SmallInt; let rhs = ColumnType::SmallInt; - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); + let actual = lhs + .try_add_subtract_column_types(rhs, BinaryOperator::Add) + .unwrap(); let expected = ColumnType::SmallInt; assert_eq!(expected, actual); // lhs and rhs are integers with different precision let lhs = ColumnType::SmallInt; let rhs = ColumnType::Int; - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); + let actual = lhs + .try_add_subtract_column_types(rhs, BinaryOperator::Add) + .unwrap(); let expected = ColumnType::Int; assert_eq!(expected, actual); // lhs is an integer and rhs is a scalar let lhs = ColumnType::SmallInt; let rhs = ColumnType::Scalar; - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); + let actual = lhs + .try_add_subtract_column_types(rhs, BinaryOperator::Add) + .unwrap(); let expected = ColumnType::Scalar; assert_eq!(expected, actual); // lhs is a decimal with nonnegative scale and rhs is an integer let lhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); let rhs = ColumnType::SmallInt; - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); + let actual = lhs + .try_add_subtract_column_types(rhs, BinaryOperator::Add) + .unwrap(); let expected = ColumnType::Decimal75(Precision::new(11).unwrap(), 2); assert_eq!(expected, actual); // lhs and rhs are both decimals with nonnegative scale let lhs = ColumnType::Decimal75(Precision::new(20).unwrap(), 3); let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); + let actual = lhs + .try_add_subtract_column_types(rhs, BinaryOperator::Add) + .unwrap(); let expected = ColumnType::Decimal75(Precision::new(21).unwrap(), 3); assert_eq!(expected, actual); // lhs is an integer and rhs is a decimal with negative scale let lhs = ColumnType::SmallInt; let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), -2); - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); + let actual = lhs + .try_add_subtract_column_types(rhs, BinaryOperator::Add) + .unwrap(); let expected = ColumnType::Decimal75(Precision::new(13).unwrap(), 0); assert_eq!(expected, actual); // lhs and rhs are both decimals one of which has negative scale let lhs = ColumnType::Decimal75(Precision::new(40).unwrap(), -13); let rhs = ColumnType::Decimal75(Precision::new(15).unwrap(), 5); - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); + let actual = lhs + .try_add_subtract_column_types(rhs, BinaryOperator::Add) + .unwrap(); let expected = ColumnType::Decimal75(Precision::new(59).unwrap(), 5); assert_eq!(expected, actual); @@ -984,7 +1010,9 @@ mod test { // and with result having maximum precision let lhs = ColumnType::Decimal75(Precision::new(74).unwrap(), -13); let rhs = ColumnType::Decimal75(Precision::new(15).unwrap(), -14); - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); + let actual = lhs + .try_add_subtract_column_types(rhs, BinaryOperator::Add) + .unwrap(); let expected = ColumnType::Decimal75(Precision::new(75).unwrap(), -13); assert_eq!(expected, actual); } @@ -994,14 +1022,14 @@ mod test { let lhs = ColumnType::SmallInt; let rhs = ColumnType::VarChar; assert!(matches!( - try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add), + lhs.try_add_subtract_column_types(rhs, BinaryOperator::Add), Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) )); let lhs = ColumnType::VarChar; let rhs = ColumnType::VarChar; assert!(matches!( - try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add), + lhs.try_add_subtract_column_types(rhs, BinaryOperator::Add), Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) )); } @@ -1011,7 +1039,7 @@ mod test { let lhs = ColumnType::Decimal75(Precision::new(75).unwrap(), 4); let rhs = ColumnType::Decimal75(Precision::new(73).unwrap(), 4); assert!(matches!( - try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add), + lhs.try_add_subtract_column_types(rhs, BinaryOperator::Add), Err(ColumnOperationError::DecimalConversionError { source: DecimalError::InvalidPrecision { .. } }) @@ -1020,7 +1048,7 @@ mod test { let lhs = ColumnType::Int; let rhs = ColumnType::Decimal75(Precision::new(75).unwrap(), 10); assert!(matches!( - try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add), + lhs.try_add_subtract_column_types(rhs, BinaryOperator::Add), Err(ColumnOperationError::DecimalConversionError { source: DecimalError::InvalidPrecision { .. } }) @@ -1032,49 +1060,63 @@ mod test { // lhs and rhs are integers with the same precision let lhs = ColumnType::SmallInt; let rhs = ColumnType::SmallInt; - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); + let actual = lhs + .try_add_subtract_column_types(rhs, BinaryOperator::Subtract) + .unwrap(); let expected = ColumnType::SmallInt; assert_eq!(expected, actual); // lhs and rhs are integers with different precision let lhs = ColumnType::SmallInt; let rhs = ColumnType::Int; - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); + let actual = lhs + .try_add_subtract_column_types(rhs, BinaryOperator::Subtract) + .unwrap(); let expected = ColumnType::Int; assert_eq!(expected, actual); // lhs is an integer and rhs is a scalar let lhs = ColumnType::SmallInt; let rhs = ColumnType::Scalar; - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); + let actual = lhs + .try_add_subtract_column_types(rhs, BinaryOperator::Subtract) + .unwrap(); let expected = ColumnType::Scalar; assert_eq!(expected, actual); // lhs is a decimal and rhs is an integer let lhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); let rhs = ColumnType::SmallInt; - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); + let actual = lhs + .try_add_subtract_column_types(rhs, BinaryOperator::Subtract) + .unwrap(); let expected = ColumnType::Decimal75(Precision::new(11).unwrap(), 2); assert_eq!(expected, actual); // lhs and rhs are both decimals with nonnegative scale let lhs = ColumnType::Decimal75(Precision::new(20).unwrap(), 3); let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); + let actual = lhs + .try_add_subtract_column_types(rhs, BinaryOperator::Subtract) + .unwrap(); let expected = ColumnType::Decimal75(Precision::new(21).unwrap(), 3); assert_eq!(expected, actual); // lhs is an integer and rhs is a decimal with negative scale let lhs = ColumnType::SmallInt; let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), -2); - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); + let actual = lhs + .try_add_subtract_column_types(rhs, BinaryOperator::Subtract) + .unwrap(); let expected = ColumnType::Decimal75(Precision::new(13).unwrap(), 0); assert_eq!(expected, actual); // lhs and rhs are both decimals one of which has negative scale let lhs = ColumnType::Decimal75(Precision::new(40).unwrap(), -13); let rhs = ColumnType::Decimal75(Precision::new(15).unwrap(), 5); - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); + let actual = lhs + .try_add_subtract_column_types(rhs, BinaryOperator::Subtract) + .unwrap(); let expected = ColumnType::Decimal75(Precision::new(59).unwrap(), 5); assert_eq!(expected, actual); @@ -1082,7 +1124,9 @@ mod test { // and with result having maximum precision let lhs = ColumnType::Decimal75(Precision::new(61).unwrap(), -13); let rhs = ColumnType::Decimal75(Precision::new(73).unwrap(), -14); - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); + let actual = lhs + .try_add_subtract_column_types(rhs, BinaryOperator::Subtract) + .unwrap(); let expected = ColumnType::Decimal75(Precision::new(75).unwrap(), -13); assert_eq!(expected, actual); } @@ -1092,14 +1136,14 @@ mod test { let lhs = ColumnType::SmallInt; let rhs = ColumnType::VarChar; assert!(matches!( - try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract), + lhs.try_add_subtract_column_types(rhs, BinaryOperator::Subtract), Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) )); let lhs = ColumnType::VarChar; let rhs = ColumnType::VarChar; assert!(matches!( - try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract), + lhs.try_add_subtract_column_types(rhs, BinaryOperator::Subtract), Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) )); } @@ -1109,7 +1153,7 @@ mod test { let lhs = ColumnType::Decimal75(Precision::new(75).unwrap(), 0); let rhs = ColumnType::Decimal75(Precision::new(73).unwrap(), 1); assert!(matches!( - try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract), + lhs.try_add_subtract_column_types(rhs, BinaryOperator::Subtract), Err(ColumnOperationError::DecimalConversionError { source: DecimalError::InvalidPrecision { .. } }) @@ -1118,7 +1162,7 @@ mod test { let lhs = ColumnType::Int128; let rhs = ColumnType::Decimal75(Precision::new(75).unwrap(), 12); assert!(matches!( - try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract), + lhs.try_add_subtract_column_types(rhs, BinaryOperator::Subtract), Err(ColumnOperationError::DecimalConversionError { source: DecimalError::InvalidPrecision { .. } }) @@ -1130,49 +1174,49 @@ mod test { // lhs and rhs are integers with the same precision let lhs = ColumnType::SmallInt; let rhs = ColumnType::SmallInt; - let actual = try_multiply_column_types(lhs, rhs).unwrap(); + let actual = lhs.try_multiply_column_types(rhs).unwrap(); let expected = ColumnType::SmallInt; assert_eq!(expected, actual); // lhs and rhs are integers with different precision let lhs = ColumnType::SmallInt; let rhs = ColumnType::Int; - let actual = try_multiply_column_types(lhs, rhs).unwrap(); + let actual = lhs.try_multiply_column_types(rhs).unwrap(); let expected = ColumnType::Int; assert_eq!(expected, actual); // lhs is an integer and rhs is a scalar let lhs = ColumnType::SmallInt; let rhs = ColumnType::Scalar; - let actual = try_multiply_column_types(lhs, rhs).unwrap(); + let actual = lhs.try_multiply_column_types(rhs).unwrap(); let expected = ColumnType::Scalar; assert_eq!(expected, actual); // lhs is a decimal and rhs is an integer let lhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); let rhs = ColumnType::SmallInt; - let actual = try_multiply_column_types(lhs, rhs).unwrap(); + let actual = lhs.try_multiply_column_types(rhs).unwrap(); let expected = ColumnType::Decimal75(Precision::new(16).unwrap(), 2); assert_eq!(expected, actual); // lhs and rhs are both decimals with nonnegative scale let lhs = ColumnType::Decimal75(Precision::new(20).unwrap(), 3); let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); - let actual = try_multiply_column_types(lhs, rhs).unwrap(); + let actual = lhs.try_multiply_column_types(rhs).unwrap(); let expected = ColumnType::Decimal75(Precision::new(31).unwrap(), 5); assert_eq!(expected, actual); // lhs is an integer and rhs is a decimal with negative scale let lhs = ColumnType::SmallInt; let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), -2); - let actual = try_multiply_column_types(lhs, rhs).unwrap(); + let actual = lhs.try_multiply_column_types(rhs).unwrap(); let expected = ColumnType::Decimal75(Precision::new(16).unwrap(), -2); assert_eq!(expected, actual); // lhs and rhs are both decimals one of which has negative scale let lhs = ColumnType::Decimal75(Precision::new(40).unwrap(), -13); let rhs = ColumnType::Decimal75(Precision::new(15).unwrap(), 5); - let actual = try_multiply_column_types(lhs, rhs).unwrap(); + let actual = lhs.try_multiply_column_types(rhs).unwrap(); let expected = ColumnType::Decimal75(Precision::new(56).unwrap(), -8); assert_eq!(expected, actual); @@ -1180,7 +1224,7 @@ mod test { // and with result having maximum precision let lhs = ColumnType::Decimal75(Precision::new(61).unwrap(), -13); let rhs = ColumnType::Decimal75(Precision::new(13).unwrap(), -14); - let actual = try_multiply_column_types(lhs, rhs).unwrap(); + let actual = lhs.try_multiply_column_types(rhs).unwrap(); let expected = ColumnType::Decimal75(Precision::new(75).unwrap(), -27); assert_eq!(expected, actual); } @@ -1190,14 +1234,14 @@ mod test { let lhs = ColumnType::SmallInt; let rhs = ColumnType::VarChar; assert!(matches!( - try_multiply_column_types(lhs, rhs), + lhs.try_multiply_column_types(rhs), Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) )); let lhs = ColumnType::VarChar; let rhs = ColumnType::VarChar; assert!(matches!( - try_multiply_column_types(lhs, rhs), + lhs.try_multiply_column_types(rhs), Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) )); } @@ -1208,7 +1252,7 @@ mod test { let lhs = ColumnType::Decimal75(Precision::new(38).unwrap(), 4); let rhs = ColumnType::Decimal75(Precision::new(37).unwrap(), 4); assert!(matches!( - try_multiply_column_types(lhs, rhs), + lhs.try_multiply_column_types(rhs), Err(ColumnOperationError::DecimalConversionError { source: DecimalError::InvalidPrecision { .. } }) @@ -1217,7 +1261,7 @@ mod test { let lhs = ColumnType::Int; let rhs = ColumnType::Decimal75(Precision::new(65).unwrap(), 0); assert!(matches!( - try_multiply_column_types(lhs, rhs), + lhs.try_multiply_column_types(rhs), Err(ColumnOperationError::DecimalConversionError { source: DecimalError::InvalidPrecision { .. } }) @@ -1227,7 +1271,7 @@ mod test { let lhs = ColumnType::Decimal75(Precision::new(5).unwrap(), -64_i8); let rhs = ColumnType::Decimal75(Precision::new(5).unwrap(), -65_i8); assert!(matches!( - try_multiply_column_types(lhs, rhs), + lhs.try_multiply_column_types(rhs), Err(ColumnOperationError::DecimalConversionError { source: DecimalError::InvalidScale { .. } }) @@ -1236,7 +1280,7 @@ mod test { let lhs = ColumnType::Decimal75(Precision::new(5).unwrap(), 64_i8); let rhs = ColumnType::Decimal75(Precision::new(5).unwrap(), 64_i8); assert!(matches!( - try_multiply_column_types(lhs, rhs), + lhs.try_multiply_column_types(rhs), Err(ColumnOperationError::DecimalConversionError { source: DecimalError::InvalidScale { .. } }) @@ -1248,49 +1292,49 @@ mod test { // lhs and rhs are integers with the same precision let lhs = ColumnType::SmallInt; let rhs = ColumnType::SmallInt; - let actual = try_divide_column_types(lhs, rhs).unwrap(); + let actual = lhs.try_divide_column_types(rhs).unwrap(); let expected = ColumnType::SmallInt; assert_eq!(expected, actual); // lhs and rhs are integers with different precision let lhs = ColumnType::SmallInt; let rhs = ColumnType::Int; - let actual = try_divide_column_types(lhs, rhs).unwrap(); + let actual = lhs.try_divide_column_types(rhs).unwrap(); let expected = ColumnType::Int; assert_eq!(expected, actual); // lhs is a decimal with nonnegative scale and rhs is an integer let lhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); let rhs = ColumnType::SmallInt; - let actual = try_divide_column_types(lhs, rhs).unwrap(); + let actual = lhs.try_divide_column_types(rhs).unwrap(); let expected = ColumnType::Decimal75(Precision::new(16).unwrap(), 8); assert_eq!(expected, actual); // lhs is an integer and rhs is a decimal with nonnegative scale let lhs = ColumnType::SmallInt; let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); - let actual = try_divide_column_types(lhs, rhs).unwrap(); + let actual = lhs.try_divide_column_types(rhs).unwrap(); let expected = ColumnType::Decimal75(Precision::new(18).unwrap(), 11); assert_eq!(expected, actual); // lhs and rhs are both decimals with nonnegative scale let lhs = ColumnType::Decimal75(Precision::new(20).unwrap(), 3); let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); - let actual = try_divide_column_types(lhs, rhs).unwrap(); + let actual = lhs.try_divide_column_types(rhs).unwrap(); let expected = ColumnType::Decimal75(Precision::new(33).unwrap(), 14); assert_eq!(expected, actual); // lhs is an integer and rhs is a decimal with negative scale let lhs = ColumnType::SmallInt; let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), -2); - let actual = try_divide_column_types(lhs, rhs).unwrap(); + let actual = lhs.try_divide_column_types(rhs).unwrap(); let expected = ColumnType::Decimal75(Precision::new(14).unwrap(), 11); assert_eq!(expected, actual); // lhs and rhs are both decimals one of which has negative scale let lhs = ColumnType::Decimal75(Precision::new(40).unwrap(), -13); let rhs = ColumnType::Decimal75(Precision::new(15).unwrap(), 5); - let actual = try_divide_column_types(lhs, rhs).unwrap(); + let actual = lhs.try_divide_column_types(rhs).unwrap(); let expected = ColumnType::Decimal75(Precision::new(64).unwrap(), 6); assert_eq!(expected, actual); @@ -1298,7 +1342,7 @@ mod test { // and with result having maximum precision let lhs = ColumnType::Decimal75(Precision::new(70).unwrap(), -13); let rhs = ColumnType::Decimal75(Precision::new(13).unwrap(), -14); - let actual = try_divide_column_types(lhs, rhs).unwrap(); + let actual = lhs.try_divide_column_types(rhs).unwrap(); let expected = ColumnType::Decimal75(Precision::new(75).unwrap(), 6); assert_eq!(expected, actual); } @@ -1308,21 +1352,21 @@ mod test { let lhs = ColumnType::SmallInt; let rhs = ColumnType::VarChar; assert!(matches!( - try_divide_column_types(lhs, rhs), + lhs.try_divide_column_types(rhs), Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) )); let lhs = ColumnType::VarChar; let rhs = ColumnType::VarChar; assert!(matches!( - try_divide_column_types(lhs, rhs), + lhs.try_divide_column_types(rhs), Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) )); let lhs = ColumnType::Scalar; let rhs = ColumnType::Scalar; assert!(matches!( - try_divide_column_types(lhs, rhs), + lhs.try_divide_column_types(rhs), Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) )); } @@ -1333,7 +1377,7 @@ mod test { let lhs = ColumnType::Decimal75(Precision::new(71).unwrap(), -13); let rhs = ColumnType::Decimal75(Precision::new(13).unwrap(), -14); assert!(matches!( - try_divide_column_types(lhs, rhs), + lhs.try_divide_column_types(rhs), Err(ColumnOperationError::DecimalConversionError { source: DecimalError::InvalidPrecision { .. } }) @@ -1342,7 +1386,7 @@ mod test { let lhs = ColumnType::Int; let rhs = ColumnType::Decimal75(Precision::new(68).unwrap(), 67); assert!(matches!( - try_divide_column_types(lhs, rhs), + lhs.try_divide_column_types(rhs), Err(ColumnOperationError::DecimalConversionError { source: DecimalError::InvalidPrecision { .. } }) @@ -1352,7 +1396,7 @@ mod test { let lhs = ColumnType::Decimal75(Precision::new(15).unwrap(), 53_i8); let rhs = ColumnType::Decimal75(Precision::new(75).unwrap(), 40_i8); assert!(matches!( - try_divide_column_types(lhs, rhs), + lhs.try_divide_column_types(rhs), Err(ColumnOperationError::DecimalConversionError { source: DecimalError::InvalidScale { .. } }) diff --git a/crates/proof-of-sql/src/base/database/mod.rs b/crates/proof-of-sql/src/base/database/mod.rs index a1bb6dd7f..6e8d01501 100644 --- a/crates/proof-of-sql/src/base/database/mod.rs +++ b/crates/proof-of-sql/src/base/database/mod.rs @@ -8,9 +8,7 @@ mod column; pub use column::{Column, ColumnField, ColumnRef, ColumnType}; mod column_operation; -pub use column_operation::{ - try_add_subtract_column_types, try_divide_column_types, try_multiply_column_types, -}; +pub use column_operation::ColumnOperation; mod column_operation_error; pub use column_operation_error::{ColumnOperationError, ColumnOperationResult}; diff --git a/crates/proof-of-sql/src/sql/parse/query_context_builder.rs b/crates/proof-of-sql/src/sql/parse/query_context_builder.rs index 819bbbec6..d4d516153 100644 --- a/crates/proof-of-sql/src/sql/parse/query_context_builder.rs +++ b/crates/proof-of-sql/src/sql/parse/query_context_builder.rs @@ -1,9 +1,7 @@ use super::{ConversionError, ConversionResult, QueryContext}; +use crate::base::database::ColumnOperation; use crate::base::{ - database::{ - try_add_subtract_column_types, try_multiply_column_types, ColumnRef, ColumnType, - SchemaAccessor, TableRef, - }, + database::{ColumnRef, ColumnType, SchemaAccessor, TableRef}, math::decimal::Precision, }; use alloc::{boxed::Box, string::ToString, vec::Vec}; @@ -294,14 +292,15 @@ pub(crate) fn type_check_binary_operation( | (ColumnType::TimestampTZ(_, _), ColumnType::TimestampTZ(_, _)) ) } - BinaryOperator::Add => { - try_add_subtract_column_types(*left_dtype, *right_dtype, BinaryOperator::Add).is_ok() - } - BinaryOperator::Subtract => { - try_add_subtract_column_types(*left_dtype, *right_dtype, BinaryOperator::Subtract) - .is_ok() - } - BinaryOperator::Multiply => try_multiply_column_types(*left_dtype, *right_dtype).is_ok(), + BinaryOperator::Add => (*left_dtype) + .try_add_subtract_column_types(*right_dtype, BinaryOperator::Add) + .is_ok(), + BinaryOperator::Subtract => (*left_dtype) + .try_add_subtract_column_types(*right_dtype, BinaryOperator::Subtract) + .is_ok(), + BinaryOperator::Multiply => (*left_dtype) + .try_multiply_column_types(*right_dtype) + .is_ok(), BinaryOperator::Division => left_dtype.is_numeric() && right_dtype.is_numeric(), } } diff --git a/crates/proof-of-sql/src/sql/proof_exprs/add_subtract_expr.rs b/crates/proof-of-sql/src/sql/proof_exprs/add_subtract_expr.rs index 2b7611c06..0d700b8c3 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/add_subtract_expr.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/add_subtract_expr.rs @@ -1,11 +1,9 @@ use super::{add_subtract_columns, scale_and_add_subtract_eval, DynProofExpr, ProofExpr}; +use crate::base::database::ColumnOperation; use crate::{ base::{ commitment::Commitment, - database::{ - try_add_subtract_column_types, Column, ColumnRef, ColumnType, CommitmentAccessor, - DataAccessor, - }, + database::{Column, ColumnRef, ColumnType, CommitmentAccessor, DataAccessor}, map::IndexSet, proof::ProofError, }, @@ -48,7 +46,9 @@ impl ProofExpr for AddSubtractExpr { } else { BinaryOperator::Add }; - try_add_subtract_column_types(self.lhs.data_type(), self.rhs.data_type(), operator) + self.lhs + .data_type() + .try_add_subtract_column_types(self.rhs.data_type(), operator) .expect("Failed to add/subtract column types") } diff --git a/crates/proof-of-sql/src/sql/proof_exprs/multiply_expr.rs b/crates/proof-of-sql/src/sql/proof_exprs/multiply_expr.rs index 391091745..3dfb7e631 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/multiply_expr.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/multiply_expr.rs @@ -1,11 +1,9 @@ use super::{DynProofExpr, ProofExpr}; +use crate::base::database::ColumnOperation; use crate::{ base::{ commitment::Commitment, - database::{ - try_multiply_column_types, Column, ColumnRef, ColumnType, CommitmentAccessor, - DataAccessor, - }, + database::{Column, ColumnRef, ColumnType, CommitmentAccessor, DataAccessor}, map::IndexSet, proof::ProofError, }, @@ -44,7 +42,9 @@ impl ProofExpr for MultiplyExpr { } fn data_type(&self) -> ColumnType { - try_multiply_column_types(self.lhs.data_type(), self.rhs.data_type()) + self.lhs + .data_type() + .try_multiply_column_types(self.rhs.data_type()) .expect("Failed to multiply column types") }