Skip to content

Commit

Permalink
[Fix](bug) Percentile* func core when percent args is negative number
Browse files Browse the repository at this point in the history
  • Loading branch information
HappenLee committed Jan 16, 2025
1 parent 4c40b4e commit 31ce20b
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BigIntType;
Expand Down Expand Up @@ -50,7 +51,6 @@ public class Percentile extends NullableAggregateFunction
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE, DoubleType.INSTANCE)

);

/**
Expand Down Expand Up @@ -79,6 +79,22 @@ public void checkLegalityBeforeTypeCoercion() {
}
}

@Override
public void checkLegalityAfterRewrite() {
Expression arg1 = getArgument(1);
if (!(arg1 instanceof DoubleLiteral)) {
throw new AnalysisException(
"percentile requires second parameter must be a constant double: " + this.toSql());
}
DoubleLiteral data = (DoubleLiteral) arg1;
double realData = data.getValue();
if (realData < 0 || realData > 1) {
throw new AnalysisException(
"percentile requires second parameter must be a constant double in [0, 1]: " + this.toSql()
+ " but the value is " + realData);
}
}

/**
* withDistinctAndChildren.
*/
Expand All @@ -89,13 +105,13 @@ public Percentile withDistinctAndChildren(boolean distinct, List<Expression> chi
}

@Override
public NullableAggregateFunction withAlwaysNullable(boolean alwaysNullable) {
return new Percentile(distinct, alwaysNullable, children.get(0), children.get(1));
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitPercentile(this, context);
}

@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitPercentile(this, context);
public NullableAggregateFunction withAlwaysNullable(boolean alwaysNullable) {
return new Percentile(distinct, alwaysNullable, children.get(0), children.get(1));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DoubleType;

Expand Down Expand Up @@ -92,6 +93,22 @@ public void checkLegalityBeforeTypeCoercion() {
}
}

@Override
public void checkLegalityAfterRewrite() {
Expression arg1 = getArgument(1);
if (!(arg1 instanceof DoubleLiteral)) {
throw new AnalysisException(
"percentile_approx requires second parameter must be a constant double: " + this.toSql());
}
DoubleLiteral data = (DoubleLiteral) arg1;
double realData = data.getValue();
if (realData < 0 || realData > 1) {
throw new AnalysisException(
"percentile_approx requires second parameter must be a constant double in [0, 1]: " + this.toSql()
+ " but the value is " + realData);
}
}

/**
* withDistinctAndChildren.
*/
Expand All @@ -106,6 +123,11 @@ public PercentileApprox withDistinctAndChildren(boolean distinct, List<Expressio
}
}

@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitPercentileApprox(this, context);
}

@Override
public PercentileApprox withAlwaysNullable(boolean alwaysNullable) {
if (children.size() == 2) {
Expand All @@ -115,11 +137,6 @@ public PercentileApprox withAlwaysNullable(boolean alwaysNullable) {
}
}

@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitPercentileApprox(this, context);
}

@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DoubleType;

Expand Down Expand Up @@ -97,6 +98,22 @@ public void checkLegalityBeforeTypeCoercion() {
}
}

@Override
public void checkLegalityAfterRewrite() {
Expression arg2 = getArgument(2);
if (!(arg2 instanceof DoubleLiteral)) {
throw new AnalysisException(
"percentile_approx_weighted requires third parameter must be a constant double: " + this.toSql());
}
DoubleLiteral data = (DoubleLiteral) arg2;
double realData = data.getValue();
if (realData < 0 || realData > 1) {
throw new AnalysisException(
"percentile_approx_weighted requires third parameter must be a constant double in [0, 1]: "
+ this.toSql() + " but the value is " + realData);
}
}

/**
* withDistinctAndChildren.
*/
Expand All @@ -113,6 +130,11 @@ public PercentileApproxWeighted withDistinctAndChildren(boolean distinct,
}
}

@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitPercentileApprox(this, context);
}

@Override
public PercentileApproxWeighted withAlwaysNullable(boolean alwaysNullable) {
if (children.size() == 3) {
Expand All @@ -124,11 +146,6 @@ public PercentileApproxWeighted withAlwaysNullable(boolean alwaysNullable) {
}
}

@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitPercentileApprox(this, context);
}

@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@
package org.apache.doris.nereids.trees.expressions.functions.agg;

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.literal.ArrayLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
Expand Down Expand Up @@ -74,6 +77,36 @@ public PercentileArray(boolean distinct, Expression arg0, Expression arg1) {
super("percentile_array", distinct, arg0, arg1);
}

@Override
public void checkLegalityBeforeTypeCoercion() {
if (!getArgument(1).isConstant()) {
throw new AnalysisException(
"percentile_array requires second parameter must be a constant : " + this.toSql());
}
}

@Override
public void checkLegalityAfterRewrite() {
Expression arg1 = getArgument(1);
if (!(arg1 instanceof ArrayLiteral)) {
throw new AnalysisException(
"percentile_approx requires second parameter must be a constant array: " + this.toSql());
}
ArrayLiteral data = (ArrayLiteral) arg1;
for (Literal d : data.getValue()) {
if (!(d instanceof DoubleLiteral)) {
throw new AnalysisException(
"percentile_array requires second parameter must be a constant array[double]: " + this.toSql());
}
double realData = ((DoubleLiteral) d).getValue();
if (realData < 0 || realData > 1) {
throw new AnalysisException(
"percentile_array requires second parameter must be a constant array[double], "
+ "double value in [0, 1]: " + this.toSql() + " but the value is " + realData);
}
}
}

/**
* withDistinctAndChildren.
*/
Expand Down
17 changes: 17 additions & 0 deletions regression-test/suites/query_p0/aggregate/aggregate.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,23 @@ suite("aggregate") {
qt_aggregate32" select topn_weighted(c_string,c_bigint,3) from ${tableName}"
qt_aggregate33" select avg_weighted(c_double,c_bigint) from ${tableName};"
qt_aggregate34" select percentile_array(c_bigint,[0.2,0.5,0.9]) from ${tableName};"

try {
sql "select percentile_array(c_bigint,[-1,0.5,0.9]) from ${tableName};"
} catch (Exception ex) {
assert("${ex}".contains("-1"))
}
try {
sql "select percentile_array(c_bigint,[0.5,0.9,3000]) from ${tableName};"
} catch (Exception ex) {
assert("${ex}".contains("3000"))
}
try {
sql "select percentile_array(c_bigint,[0.5,0.9,null]) from ${tableName};"
} catch (Exception ex) {
assert("${ex}".contains("double"))
}

qt_aggregate """
SELECT c_bigint,
CASE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,23 @@ suite("test_aggregate_all_functions", "arrow_flight_sql") {
qt_select21_1 "select id,percentile(level + 0.1,0.55) from ${tableName_13} group by id order by id"
qt_select22_1 "select id,percentile(level + 0.1,0.805) from ${tableName_13} group by id order by id"

try {
sql "select id,percentile(level + 0.1, -1) from ${tableName_13} group by id order by id"
} catch (Exception ex) {
assert("${ex}".contains("-1"))
}
try {
sql "select id,percentile(level + 0.1, 3000) from ${tableName_13} group by id order by id"
} catch (Exception ex) {
assert("${ex}".contains("3000"))
}
try {
sql "select id,percentile(level + 0.1, null) from ${tableName_13} group by id order by id"
} catch (Exception ex) {
assert("${ex}".contains("double"))
}


sql "DROP TABLE IF EXISTS ${tableName_13}"


Expand Down Expand Up @@ -314,6 +331,23 @@ suite("test_aggregate_all_functions", "arrow_flight_sql") {
qt_select27 "select id,PERCENTILE_APPROX(level,0.55,2048) from ${tableName_14} group by id order by id"
qt_select28 "select id,PERCENTILE_APPROX(level,0.805,2048) from ${tableName_14} group by id order by id"

try {
sql "select id,PERCENTILE_APPROX(level, -1, 2048) from ${tableName_14} group by id order by id"
} catch (Exception ex) {
assert("${ex}".contains("-1"))
}
try {
sql "select id,PERCENTILE_APPROX(level, 3000 ,2048) from ${tableName_14} group by id order by id"
} catch (Exception ex) {
assert("${ex}".contains("3000"))
}
try {
sql "select id,PERCENTILE_APPROX(level, null ,2048) from ${tableName_14} group by id order by id"
} catch (Exception ex) {
assert("${ex}".contains("double"))
}


sql "DROP TABLE IF EXISTS ${tableName_14}"


Expand Down

0 comments on commit 31ce20b

Please sign in to comment.