From 9f9e46aed97526b017870bbccdae45b543c0305a Mon Sep 17 00:00:00 2001 From: Zouxxyy Date: Wed, 24 Apr 2024 16:15:01 +0800 Subject: [PATCH] [spark] Add the fields in reservedFilters into the estimation of stats (#3255) --- .../org/apache/paimon/spark/PaimonBaseScan.scala | 5 +++++ .../org/apache/paimon/spark/PaimonStatistics.scala | 2 +- .../spark/statistics/StatisticsHelperBase.scala | 13 +++++++------ .../paimon/spark/sql/AnalyzeTableTestBase.scala | 7 +++++++ 4 files changed, 20 insertions(+), 7 deletions(-) diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala index a5ba887232d7..7a49167a9709 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala @@ -59,6 +59,11 @@ abstract class PaimonBaseScan( lazy val statistics: Optional[stats.Statistics] = table.statistics() + lazy val requiredStatsSchema: StructType = { + val fieldNames = requiredSchema.fieldNames ++ reservedFilters.flatMap(_.references) + StructType(tableSchema.filter(field => fieldNames.contains(field.name))) + } + lazy val readBuilder: ReadBuilder = { val _readBuilder = table.newReadBuilder() diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonStatistics.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonStatistics.scala index d31820cb3a1e..abaac63822a6 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonStatistics.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonStatistics.scala @@ -48,7 +48,7 @@ case class PaimonStatistics[T <: PaimonBaseScan](scan: T) extends Statistics { if (paimonStats.isPresent) paimonStats.get().mergedRecordCount() else OptionalLong.of(rowCount) override def columnStats(): java.util.Map[NamedReference, ColumnStatistics] = { - val requiredFields = scan.readSchema().fieldNames.toList.asJava + val requiredFields = scan.requiredStatsSchema.fieldNames.toList.asJava val resultMap = new java.util.HashMap[NamedReference, ColumnStatistics]() if (paimonStats.isPresent) { val paimonColStats = paimonStats.get().colStats() diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/statistics/StatisticsHelperBase.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/statistics/StatisticsHelperBase.scala index 17eadf4f237e..1a76b2600a53 100644 --- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/statistics/StatisticsHelperBase.scala +++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/statistics/StatisticsHelperBase.scala @@ -30,17 +30,17 @@ import org.apache.spark.sql.connector.expressions.NamedReference import org.apache.spark.sql.connector.read.Statistics import org.apache.spark.sql.connector.read.colstats.ColumnStatistics import org.apache.spark.sql.sources.{And, Filter} -import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.sql.types.StructType import java.util.OptionalLong trait StatisticsHelperBase extends SQLConfHelper { - val requiredSchema: StructType + val requiredStatsSchema: StructType def filterStatistics(v2Stats: Statistics, filters: Seq[Filter]): Statistics = { val attrs: Seq[AttributeReference] = - requiredSchema.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) + requiredStatsSchema.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) val condition = filterToCondition(filters, attrs) if (condition.isDefined && v2Stats.numRows().isPresent) { @@ -56,14 +56,15 @@ trait StatisticsHelperBase extends SQLConfHelper { StructFilters.filterToExpression(filters.reduce(And), toRef).map { expression => expression.transform { - case ref: BoundReference => attrs.find(_.name == requiredSchema(ref.ordinal).name).get + case ref: BoundReference => + attrs.find(_.name == requiredStatsSchema(ref.ordinal).name).get } } } private def toRef(attr: String): Option[BoundReference] = { - val index = requiredSchema.fieldIndex(attr) - val field = requiredSchema(index) + val index = requiredStatsSchema.fieldIndex(attr) + val field = requiredStatsSchema(index) Option.apply(BoundReference(index, field.dataType, field.nullable)) } diff --git a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/AnalyzeTableTestBase.scala b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/AnalyzeTableTestBase.scala index dbe069be3f68..d1e6f73503e6 100644 --- a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/AnalyzeTableTestBase.scala +++ b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/AnalyzeTableTestBase.scala @@ -353,6 +353,13 @@ abstract class AnalyzeTableTestBase extends PaimonSparkTestBase { getScanStatistic(sql).rowCount.get.longValue()) checkAnswer(spark.sql(sql), Nil) + // partition push down hit and select without it + sql = "SELECT id FROM T WHERE pt < 1" + Assertions.assertEquals( + if (supportsColStats()) 0L else 4L, + getScanStatistic(sql).rowCount.get.longValue()) + checkAnswer(spark.sql(sql), Nil) + // partition push down not hit sql = "SELECT * FROM T WHERE id < 1" Assertions.assertEquals(4L, getScanStatistic(sql).rowCount.get.longValue())