Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix](Nereids) simplify decimal comparison wrong when cast to smaller scale (#41151) #42871

Merged
merged 1 commit into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.apache.doris.nereids.types.VarcharType;

import java.math.BigDecimal;
import java.math.RoundingMode;

/**
* Rewrite rule of simplify CAST expression.
Expand Down Expand Up @@ -107,8 +108,14 @@ private Expression simplify(Cast cast, ExpressionRewriteContext context) {
return new DecimalV3Literal(decimalV3Type,
new BigDecimal(((BigIntLiteral) child).getValue()));
} else if (child instanceof DecimalV3Literal) {
return new DecimalV3Literal(decimalV3Type,
((DecimalV3Literal) child).getValue());
DecimalV3Type childType = (DecimalV3Type) child.getDataType();
if (childType.getRange() <= decimalV3Type.getRange()) {
return new DecimalV3Literal(decimalV3Type,
((DecimalV3Literal) child).getValue()
.setScale(decimalV3Type.getScale(), RoundingMode.HALF_UP));
} else {
return cast;
}
}
}
} catch (Throwable t) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,10 @@ private Expression processDecimalV3TypeCoercion(ComparisonPredicate comparisonPr
int toScale = ((DecimalV3Type) left.getDataType()).getScale();
if (comparisonPredicate instanceof EqualTo) {
try {
return comparisonPredicate.withChildren(left,
new DecimalV3Literal((DecimalV3Type) left.getDataType(),
literal.getValue().setScale(toScale)));
Expression decimal = new DecimalV3Literal((DecimalV3Type) left.getDataType(),
literal.getValue().setScale(toScale, RoundingMode.UNNECESSARY));
return TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate)
comparisonPredicate.withChildren(left, decimal), left, decimal);
} catch (ArithmeticException e) {
if (left.nullable()) {
// TODO: the ideal way is to return an If expr like:
Expand All @@ -240,24 +241,27 @@ private Expression processDecimalV3TypeCoercion(ComparisonPredicate comparisonPr
}
} else if (comparisonPredicate instanceof NullSafeEqual) {
try {
return comparisonPredicate.withChildren(left,
new DecimalV3Literal((DecimalV3Type) left.getDataType(),
literal.getValue().setScale(toScale)));
Expression decimal = new DecimalV3Literal((DecimalV3Type) left.getDataType(),
literal.getValue().setScale(toScale, RoundingMode.UNNECESSARY));
return TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate)
comparisonPredicate.withChildren(left, decimal), left, decimal);
} catch (ArithmeticException e) {
return BooleanLiteral.of(false);
}
} else if (comparisonPredicate instanceof GreaterThan
|| comparisonPredicate instanceof LessThanEqual) {
return comparisonPredicate.withChildren(left, literal.roundFloor(toScale));
literal = literal.roundFloor(toScale);
return TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate)
comparisonPredicate.withChildren(left, literal), left, literal);
} else if (comparisonPredicate instanceof LessThan
|| comparisonPredicate instanceof GreaterThanEqual) {
return comparisonPredicate.withChildren(left,
literal.roundCeiling(toScale));
literal = literal.roundCeiling(toScale);
return TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate)
comparisonPredicate.withChildren(left, literal), left, literal);
}
}
} else if (left.getDataType().isIntegerLikeType()) {
return processIntegerDecimalLiteralComparison(comparisonPredicate, left,
literal.getValue());
return processIntegerDecimalLiteralComparison(comparisonPredicate, left, literal.getValue());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.doris.nereids.rules.expression.rules;

import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.Cast;
Expand All @@ -25,8 +26,6 @@
import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
import org.apache.doris.nereids.types.DecimalV3Type;

import com.google.common.base.Preconditions;

import java.math.BigDecimal;
import java.math.RoundingMode;

Expand All @@ -50,15 +49,17 @@ public Expression visitComparisonPredicate(ComparisonPredicate cp, ExpressionRew
if (left.getDataType() instanceof DecimalV3Type
&& left instanceof Cast
&& ((Cast) left).child().getDataType() instanceof DecimalV3Type
&& ((DecimalV3Type) left.getDataType()).getScale()
>= ((DecimalV3Type) ((Cast) left).child().getDataType()).getScale()
&& right instanceof DecimalV3Literal) {
return doProcess(cp, (Cast) left, (DecimalV3Literal) right);
try {
return doProcess(cp, (Cast) left, (DecimalV3Literal) right);
} catch (ArithmeticException e) {
return cp;
}
}

if (left != cp.left() || right != cp.right()) {
return cp.withChildren(left, right);
} else {
return cp;
}
return cp;
}

private Expression doProcess(ComparisonPredicate cp, Cast left, DecimalV3Literal right) {
Expand All @@ -72,13 +73,16 @@ private Expression doProcess(ComparisonPredicate cp, Cast left, DecimalV3Literal
}

Expression castChild = left.child();
Preconditions.checkState(castChild.getDataType() instanceof DecimalV3Type);
if (!(castChild.getDataType() instanceof DecimalV3Type)) {
throw new AnalysisException("cast child's type should be DecimalV3Type, but its type is "
+ castChild.getDataType().toSql());
}
DecimalV3Type leftType = (DecimalV3Type) castChild.getDataType();
if (scale <= leftType.getScale() && precision - scale <= leftType.getPrecision() - leftType.getScale()) {
if (scale <= leftType.getScale() && precision - scale <= leftType.getRange()) {
// precision and scale of literal all smaller than left, we don't need the cast
DecimalV3Literal newRight = new DecimalV3Literal(
DecimalV3Type.createDecimalV3Type(leftType.getPrecision(), leftType.getScale()),
trailingZerosValue);
trailingZerosValue.setScale(leftType.getScale(), RoundingMode.UNNECESSARY));
return cp.withChildren(castChild, newRight);
} else {
return cp;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,11 @@ public double getDouble() {
}

public DecimalV3Literal roundCeiling(int newScale) {
return new DecimalV3Literal(DecimalV3Type
.createDecimalV3Type(((DecimalV3Type) dataType).getPrecision(), newScale),
value.setScale(newScale, RoundingMode.CEILING));
return new DecimalV3Literal(value.setScale(newScale, RoundingMode.CEILING));
}

public DecimalV3Literal roundFloor(int newScale) {
return new DecimalV3Literal(DecimalV3Type
.createDecimalV3Type(((DecimalV3Type) dataType).getPrecision(), newScale),
value.setScale(newScale, RoundingMode.FLOOR));
return new DecimalV3Literal(value.setScale(newScale, RoundingMode.FLOOR));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ public static DataType convertPrimitiveFromStrings(List<String> types, boolean t
case "decimalv3":
switch (types.size()) {
case 1:
return DecimalV3Type.CATALOG_DEFAULT;
return DecimalV3Type.createDecimalV3Type(38, 9);
case 2:
return DecimalV3Type.createDecimalV3Type(Integer.parseInt(types.get(1)));
case 3:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,42 +19,45 @@

import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper;
import org.apache.doris.nereids.rules.expression.ExpressionRuleExecutor;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.types.VarcharType;

import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

import java.math.BigDecimal;

class SimplifyCastRuleTest extends ExpressionRewriteTestHelper {

@Test
public void testSimplify() {
executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyCastRule.INSTANCE));
assertRewriteAfterSimplify("CAST('1' AS STRING)", "'1'",
StringType.INSTANCE);
assertRewriteAfterSimplify("CAST('1' AS VARCHAR)", "'1'",
VarcharType.createVarcharType(-1));
assertRewriteAfterSimplify("CAST(1 AS DECIMAL)", "1.000000000",
DecimalV3Type.createDecimalV3Type(38, 9));
assertRewriteAfterSimplify("CAST(1000 AS DECIMAL)", "1000.000000000",
DecimalV3Type.createDecimalV3Type(38, 9));
assertRewriteAfterSimplify("CAST(1 AS DECIMALV3)", "1",
DecimalV3Type.createDecimalV3Type(9, 0));
assertRewriteAfterSimplify("CAST(1000 AS DECIMALV3)", "1000",
DecimalV3Type.createDecimalV3Type(9, 0));
assertRewrite(new Cast(new VarcharLiteral("1"), StringType.INSTANCE),
new StringLiteral("1"));
assertRewrite(new Cast(new VarcharLiteral("1"), VarcharType.SYSTEM_DEFAULT),
new VarcharLiteral("1", -1));
assertRewrite(new Cast(new TinyIntLiteral((byte) 1), DecimalV3Type.SYSTEM_DEFAULT),
new DecimalV3Literal(DecimalV3Type.SYSTEM_DEFAULT, new BigDecimal("1.000000000")));
assertRewrite(new Cast(new SmallIntLiteral((short) 1000), DecimalV3Type.SYSTEM_DEFAULT),
new DecimalV3Literal(DecimalV3Type.SYSTEM_DEFAULT, new BigDecimal("1000.000000000")));
assertRewrite(new Cast(new VarcharLiteral("1"), VarcharType.SYSTEM_DEFAULT), new VarcharLiteral("1", -1));
assertRewrite(new Cast(new VarcharLiteral("1"), VarcharType.SYSTEM_DEFAULT), new VarcharLiteral("1", -1));

Expression decimalV3Literal = new DecimalV3Literal(DecimalV3Type.createDecimalV3Type(5, 3),
new BigDecimal("12.000"));
assertRewrite(new Cast(decimalV3Literal, DecimalV3Type.createDecimalV3Type(7, 3)),
new DecimalV3Literal(DecimalV3Type.createDecimalV3Type(7, 3),
new BigDecimal("12.000")));
assertRewrite(new Cast(decimalV3Literal, DecimalV3Type.createDecimalV3Type(3, 1)),
new DecimalV3Literal(DecimalV3Type.createDecimalV3Type(3, 1),
new BigDecimal("12.0")));
}

private void assertRewriteAfterSimplify(String expr, String expected, DataType expectedType) {
Expression needRewriteExpression = PARSER.parseExpression(expr);
Expression rewritten = SimplifyCastRule.INSTANCE.rewrite(needRewriteExpression, context);
Expression expectedExpression = PARSER.parseExpression(expected);
Assertions.assertEquals(expectedExpression.toSql(), rewritten.toSql());
Assertions.assertEquals(expectedType, rewritten.getDataType());

}

}
Loading
Loading