Skip to content

Commit

Permalink
refactor!: remove parser structs from ColumnOperationError (#344)
Browse files Browse the repository at this point in the history
Please be sure to look over the pull request guidelines here:
https://github.com/spaceandtimelabs/sxt-proof-of-sql/blob/main/CONTRIBUTING.md#submit-pr.

# Please go through the following checklist
- [x] The PR title and commit messages adhere to guidelines here:
https://github.com/spaceandtimelabs/sxt-proof-of-sql/blob/main/CONTRIBUTING.md.
In particular `!` is used if and only if at least one breaking change
has been introduced.
- [x] I have run the ci check script with `source
scripts/run_ci_checks.sh`.

# Rationale for this change
This change is made to simplify column operations partly in order to
simplify `owned_column_operation.rs` and partly in preparation for the
new version of provable arithmetic expressions.
<!--
Why are you proposing this change? If this is already explained clearly
in the linked issue then this section is not needed.
Explaining clearly why changes are proposed helps reviewers understand
your changes and offer better suggestions for fixes.

 Example:
 Add `NestedLoopJoinExec`.
 Closes #345.

Since we added `HashJoinExec` in #323 it has been possible to do
provable inner joins. However performance is not satisfactory in some
cases. Hence we need to fix the problem by implement
`NestedLoopJoinExec` and speed up the code
 for `HashJoinExec`.
-->

# What changes are included in this PR?
- replace `BinaryOp` and `UnaryOp` in `ColumnOperationError` with
strings
<!--
There is no need to duplicate the description in the ticket here but it
is sometimes worth providing a summary of the individual changes in this
PR.

Example:
- Add `NestedLoopJoinExec`.
- Speed up `HashJoinExec`.
- Route joins to `NestedLoopJoinExec` if the outer input is sufficiently
small.
-->

# Are these changes tested?
<!--
We typically require tests for all PRs in order to:
1. Prevent the code from being accidentally broken by subsequent changes
2. Serve as another way to document the expected behavior of the code

If tests are not included in your PR, please explain why (for example,
are they covered by existing tests)?

Example:
Yes.
-->
Existing tests should pass
  • Loading branch information
iajoiner authored Nov 11, 2024
2 parents 87c1a41 + 4b46324 commit fdb4c9e
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 77 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use crate::base::{database::ColumnType, math::decimal::DecimalError};
use alloc::string::String;
use core::result::Result;
use proof_of_sql_parser::intermediate_ast::{BinaryOperator, UnaryOperator};
use snafu::Snafu;

/// Errors from operations on columns.
Expand All @@ -19,8 +18,8 @@ pub enum ColumnOperationError {
/// Incorrect `ColumnType` in binary operations
#[snafu(display("{operator:?}(lhs: {left_type:?}, rhs: {right_type:?}) is not supported"))]
BinaryOperationInvalidColumnType {
/// `BinaryOperator` that caused the error
operator: BinaryOperator,
/// Binary operator that caused the error
operator: String,
/// `ColumnType` of left operand
left_type: ColumnType,
/// `ColumnType` of right operand
Expand All @@ -30,8 +29,8 @@ pub enum ColumnOperationError {
/// Incorrect `ColumnType` in unary operations
#[snafu(display("{operator:?}(operand: {operand_type:?}) is not supported"))]
UnaryOperationInvalidColumnType {
/// `UnaryOperator` that caused the error
operator: UnaryOperator,
/// Unary operator that caused the error
operator: String,
/// `ColumnType` of the operand
operand_type: ColumnType,
},
Expand Down
80 changes: 39 additions & 41 deletions crates/proof-of-sql/src/base/database/column_type_operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use crate::base::{
math::decimal::{DecimalError, Precision},
};
use alloc::{format, string::ToString};
use proof_of_sql_parser::intermediate_ast::BinaryOperator;
// For decimal type manipulation please refer to
// https://learn.microsoft.com/en-us/sql/t-sql/data-types/precision-scale-and-length-transact-sql?view=sql-server-ver16

Expand All @@ -19,11 +18,10 @@ use proof_of_sql_parser::intermediate_ast::BinaryOperator;
pub fn try_add_subtract_column_types(
lhs: ColumnType,
rhs: ColumnType,
operator: BinaryOperator,
) -> ColumnOperationResult<ColumnType> {
if !lhs.is_numeric() || !rhs.is_numeric() {
return Err(ColumnOperationError::BinaryOperationInvalidColumnType {
operator,
operator: "+/-".to_string(),
left_type: lhs,
right_type: rhs,
});
Expand Down Expand Up @@ -77,7 +75,7 @@ pub fn try_multiply_column_types(
) -> ColumnOperationResult<ColumnType> {
if !lhs.is_numeric() || !rhs.is_numeric() {
return Err(ColumnOperationError::BinaryOperationInvalidColumnType {
operator: BinaryOperator::Multiply,
operator: "*".to_string(),
left_type: lhs,
right_type: rhs,
});
Expand Down Expand Up @@ -132,7 +130,7 @@ pub fn try_divide_column_types(
|| rhs == ColumnType::Scalar
{
return Err(ColumnOperationError::BinaryOperationInvalidColumnType {
operator: BinaryOperator::Division,
operator: "/".to_string(),
left_type: lhs,
right_type: rhs,
});
Expand Down Expand Up @@ -180,87 +178,87 @@ mod test {
// lhs and rhs are integers with the same precision
let lhs = ColumnType::TinyInt;
let rhs = ColumnType::TinyInt;
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::TinyInt;
assert_eq!(expected, actual);

let lhs = ColumnType::SmallInt;
let rhs = ColumnType::SmallInt;
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::SmallInt;
assert_eq!(expected, actual);

// lhs and rhs are integers with different precision
let lhs = ColumnType::TinyInt;
let rhs = ColumnType::SmallInt;
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::SmallInt;
assert_eq!(expected, actual);

let lhs = ColumnType::SmallInt;
let rhs = ColumnType::Int;
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::Int;
assert_eq!(expected, actual);

// lhs is an integer and rhs is a scalar
let lhs = ColumnType::TinyInt;
let rhs = ColumnType::Scalar;
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::Scalar;
assert_eq!(expected, actual);

let lhs = ColumnType::SmallInt;
let rhs = ColumnType::Scalar;
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::Scalar;
assert_eq!(expected, actual);

// lhs is a decimal with nonnegative scale and rhs is an integer
let lhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2);
let rhs = ColumnType::TinyInt;
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::Decimal75(Precision::new(11).unwrap(), 2);
assert_eq!(expected, actual);

let lhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2);
let rhs = ColumnType::SmallInt;
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::Decimal75(Precision::new(11).unwrap(), 2);
assert_eq!(expected, actual);

// lhs and rhs are both decimals with nonnegative scale
let lhs = ColumnType::Decimal75(Precision::new(20).unwrap(), 3);
let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2);
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::Decimal75(Precision::new(21).unwrap(), 3);
assert_eq!(expected, actual);

// lhs is an integer and rhs is a decimal with negative scale
let lhs = ColumnType::TinyInt;
let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), -2);
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::Decimal75(Precision::new(13).unwrap(), 0);
assert_eq!(expected, actual);

let lhs = ColumnType::SmallInt;
let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), -2);
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::Decimal75(Precision::new(13).unwrap(), 0);
assert_eq!(expected, actual);

// lhs and rhs are both decimals one of which has negative scale
let lhs = ColumnType::Decimal75(Precision::new(40).unwrap(), -13);
let rhs = ColumnType::Decimal75(Precision::new(15).unwrap(), 5);
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::Decimal75(Precision::new(59).unwrap(), 5);
assert_eq!(expected, actual);

// lhs and rhs are both decimals both with negative scale
// and with result having maximum precision
let lhs = ColumnType::Decimal75(Precision::new(74).unwrap(), -13);
let rhs = ColumnType::Decimal75(Precision::new(15).unwrap(), -14);
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::Decimal75(Precision::new(75).unwrap(), -13);
assert_eq!(expected, actual);
}
Expand All @@ -270,21 +268,21 @@ mod test {
let lhs = ColumnType::TinyInt;
let rhs = ColumnType::VarChar;
assert!(matches!(
try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add),
try_add_subtract_column_types(lhs, rhs),
Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. })
));

let lhs = ColumnType::SmallInt;
let rhs = ColumnType::VarChar;
assert!(matches!(
try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add),
try_add_subtract_column_types(lhs, rhs),
Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. })
));

let lhs = ColumnType::VarChar;
let rhs = ColumnType::VarChar;
assert!(matches!(
try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add),
try_add_subtract_column_types(lhs, rhs),
Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. })
));
}
Expand All @@ -294,7 +292,7 @@ mod test {
let lhs = ColumnType::Decimal75(Precision::new(75).unwrap(), 4);
let rhs = ColumnType::Decimal75(Precision::new(73).unwrap(), 4);
assert!(matches!(
try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add),
try_add_subtract_column_types(lhs, rhs),
Err(ColumnOperationError::DecimalConversionError {
source: DecimalError::InvalidPrecision { .. }
})
Expand All @@ -303,7 +301,7 @@ mod test {
let lhs = ColumnType::Int;
let rhs = ColumnType::Decimal75(Precision::new(75).unwrap(), 10);
assert!(matches!(
try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add),
try_add_subtract_column_types(lhs, rhs),
Err(ColumnOperationError::DecimalConversionError {
source: DecimalError::InvalidPrecision { .. }
})
Expand All @@ -315,87 +313,87 @@ mod test {
// lhs and rhs are integers with the same precision
let lhs = ColumnType::TinyInt;
let rhs = ColumnType::TinyInt;
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::TinyInt;
assert_eq!(expected, actual);

let lhs = ColumnType::SmallInt;
let rhs = ColumnType::SmallInt;
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::SmallInt;
assert_eq!(expected, actual);

// lhs and rhs are integers with different precision
let lhs = ColumnType::TinyInt;
let rhs = ColumnType::SmallInt;
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::SmallInt;
assert_eq!(expected, actual);

let lhs = ColumnType::SmallInt;
let rhs = ColumnType::Int;
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::Int;
assert_eq!(expected, actual);

// lhs is an integer and rhs is a scalar
let lhs = ColumnType::TinyInt;
let rhs = ColumnType::Scalar;
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::Scalar;
assert_eq!(expected, actual);

let lhs = ColumnType::SmallInt;
let rhs = ColumnType::Scalar;
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::Scalar;
assert_eq!(expected, actual);

// lhs is a decimal and rhs is an integer
let lhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2);
let rhs = ColumnType::TinyInt;
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::Decimal75(Precision::new(11).unwrap(), 2);
assert_eq!(expected, actual);

let lhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2);
let rhs = ColumnType::SmallInt;
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::Decimal75(Precision::new(11).unwrap(), 2);
assert_eq!(expected, actual);

// lhs and rhs are both decimals with nonnegative scale
let lhs = ColumnType::Decimal75(Precision::new(20).unwrap(), 3);
let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2);
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::Decimal75(Precision::new(21).unwrap(), 3);
assert_eq!(expected, actual);

// lhs is an integer and rhs is a decimal with negative scale
let lhs = ColumnType::TinyInt;
let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), -2);
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::Decimal75(Precision::new(13).unwrap(), 0);
assert_eq!(expected, actual);

let lhs = ColumnType::SmallInt;
let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), -2);
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::Decimal75(Precision::new(13).unwrap(), 0);
assert_eq!(expected, actual);

// lhs and rhs are both decimals one of which has negative scale
let lhs = ColumnType::Decimal75(Precision::new(40).unwrap(), -13);
let rhs = ColumnType::Decimal75(Precision::new(15).unwrap(), 5);
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::Decimal75(Precision::new(59).unwrap(), 5);
assert_eq!(expected, actual);

// lhs and rhs are both decimals both with negative scale
// and with result having maximum precision
let lhs = ColumnType::Decimal75(Precision::new(61).unwrap(), -13);
let rhs = ColumnType::Decimal75(Precision::new(73).unwrap(), -14);
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::Decimal75(Precision::new(75).unwrap(), -13);
assert_eq!(expected, actual);
}
Expand All @@ -405,21 +403,21 @@ mod test {
let lhs = ColumnType::TinyInt;
let rhs = ColumnType::VarChar;
assert!(matches!(
try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract),
try_add_subtract_column_types(lhs, rhs),
Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. })
));

let lhs = ColumnType::SmallInt;
let rhs = ColumnType::VarChar;
assert!(matches!(
try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract),
try_add_subtract_column_types(lhs, rhs),
Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. })
));

let lhs = ColumnType::VarChar;
let rhs = ColumnType::VarChar;
assert!(matches!(
try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract),
try_add_subtract_column_types(lhs, rhs),
Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. })
));
}
Expand All @@ -429,7 +427,7 @@ mod test {
let lhs = ColumnType::Decimal75(Precision::new(75).unwrap(), 0);
let rhs = ColumnType::Decimal75(Precision::new(73).unwrap(), 1);
assert!(matches!(
try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract),
try_add_subtract_column_types(lhs, rhs),
Err(ColumnOperationError::DecimalConversionError {
source: DecimalError::InvalidPrecision { .. }
})
Expand All @@ -438,7 +436,7 @@ mod test {
let lhs = ColumnType::Int128;
let rhs = ColumnType::Decimal75(Precision::new(75).unwrap(), 12);
assert!(matches!(
try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract),
try_add_subtract_column_types(lhs, rhs),
Err(ColumnOperationError::DecimalConversionError {
source: DecimalError::InvalidPrecision { .. }
})
Expand Down
Loading

0 comments on commit fdb4c9e

Please sign in to comment.