diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/FilterEstimationTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/FilterEstimationTest.java index f5491c633313e00..ecf773715fd570d 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/FilterEstimationTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/FilterEstimationTest.java @@ -34,6 +34,7 @@ import org.apache.doris.nereids.trees.expressions.literal.DateLiteral; import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral; import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; +import org.apache.doris.nereids.trees.plans.algebra.Generate; import org.apache.doris.nereids.types.DateType; import org.apache.doris.nereids.types.DoubleType; import org.apache.doris.nereids.types.IntegerType; @@ -64,10 +65,10 @@ public void testOrNaN() { Or or = new Or(greaterThan1, lessThan); Map columnStat = new HashMap<>(); ColumnStatistic aStats = new ColumnStatisticBuilder().setCount(500).setNdv(500).setAvgSizeByte(4) - .setNumNulls(500).setDataSize(0) + .setNumNulls(0).setDataSize(0) .setMinValue(0).setMaxValue(1000).setMinExpr(null).build(); ColumnStatistic bStats = new ColumnStatisticBuilder().setCount(500).setNdv(500).setAvgSizeByte(4) - .setNumNulls(500).setDataSize(0) + .setNumNulls(0).setDataSize(0) .setMinValue(0).setMaxValue(1000).setMinExpr(null).setIsUnknown(true).build(); columnStat.put(a, aStats); columnStat.put(b, bStats); @@ -93,10 +94,10 @@ public void testAndNaN() { And and = new And(greaterThan1, lessThan); Map columnStat = new HashMap<>(); ColumnStatistic aStats = new ColumnStatisticBuilder().setCount(500).setNdv(500) - .setAvgSizeByte(4).setNumNulls(500).setDataSize(0) + .setAvgSizeByte(4).setNumNulls(0).setDataSize(0) .setMinValue(0).setMaxValue(1000).setMinExpr(null).build(); ColumnStatistic bStats = new ColumnStatisticBuilder().setCount(500).setNdv(500) - .setAvgSizeByte(4).setNumNulls(500).setDataSize(0) + .setAvgSizeByte(4).setNumNulls(0).setDataSize(0) .setMinValue(0).setMaxValue(1000).setMinExpr(null).setIsUnknown(true).build(); columnStat.put(a, aStats); columnStat.put(b, bStats); @@ -185,13 +186,13 @@ public void test1() { Or or = new Or(and, equalTo); Map slotToColumnStat = new HashMap<>(); ColumnStatistic aStats = new ColumnStatisticBuilder().setCount(500).setNdv(500) - .setAvgSizeByte(4).setNumNulls(500).setDataSize(0) + .setAvgSizeByte(4).setNumNulls(0).setDataSize(0) .setMinValue(0).setMaxValue(1000).setMinExpr(null).build(); ColumnStatistic bStats = new ColumnStatisticBuilder().setCount(500).setNdv(500) - .setAvgSizeByte(4).setNumNulls(500).setDataSize(0) + .setAvgSizeByte(4).setNumNulls(0).setDataSize(0) .setMinValue(0).setMaxValue(1000).setMinExpr(null).build(); ColumnStatistic cStats = new ColumnStatisticBuilder().setCount(500).setNdv(500) - .setAvgSizeByte(4).setNumNulls(500).setDataSize(0) + .setAvgSizeByte(4).setNumNulls(0).setDataSize(0) .setMinValue(0).setMaxValue(1000).setMinExpr(null).build(); slotToColumnStat.put(a, aStats); slotToColumnStat.put(b, bStats); @@ -910,4 +911,193 @@ public void testIsNotNull() { Statistics result = filterEstimation.estimate(not, stats); Assertions.assertEquals(result.getRowCount(), 90); } + + /** + * a = 1 + */ + @Test + public void testNumNullsEqualTo() { + SlotReference a = new SlotReference("a", IntegerType.INSTANCE); + ColumnStatisticBuilder builder = new ColumnStatisticBuilder() + .setNdv(2) + .setAvgSizeByte(4) + .setNumNulls(8) + .setMaxValue(2) + .setMinValue(1) + .setCount(10); + IntegerLiteral int1 = new IntegerLiteral(1); + EqualTo equalTo = new EqualTo(a, int1); + Statistics stats = new Statistics(10, new HashMap<>()); + stats.addColumnStats(a, builder.build()); + FilterEstimation filterEstimation = new FilterEstimation(); + Statistics result = filterEstimation.estimate(equalTo, stats); + Assertions.assertEquals(result.getRowCount(), 1.0, 0.01); + } + + /** + * a > 1 + */ + @Test + public void testNumNullsComparable() { + SlotReference a = new SlotReference("a", IntegerType.INSTANCE); + ColumnStatisticBuilder builder = new ColumnStatisticBuilder() + .setNdv(2) + .setAvgSizeByte(4) + .setNumNulls(8) + .setMaxValue(2) + .setMinValue(1) + .setCount(10); + IntegerLiteral int1 = new IntegerLiteral(1); + GreaterThan greaterThan = new GreaterThan(a, int1); + Statistics stats = new Statistics(10, new HashMap<>()); + stats.addColumnStats(a, builder.build()); + FilterEstimation filterEstimation = new FilterEstimation(); + Statistics result = filterEstimation.estimate(greaterThan, stats); + Assertions.assertEquals(result.getRowCount(), 2.0, 0.01); + } + + /** + * a in (1, 2) + */ + @Test + public void testNumNullsIn() { + SlotReference a = new SlotReference("a", IntegerType.INSTANCE); + ColumnStatisticBuilder builder = new ColumnStatisticBuilder() + .setNdv(2) + .setAvgSizeByte(4) + .setNumNulls(8) + .setMaxValue(2) + .setMinValue(1) + .setCount(10); + IntegerLiteral int1 = new IntegerLiteral(1); + IntegerLiteral int2 = new IntegerLiteral(2); + InPredicate in = new InPredicate(a, Lists.newArrayList(int1, int2)); + Statistics stats = new Statistics(10, new HashMap<>()); + stats.addColumnStats(a, builder.build()); + FilterEstimation filterEstimation = new FilterEstimation(); + Statistics result = filterEstimation.estimate(in, stats); + Assertions.assertEquals(result.getRowCount(), 10.0, 0.01); + } + + /** + * not a = 1 + */ + @Test + public void testNumNullsNotEqualTo() { + SlotReference a = new SlotReference("a", IntegerType.INSTANCE); + ColumnStatisticBuilder builder = new ColumnStatisticBuilder() + .setNdv(2) + .setAvgSizeByte(4) + .setNumNulls(8) + .setMaxValue(2) + .setMinValue(1) + .setCount(10); + IntegerLiteral int1 = new IntegerLiteral(1); + EqualTo equalTo = new EqualTo(a, int1); + Not not = new Not(equalTo); + Statistics stats = new Statistics(10, new HashMap<>()); + stats.addColumnStats(a, builder.build()); + FilterEstimation filterEstimation = new FilterEstimation(); + Statistics result = filterEstimation.estimate(not, stats); + Assertions.assertEquals(result.getRowCount(), 1.0, 0.01); + } + + /** + * a not in (1, 2) + */ + @Test + public void testNumNullsNotIn() { + SlotReference a = new SlotReference("a", IntegerType.INSTANCE); + ColumnStatisticBuilder builder = new ColumnStatisticBuilder() + .setNdv(2) + .setAvgSizeByte(4) + .setNumNulls(8) + .setMaxValue(2) + .setMinValue(1) + .setCount(10); + IntegerLiteral int1 = new IntegerLiteral(1); + IntegerLiteral int2 = new IntegerLiteral(2); + InPredicate in = new InPredicate(a, Lists.newArrayList(int1, int2)); + Not not = new Not(in); + Statistics stats = new Statistics(10, new HashMap<>()); + stats.addColumnStats(a, builder.build()); + FilterEstimation filterEstimation = new FilterEstimation(); + Statistics result = filterEstimation.estimate(not, stats); + Assertions.assertEquals(result.getRowCount(), 1.0, 0.01); + } + + /** + * a >= 1 and a <= 2 + */ + @Test + public void testNumNullsAnd() { + SlotReference a = new SlotReference("a", IntegerType.INSTANCE); + ColumnStatisticBuilder builder = new ColumnStatisticBuilder() + .setNdv(2) + .setAvgSizeByte(4) + .setNumNulls(8) + .setMaxValue(2) + .setMinValue(1) + .setCount(10); + IntegerLiteral int1 = new IntegerLiteral(1); + IntegerLiteral int2 = new IntegerLiteral(2); + GreaterThanEqual greaterThanEqual = new GreaterThanEqual(a, int1); + LessThanEqual lessThanEqual = new LessThanEqual(a, int2); + And and = new And(greaterThanEqual, lessThanEqual); + Statistics stats = new Statistics(10, new HashMap<>()); + stats.addColumnStats(a, builder.build()); + FilterEstimation filterEstimation = new FilterEstimation(); + Statistics result = filterEstimation.estimate(and, stats); + Assertions.assertEquals(result.getRowCount(), 2.0, 0.01); + } + + /** + * a >= 1 or a <= 2 + */ + @Test + public void testNumNullsOr() { + SlotReference a = new SlotReference("a", IntegerType.INSTANCE); + ColumnStatisticBuilder builder = new ColumnStatisticBuilder() + .setNdv(2) + .setAvgSizeByte(4) + .setNumNulls(8) + .setMaxValue(2) + .setMinValue(1) + .setCount(10); + IntegerLiteral int1 = new IntegerLiteral(1); + IntegerLiteral int2 = new IntegerLiteral(2); + GreaterThanEqual greaterThanEqual = new GreaterThanEqual(a, int2); + LessThanEqual lessThanEqual = new LessThanEqual(a, int1); + Or or = new Or(greaterThanEqual, lessThanEqual); + Statistics stats = new Statistics(10, new HashMap<>()); + stats.addColumnStats(a, builder.build()); + FilterEstimation filterEstimation = new FilterEstimation(); + Statistics result = filterEstimation.estimate(or, stats); + Assertions.assertEquals(result.getRowCount(), 2.0, 0.01); + } + + /** + * a >= 1 or a is null + */ + @Test + public void testNumNullsOrIsNull() { + SlotReference a = new SlotReference("a", IntegerType.INSTANCE); + ColumnStatisticBuilder builder = new ColumnStatisticBuilder() + .setNdv(2) + .setAvgSizeByte(4) + .setNumNulls(8) + .setMaxValue(2) + .setMinValue(1) + .setCount(10); + IntegerLiteral int1 = new IntegerLiteral(1); + GreaterThanEqual greaterThanEqual = new GreaterThanEqual(a, int1); + IsNull isNull = new IsNull(a); + Or or = new Or(greaterThanEqual, isNull); + Statistics stats = new Statistics(10, new HashMap<>()); + stats.addColumnStats(a, builder.build()); + FilterEstimation filterEstimation = new FilterEstimation(); + Statistics result = filterEstimation.estimate(or, stats); + Assertions.assertEquals(result.getRowCount(), 10.0, 0.01); + } + }