This is an automated email from the ASF dual-hosted git repository.

lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new 9f9e46aed [spark] Add the fields in reservedFilters into the 
estimation of stats (#3255)
9f9e46aed is described below

commit 9f9e46aed97526b017870bbccdae45b543c0305a
Author: Zouxxyy <[email protected]>
AuthorDate: Wed Apr 24 16:15:01 2024 +0800

    [spark] Add the fields in reservedFilters into the estimation of stats 
(#3255)
---
 .../main/scala/org/apache/paimon/spark/PaimonBaseScan.scala |  5 +++++
 .../scala/org/apache/paimon/spark/PaimonStatistics.scala    |  2 +-
 .../paimon/spark/statistics/StatisticsHelperBase.scala      | 13 +++++++------
 .../org/apache/paimon/spark/sql/AnalyzeTableTestBase.scala  |  7 +++++++
 4 files changed, 20 insertions(+), 7 deletions(-)

diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala
index a5ba88723..7a49167a9 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala
@@ -59,6 +59,11 @@ abstract class PaimonBaseScan(
 
   lazy val statistics: Optional[stats.Statistics] = table.statistics()
 
+  lazy val requiredStatsSchema: StructType = {
+    val fieldNames = requiredSchema.fieldNames ++ 
reservedFilters.flatMap(_.references)
+    StructType(tableSchema.filter(field => fieldNames.contains(field.name)))
+  }
+
   lazy val readBuilder: ReadBuilder = {
     val _readBuilder = table.newReadBuilder()
 
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonStatistics.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonStatistics.scala
index d31820cb3..abaac6382 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonStatistics.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonStatistics.scala
@@ -48,7 +48,7 @@ case class PaimonStatistics[T <: PaimonBaseScan](scan: T) 
extends Statistics {
     if (paimonStats.isPresent) paimonStats.get().mergedRecordCount() else 
OptionalLong.of(rowCount)
 
   override def columnStats(): java.util.Map[NamedReference, ColumnStatistics] 
= {
-    val requiredFields = scan.readSchema().fieldNames.toList.asJava
+    val requiredFields = scan.requiredStatsSchema.fieldNames.toList.asJava
     val resultMap = new java.util.HashMap[NamedReference, ColumnStatistics]()
     if (paimonStats.isPresent) {
       val paimonColStats = paimonStats.get().colStats()
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/statistics/StatisticsHelperBase.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/statistics/StatisticsHelperBase.scala
index 17eadf4f2..1a76b2600 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/statistics/StatisticsHelperBase.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/statistics/StatisticsHelperBase.scala
@@ -30,17 +30,17 @@ import 
org.apache.spark.sql.connector.expressions.NamedReference
 import org.apache.spark.sql.connector.read.Statistics
 import org.apache.spark.sql.connector.read.colstats.ColumnStatistics
 import org.apache.spark.sql.sources.{And, Filter}
-import org.apache.spark.sql.types.{StructField, StructType}
+import org.apache.spark.sql.types.StructType
 
 import java.util.OptionalLong
 
 trait StatisticsHelperBase extends SQLConfHelper {
 
-  val requiredSchema: StructType
+  val requiredStatsSchema: StructType
 
   def filterStatistics(v2Stats: Statistics, filters: Seq[Filter]): Statistics 
= {
     val attrs: Seq[AttributeReference] =
-      requiredSchema.map(f => AttributeReference(f.name, f.dataType, 
f.nullable, f.metadata)())
+      requiredStatsSchema.map(f => AttributeReference(f.name, f.dataType, 
f.nullable, f.metadata)())
     val condition = filterToCondition(filters, attrs)
 
     if (condition.isDefined && v2Stats.numRows().isPresent) {
@@ -56,14 +56,15 @@ trait StatisticsHelperBase extends SQLConfHelper {
     StructFilters.filterToExpression(filters.reduce(And), toRef).map {
       expression =>
         expression.transform {
-          case ref: BoundReference => attrs.find(_.name == 
requiredSchema(ref.ordinal).name).get
+          case ref: BoundReference =>
+            attrs.find(_.name == requiredStatsSchema(ref.ordinal).name).get
         }
     }
   }
 
   private def toRef(attr: String): Option[BoundReference] = {
-    val index = requiredSchema.fieldIndex(attr)
-    val field = requiredSchema(index)
+    val index = requiredStatsSchema.fieldIndex(attr)
+    val field = requiredStatsSchema(index)
     Option.apply(BoundReference(index, field.dataType, field.nullable))
   }
 
diff --git 
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/AnalyzeTableTestBase.scala
 
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/AnalyzeTableTestBase.scala
index dbe069be3..d1e6f7350 100644
--- 
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/AnalyzeTableTestBase.scala
+++ 
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/AnalyzeTableTestBase.scala
@@ -353,6 +353,13 @@ abstract class AnalyzeTableTestBase extends 
PaimonSparkTestBase {
       getScanStatistic(sql).rowCount.get.longValue())
     checkAnswer(spark.sql(sql), Nil)
 
+    // partition push down hit and select without it
+    sql = "SELECT id FROM T WHERE pt < 1"
+    Assertions.assertEquals(
+      if (supportsColStats()) 0L else 4L,
+      getScanStatistic(sql).rowCount.get.longValue())
+    checkAnswer(spark.sql(sql), Nil)
+
     // partition push down not hit
     sql = "SELECT * FROM T WHERE id < 1"
     Assertions.assertEquals(4L, getScanStatistic(sql).rowCount.get.longValue())

Reply via email to