From 9306484ba08006e3201f92618fa1ba1792301707 Mon Sep 17 00:00:00 2001 From: morrySnow <101034200+morrySnow@users.noreply.github.com> Date: Tue, 24 Sep 2024 17:28:34 +0800 Subject: [PATCH] [fix](Nereids) simplify decimal comparison wrong when cast to smaller scale (#41151) pick from master #41151 --- .../expression/rules/SimplifyCastRule.java | 11 +- .../rules/SimplifyComparisonPredicate.java | 26 ++-- .../rules/SimplifyDecimalV3Comparison.java | 26 ++-- .../expressions/literal/DecimalV3Literal.java | 8 +- .../apache/doris/nereids/types/DataType.java | 2 +- .../rules/SimplifyCastRuleTest.java | 51 +++---- .../SimplifyComparisonPredicateTest.java | 128 ++++++++++++++++++ .../SimplifyDecimalV3ComparisonTest.java | 47 ++++--- .../test_simplify_decimal_comparison.groovy | 28 ++++ 9 files changed, 253 insertions(+), 74 deletions(-) create mode 100644 regression-test/suites/nereids_rules_p0/expression/test_simplify_decimal_comparison.groovy diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyCastRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyCastRule.java index 34143043a07022..2f23412c3fc5c0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyCastRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyCastRule.java @@ -38,6 +38,7 @@ import org.apache.doris.nereids.types.VarcharType; import java.math.BigDecimal; +import java.math.RoundingMode; /** * Rewrite rule of simplify CAST expression. @@ -107,8 +108,14 @@ private Expression simplify(Cast cast, ExpressionRewriteContext context) { return new DecimalV3Literal(decimalV3Type, new BigDecimal(((BigIntLiteral) child).getValue())); } else if (child instanceof DecimalV3Literal) { - return new DecimalV3Literal(decimalV3Type, - ((DecimalV3Literal) child).getValue()); + DecimalV3Type childType = (DecimalV3Type) child.getDataType(); + if (childType.getRange() <= decimalV3Type.getRange()) { + return new DecimalV3Literal(decimalV3Type, + ((DecimalV3Literal) child).getValue() + .setScale(decimalV3Type.getScale(), RoundingMode.HALF_UP)); + } else { + return cast; + } } } } catch (Throwable t) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java index 9f719c7377237f..488f7bddfc6d5e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java @@ -221,9 +221,10 @@ private Expression processDecimalV3TypeCoercion(ComparisonPredicate comparisonPr int toScale = ((DecimalV3Type) left.getDataType()).getScale(); if (comparisonPredicate instanceof EqualTo) { try { - return comparisonPredicate.withChildren(left, - new DecimalV3Literal((DecimalV3Type) left.getDataType(), - literal.getValue().setScale(toScale))); + Expression decimal = new DecimalV3Literal((DecimalV3Type) left.getDataType(), + literal.getValue().setScale(toScale, RoundingMode.UNNECESSARY)); + return TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate) + comparisonPredicate.withChildren(left, decimal), left, decimal); } catch (ArithmeticException e) { if (left.nullable()) { // TODO: the ideal way is to return an If expr like: @@ -240,24 +241,27 @@ private Expression processDecimalV3TypeCoercion(ComparisonPredicate comparisonPr } } else if (comparisonPredicate instanceof NullSafeEqual) { try { - return comparisonPredicate.withChildren(left, - new DecimalV3Literal((DecimalV3Type) left.getDataType(), - literal.getValue().setScale(toScale))); + Expression decimal = new DecimalV3Literal((DecimalV3Type) left.getDataType(), + literal.getValue().setScale(toScale, RoundingMode.UNNECESSARY)); + return TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate) + comparisonPredicate.withChildren(left, decimal), left, decimal); } catch (ArithmeticException e) { return BooleanLiteral.of(false); } } else if (comparisonPredicate instanceof GreaterThan || comparisonPredicate instanceof LessThanEqual) { - return comparisonPredicate.withChildren(left, literal.roundFloor(toScale)); + literal = literal.roundFloor(toScale); + return TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate) + comparisonPredicate.withChildren(left, literal), left, literal); } else if (comparisonPredicate instanceof LessThan || comparisonPredicate instanceof GreaterThanEqual) { - return comparisonPredicate.withChildren(left, - literal.roundCeiling(toScale)); + literal = literal.roundCeiling(toScale); + return TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate) + comparisonPredicate.withChildren(left, literal), left, literal); } } } else if (left.getDataType().isIntegerLikeType()) { - return processIntegerDecimalLiteralComparison(comparisonPredicate, left, - literal.getValue()); + return processIntegerDecimalLiteralComparison(comparisonPredicate, left, literal.getValue()); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3Comparison.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3Comparison.java index b821d7a4d19bb7..98a6a9112f872f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3Comparison.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3Comparison.java @@ -17,6 +17,7 @@ package org.apache.doris.nereids.rules.expression.rules; +import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; import org.apache.doris.nereids.trees.expressions.Cast; @@ -25,8 +26,6 @@ import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal; import org.apache.doris.nereids.types.DecimalV3Type; -import com.google.common.base.Preconditions; - import java.math.BigDecimal; import java.math.RoundingMode; @@ -50,15 +49,17 @@ public Expression visitComparisonPredicate(ComparisonPredicate cp, ExpressionRew if (left.getDataType() instanceof DecimalV3Type && left instanceof Cast && ((Cast) left).child().getDataType() instanceof DecimalV3Type + && ((DecimalV3Type) left.getDataType()).getScale() + >= ((DecimalV3Type) ((Cast) left).child().getDataType()).getScale() && right instanceof DecimalV3Literal) { - return doProcess(cp, (Cast) left, (DecimalV3Literal) right); + try { + return doProcess(cp, (Cast) left, (DecimalV3Literal) right); + } catch (ArithmeticException e) { + return cp; + } } - if (left != cp.left() || right != cp.right()) { - return cp.withChildren(left, right); - } else { - return cp; - } + return cp; } private Expression doProcess(ComparisonPredicate cp, Cast left, DecimalV3Literal right) { @@ -72,13 +73,16 @@ private Expression doProcess(ComparisonPredicate cp, Cast left, DecimalV3Literal } Expression castChild = left.child(); - Preconditions.checkState(castChild.getDataType() instanceof DecimalV3Type); + if (!(castChild.getDataType() instanceof DecimalV3Type)) { + throw new AnalysisException("cast child's type should be DecimalV3Type, but its type is " + + castChild.getDataType().toSql()); + } DecimalV3Type leftType = (DecimalV3Type) castChild.getDataType(); - if (scale <= leftType.getScale() && precision - scale <= leftType.getPrecision() - leftType.getScale()) { + if (scale <= leftType.getScale() && precision - scale <= leftType.getRange()) { // precision and scale of literal all smaller than left, we don't need the cast DecimalV3Literal newRight = new DecimalV3Literal( DecimalV3Type.createDecimalV3Type(leftType.getPrecision(), leftType.getScale()), - trailingZerosValue); + trailingZerosValue.setScale(leftType.getScale(), RoundingMode.UNNECESSARY)); return cp.withChildren(castChild, newRight); } else { return cp; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DecimalV3Literal.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DecimalV3Literal.java index c797e93cb6d673..d80dd7a4cc38c7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DecimalV3Literal.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DecimalV3Literal.java @@ -73,15 +73,11 @@ public double getDouble() { } public DecimalV3Literal roundCeiling(int newScale) { - return new DecimalV3Literal(DecimalV3Type - .createDecimalV3Type(((DecimalV3Type) dataType).getPrecision(), newScale), - value.setScale(newScale, RoundingMode.CEILING)); + return new DecimalV3Literal(value.setScale(newScale, RoundingMode.CEILING)); } public DecimalV3Literal roundFloor(int newScale) { - return new DecimalV3Literal(DecimalV3Type - .createDecimalV3Type(((DecimalV3Type) dataType).getPrecision(), newScale), - value.setScale(newScale, RoundingMode.FLOOR)); + return new DecimalV3Literal(value.setScale(newScale, RoundingMode.FLOOR)); } /** diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java index ba5d2b70ebabd9..0c15e39bc44fd0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java @@ -174,7 +174,7 @@ public static DataType convertPrimitiveFromStrings(List types, boolean t case "decimalv3": switch (types.size()) { case 1: - return DecimalV3Type.CATALOG_DEFAULT; + return DecimalV3Type.createDecimalV3Type(38, 9); case 2: return DecimalV3Type.createDecimalV3Type(Integer.parseInt(types.get(1))); case 3: diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyCastRuleTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyCastRuleTest.java index 658775cedad958..4799f70fbccd4b 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyCastRuleTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyCastRuleTest.java @@ -19,42 +19,45 @@ import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper; import org.apache.doris.nereids.rules.expression.ExpressionRuleExecutor; +import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal; +import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral; +import org.apache.doris.nereids.trees.expressions.literal.StringLiteral; +import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral; +import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral; import org.apache.doris.nereids.types.DecimalV3Type; import org.apache.doris.nereids.types.StringType; import org.apache.doris.nereids.types.VarcharType; import com.google.common.collect.ImmutableList; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import java.math.BigDecimal; + class SimplifyCastRuleTest extends ExpressionRewriteTestHelper { @Test public void testSimplify() { executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyCastRule.INSTANCE)); - assertRewriteAfterSimplify("CAST('1' AS STRING)", "'1'", - StringType.INSTANCE); - assertRewriteAfterSimplify("CAST('1' AS VARCHAR)", "'1'", - VarcharType.createVarcharType(-1)); - assertRewriteAfterSimplify("CAST(1 AS DECIMAL)", "1.000000000", - DecimalV3Type.createDecimalV3Type(38, 9)); - assertRewriteAfterSimplify("CAST(1000 AS DECIMAL)", "1000.000000000", - DecimalV3Type.createDecimalV3Type(38, 9)); - assertRewriteAfterSimplify("CAST(1 AS DECIMALV3)", "1", - DecimalV3Type.createDecimalV3Type(9, 0)); - assertRewriteAfterSimplify("CAST(1000 AS DECIMALV3)", "1000", - DecimalV3Type.createDecimalV3Type(9, 0)); + assertRewrite(new Cast(new VarcharLiteral("1"), StringType.INSTANCE), + new StringLiteral("1")); + assertRewrite(new Cast(new VarcharLiteral("1"), VarcharType.SYSTEM_DEFAULT), + new VarcharLiteral("1", -1)); + assertRewrite(new Cast(new TinyIntLiteral((byte) 1), DecimalV3Type.SYSTEM_DEFAULT), + new DecimalV3Literal(DecimalV3Type.SYSTEM_DEFAULT, new BigDecimal("1.000000000"))); + assertRewrite(new Cast(new SmallIntLiteral((short) 1000), DecimalV3Type.SYSTEM_DEFAULT), + new DecimalV3Literal(DecimalV3Type.SYSTEM_DEFAULT, new BigDecimal("1000.000000000"))); + assertRewrite(new Cast(new VarcharLiteral("1"), VarcharType.SYSTEM_DEFAULT), new VarcharLiteral("1", -1)); + assertRewrite(new Cast(new VarcharLiteral("1"), VarcharType.SYSTEM_DEFAULT), new VarcharLiteral("1", -1)); + + Expression decimalV3Literal = new DecimalV3Literal(DecimalV3Type.createDecimalV3Type(5, 3), + new BigDecimal("12.000")); + assertRewrite(new Cast(decimalV3Literal, DecimalV3Type.createDecimalV3Type(7, 3)), + new DecimalV3Literal(DecimalV3Type.createDecimalV3Type(7, 3), + new BigDecimal("12.000"))); + assertRewrite(new Cast(decimalV3Literal, DecimalV3Type.createDecimalV3Type(3, 1)), + new DecimalV3Literal(DecimalV3Type.createDecimalV3Type(3, 1), + new BigDecimal("12.0"))); } - - private void assertRewriteAfterSimplify(String expr, String expected, DataType expectedType) { - Expression needRewriteExpression = PARSER.parseExpression(expr); - Expression rewritten = SimplifyCastRule.INSTANCE.rewrite(needRewriteExpression, context); - Expression expectedExpression = PARSER.parseExpression(expected); - Assertions.assertEquals(expectedExpression.toSql(), rewritten.toSql()); - Assertions.assertEquals(expectedType, rewritten.getDataType()); - - } - } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java index 224fa652386469..122e0b444e7e95 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java @@ -19,25 +19,37 @@ import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper; import org.apache.doris.nereids.rules.expression.ExpressionRuleExecutor; +import org.apache.doris.nereids.trees.expressions.And; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.GreaterThan; import org.apache.doris.nereids.trees.expressions.GreaterThanEqual; +import org.apache.doris.nereids.trees.expressions.IsNull; import org.apache.doris.nereids.trees.expressions.LessThan; +import org.apache.doris.nereids.trees.expressions.LessThanEqual; +import org.apache.doris.nereids.trees.expressions.NullSafeEqual; +import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral; +import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; import org.apache.doris.nereids.trees.expressions.literal.DateLiteral; import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral; import org.apache.doris.nereids.trees.expressions.literal.DateTimeV2Literal; import org.apache.doris.nereids.trees.expressions.literal.DateV2Literal; +import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal; import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; +import org.apache.doris.nereids.types.BooleanType; import org.apache.doris.nereids.types.DateTimeV2Type; +import org.apache.doris.nereids.types.DecimalV3Type; import org.apache.doris.nereids.types.DoubleType; import com.google.common.collect.ImmutableList; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import java.math.BigDecimal; + class SimplifyComparisonPredicateTest extends ExpressionRewriteTestHelper { @Test void testSimplifyComparisonPredicateRule() { @@ -137,4 +149,120 @@ void testDoubleLiteral() { Assertions.assertEquals(left.child(0).getDataType(), rewrittenExpression.child(1).getDataType()); Assertions.assertEquals(rewrittenExpression.child(0).getDataType(), rewrittenExpression.child(1).getDataType()); } + + @Test + void testDecimalV3Literal() { + executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyComparisonPredicate.INSTANCE)); + + // should not simplify + Expression leftChild = new DecimalV3Literal(new BigDecimal("1.24")); + Expression left = new Cast(leftChild, DecimalV3Type.createDecimalV3Type(2, 1)); + Expression right = new DecimalV3Literal(new BigDecimal("1.2")); + Expression expression = new EqualTo(left, right); + Expression rewrittenExpression = executor.rewrite(expression, context); + Assertions.assertInstanceOf(Cast.class, rewrittenExpression.child(0)); + Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(2, 1), + rewrittenExpression.child(0).getDataType()); + + // = round UNNECESSARY + leftChild = new DecimalV3Literal(new BigDecimal("11.24")); + left = new Cast(leftChild, DecimalV3Type.createDecimalV3Type(5, 3)); + right = new DecimalV3Literal(new BigDecimal("12.340")); + expression = new EqualTo(left, right); + rewrittenExpression = executor.rewrite(expression, context); + Assertions.assertInstanceOf(DecimalV3Literal.class, rewrittenExpression.child(0)); + Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(4, 2), + rewrittenExpression.child(0).getDataType()); + Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(4, 2), + rewrittenExpression.child(1).getDataType()); + Assertions.assertInstanceOf(DecimalV3Literal.class, rewrittenExpression.child(1)); + Assertions.assertEquals(new BigDecimal("12.34"), ((DecimalV3Literal) rewrittenExpression.child(1)).getValue()); + + // = always not equals not null + leftChild = new DecimalV3Literal(new BigDecimal("11.24")); + left = new Cast(leftChild, DecimalV3Type.createDecimalV3Type(5, 3)); + right = new DecimalV3Literal(new BigDecimal("12.345")); + expression = new EqualTo(left, right); + rewrittenExpression = executor.rewrite(expression, context); + Assertions.assertEquals(BooleanLiteral.FALSE, rewrittenExpression); + + // = always not equals nullable + leftChild = new SlotReference("slot", DecimalV3Type.createDecimalV3Type(4, 2), true); + left = new Cast(leftChild, DecimalV3Type.createDecimalV3Type(5, 3)); + right = new DecimalV3Literal(new BigDecimal("12.345")); + expression = new EqualTo(left, right); + rewrittenExpression = executor.rewrite(expression, context); + Assertions.assertEquals(new And(new IsNull(leftChild), new NullLiteral(BooleanType.INSTANCE)), + rewrittenExpression); + + // <=> round UNNECESSARY + leftChild = new DecimalV3Literal(new BigDecimal("11.24")); + left = new Cast(leftChild, DecimalV3Type.createDecimalV3Type(5, 3)); + right = new DecimalV3Literal(new BigDecimal("12.340")); + expression = new NullSafeEqual(left, right); + rewrittenExpression = executor.rewrite(expression, context); + Assertions.assertInstanceOf(DecimalV3Literal.class, rewrittenExpression.child(0)); + Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(4, 2), + rewrittenExpression.child(0).getDataType()); + Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(4, 2), + rewrittenExpression.child(1).getDataType()); + Assertions.assertInstanceOf(DecimalV3Literal.class, rewrittenExpression.child(1)); + Assertions.assertEquals(new BigDecimal("12.34"), ((DecimalV3Literal) rewrittenExpression.child(1)).getValue()); + + // <=> always not equals + leftChild = new DecimalV3Literal(new BigDecimal("11.24")); + left = new Cast(leftChild, DecimalV3Type.createDecimalV3Type(5, 3)); + right = new DecimalV3Literal(new BigDecimal("12.345")); + expression = new NullSafeEqual(left, right); + rewrittenExpression = executor.rewrite(expression, context); + Assertions.assertEquals(BooleanLiteral.FALSE, rewrittenExpression); + + // > right literal should round floor + leftChild = new DecimalV3Literal(new BigDecimal("1.24")); + left = new Cast(leftChild, DecimalV3Type.createDecimalV3Type(5, 3)); + right = new DecimalV3Literal(new BigDecimal("12.345")); + expression = new GreaterThan(left, right); + rewrittenExpression = executor.rewrite(expression, context); + Assertions.assertInstanceOf(Cast.class, rewrittenExpression.child(0)); + Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(4, 2), + rewrittenExpression.child(0).getDataType()); + Assertions.assertInstanceOf(DecimalV3Literal.class, rewrittenExpression.child(1)); + Assertions.assertEquals(new BigDecimal("12.34"), ((DecimalV3Literal) rewrittenExpression.child(1)).getValue()); + + // <= right literal should round floor + leftChild = new DecimalV3Literal(new BigDecimal("1.24")); + left = new Cast(leftChild, DecimalV3Type.createDecimalV3Type(5, 3)); + right = new DecimalV3Literal(new BigDecimal("12.345")); + expression = new LessThanEqual(left, right); + rewrittenExpression = executor.rewrite(expression, context); + Assertions.assertInstanceOf(Cast.class, rewrittenExpression.child(0)); + Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(4, 2), + rewrittenExpression.child(0).getDataType()); + Assertions.assertInstanceOf(DecimalV3Literal.class, rewrittenExpression.child(1)); + Assertions.assertEquals(new BigDecimal("12.34"), ((DecimalV3Literal) rewrittenExpression.child(1)).getValue()); + + // >= right literal should round ceiling + leftChild = new DecimalV3Literal(new BigDecimal("1.24")); + left = new Cast(leftChild, DecimalV3Type.createDecimalV3Type(5, 3)); + right = new DecimalV3Literal(new BigDecimal("12.345")); + expression = new GreaterThanEqual(left, right); + rewrittenExpression = executor.rewrite(expression, context); + Assertions.assertInstanceOf(Cast.class, rewrittenExpression.child(0)); + Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(4, 2), + rewrittenExpression.child(0).getDataType()); + Assertions.assertInstanceOf(DecimalV3Literal.class, rewrittenExpression.child(1)); + Assertions.assertEquals(new BigDecimal("12.35"), ((DecimalV3Literal) rewrittenExpression.child(1)).getValue()); + + // < right literal should round ceiling + leftChild = new DecimalV3Literal(new BigDecimal("1.24")); + left = new Cast(leftChild, DecimalV3Type.createDecimalV3Type(5, 3)); + right = new DecimalV3Literal(new BigDecimal("12.345")); + expression = new LessThan(left, right); + rewrittenExpression = executor.rewrite(expression, context); + Assertions.assertInstanceOf(Cast.class, rewrittenExpression.child(0)); + Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(4, 2), + rewrittenExpression.child(0).getDataType()); + Assertions.assertInstanceOf(DecimalV3Literal.class, rewrittenExpression.child(1)); + Assertions.assertEquals(new BigDecimal("12.35"), ((DecimalV3Literal) rewrittenExpression.child(1)).getValue()); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3ComparisonTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3ComparisonTest.java index ff424e4971145b..edbbd872bac67d 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3ComparisonTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3ComparisonTest.java @@ -17,40 +17,49 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.common.Config; import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper; import org.apache.doris.nereids.rules.expression.ExpressionRuleExecutor; +import org.apache.doris.nereids.trees.expressions.Cast; +import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.Slot; -import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal; import org.apache.doris.nereids.types.DecimalV3Type; import com.google.common.collect.ImmutableList; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; -import java.util.HashMap; -import java.util.Map; +import java.math.BigDecimal; class SimplifyDecimalV3ComparisonTest extends ExpressionRewriteTestHelper { @Test - public void testSimplifyDecimalV3Comparison() { - Config.enable_decimal_conversion = false; - Map nameToSlot = new HashMap<>(); - nameToSlot.put("col1", new SlotReference("col1", DecimalV3Type.createDecimalV3Type(15, 2))); + void testChildScaleLargerThanCast() { executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyDecimalV3Comparison.INSTANCE)); - assertRewriteAfterSimplify("cast(col1 as decimalv3(27, 9)) > 0.6", "cast(col1 as decimalv3(27, 9)) > 0.6", nameToSlot); + Expression leftChild = new DecimalV3Literal(new BigDecimal("1.23456")); + Expression left = new Cast(leftChild, DecimalV3Type.createDecimalV3Type(3, 2)); + Expression right = new DecimalV3Literal(new BigDecimal("1.20")); + Expression expression = new EqualTo(left, right); + Expression rewrittenExpression = executor.rewrite(expression, context); + Assertions.assertInstanceOf(Cast.class, rewrittenExpression.child(0)); + Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(3, 2), + rewrittenExpression.child(0).getDataType()); } - private void assertRewriteAfterSimplify(String expr, String expected, Map slotNameToSlot) { - Expression needRewriteExpression = PARSER.parseExpression(expr); - if (slotNameToSlot != null) { - needRewriteExpression = replaceUnboundSlot(needRewriteExpression, slotNameToSlot); - } - Expression rewritten = SimplifyDecimalV3Comparison.INSTANCE.rewrite(needRewriteExpression, context); - Expression expectedExpression = PARSER.parseExpression(expected); - Assertions.assertEquals(expectedExpression.toSql(), rewritten.toSql()); - } + @Test + void testChildScaleSmallerThanCast() { + executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyDecimalV3Comparison.INSTANCE)); + Expression leftChild = new DecimalV3Literal(new BigDecimal("1.23456")); + Expression left = new Cast(leftChild, DecimalV3Type.createDecimalV3Type(10, 9)); + Expression right = new DecimalV3Literal(new BigDecimal("1.200000000")); + Expression expression = new EqualTo(left, right); + Expression rewrittenExpression = executor.rewrite(expression, context); + Assertions.assertInstanceOf(DecimalV3Literal.class, rewrittenExpression.child(0)); + Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(6, 5), + rewrittenExpression.child(0).getDataType()); + Assertions.assertInstanceOf(DecimalV3Literal.class, rewrittenExpression.child(1)); + Assertions.assertEquals(new BigDecimal("1.20000"), + ((DecimalV3Literal) rewrittenExpression.child(1)).getValue()); + } } diff --git a/regression-test/suites/nereids_rules_p0/expression/test_simplify_decimal_comparison.groovy b/regression-test/suites/nereids_rules_p0/expression/test_simplify_decimal_comparison.groovy new file mode 100644 index 00000000000000..103a66836c61d8 --- /dev/null +++ b/regression-test/suites/nereids_rules_p0/expression/test_simplify_decimal_comparison.groovy @@ -0,0 +1,28 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_simplify_decimal_comparison") { + test { + sql """SELECT 1 FROM DUAL WHERE CAST(2.2222 AS DECIMAL(26, 2)) = 2.22""" + result ([[1]]) + } + + test { + sql """ SELECT 1 FROM DUAL WHERE CAST(2.2222 AS DECIMAL(26, 2)) != 2.22 """ + result ([]) + } +}