This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d77bc702767 [SPARK-39989][SQL] Support estimate column statistics if 
it is foldable expression
d77bc702767 is described below

commit d77bc7027675b9823bcee96589eb0dddee3cfaa0
Author: Yuming Wang <[email protected]>
AuthorDate: Tue Aug 9 10:54:41 2022 -0700

    [SPARK-39989][SQL] Support estimate column statistics if it is foldable 
expression
    
    ### What changes were proposed in this pull request?
    
    This PR adds support estimate column statistics if it is foldable 
expression. For example: estimate the `'a' AS a`'s column statistics from 
`SELECT 'a' AS a FROM tbl`.
    
    1. If the foldable expression is null:
       ```scala
       ColumnStat(Some(0), None, None, Some(rowCount), Some(size), Some(size), 
None, 2)
       ```
    2. If the foldable expression is not null:
       ```scala
       ColumnStat(Some(1), Some(value), Some(value), Some(0), Some(size), 
Some(size), None, 2)
       ```
    
    ### Why are the changes needed?
    
    Improve column statistics.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Unit test.
    
    Closes #37421 from wangyum/SPARK-39989.
    
    Lead-authored-by: Yuming Wang <[email protected]>
    Co-authored-by: Yuming Wang <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 .../statsEstimation/AggregateEstimation.scala      |  3 +-
 .../logical/statsEstimation/EstimationUtils.scala  | 14 +++++++--
 .../statsEstimation/ProjectEstimation.scala        |  3 +-
 .../statsEstimation/ProjectEstimationSuite.scala   | 34 +++++++++++++++++++++-
 4 files changed, 49 insertions(+), 5 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
index ffe071ef25b..a18fd64b0c9 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
@@ -60,7 +60,8 @@ object AggregateEstimation {
         outputRows.min(childStats.rowCount.get)
       }
 
-      val aliasStats = EstimationUtils.getAliasStats(agg.expressions, 
childStats.attributeStats)
+      val aliasStats = EstimationUtils.getAliasStats(
+        agg.expressions, childStats.attributeStats, outputRows)
 
       val outputAttrStats = getOutputMap(
         AttributeMap(childStats.attributeStats.toSeq ++ aliasStats), 
agg.output)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
index dafb979767a..d645929eea7 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
@@ -20,7 +20,7 @@ package 
org.apache.spark.sql.catalyst.plans.logical.statsEstimation
 import scala.collection.mutable.ArrayBuffer
 import scala.math.BigDecimal.RoundingMode
 
-import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
AttributeMap, Expression}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
AttributeMap, EmptyRow, Expression}
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.types.{DecimalType, _}
 
@@ -82,10 +82,20 @@ object EstimationUtils {
    */
   def getAliasStats(
       expressions: Seq[Expression],
-      attributeStats: AttributeMap[ColumnStat]): Seq[(Attribute, ColumnStat)] 
= {
+      attributeStats: AttributeMap[ColumnStat],
+      rowCount: BigInt): Seq[(Attribute, ColumnStat)] = {
     expressions.collect {
       case alias @ Alias(attr: Attribute, _) if attributeStats.contains(attr) 
=>
         alias.toAttribute -> attributeStats(attr)
+      case alias @ Alias(expr: Expression, _) if expr.foldable && 
expr.deterministic =>
+        val value = expr.eval(EmptyRow)
+        val size = expr.dataType.defaultSize
+        val columnStat = if (value == null) {
+          ColumnStat(Some(0), None, None, Some(rowCount), Some(size), 
Some(size), None, 2)
+        } else {
+          ColumnStat(Some(1), Some(value), Some(value), Some(0), Some(size), 
Some(size), None, 2)
+        }
+        alias.toAttribute -> columnStat
     }
   }
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala
index 8e58c4f314d..a44d6262dd5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala
@@ -26,7 +26,8 @@ object ProjectEstimation {
   def estimate(project: Project): Option[Statistics] = {
     if (rowCountsExist(project.child)) {
       val childStats = project.child.stats
-      val aliasStats = EstimationUtils.getAliasStats(project.expressions, 
childStats.attributeStats)
+      val aliasStats = EstimationUtils.getAliasStats(
+        project.expressions, childStats.attributeStats, 
childStats.rowCount.get)
 
       val outputAttrStats =
         getOutputMap(AttributeMap(childStats.attributeStats.toSeq ++ 
aliasStats), project.output)
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala
index dcb37017329..8efb41d8b3c 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.statsEstimation
 
 import java.sql.{Date, Timestamp}
 
-import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
AttributeMap, AttributeReference}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
AttributeMap, AttributeReference, Literal}
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.types._
@@ -131,6 +131,38 @@ class ProjectEstimationSuite extends 
StatsEstimationTestBase {
       expectedRowCount = 2)
   }
 
+  test("SPARK-39989: Support estimate column statistics if it is foldable 
expression") {
+    val (ar1, colStat1) = (attr("key1"), ColumnStat(distinctCount = Some(2), 
min = Some(1),
+      max = Some(2), nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)))
+
+    val child = StatsTestPlan(
+      outputList = Seq(ar1),
+      rowCount = 2,
+      attributeStats = AttributeMap(Seq(ar1 -> colStat1)))
+
+    // nullable expression
+    val proj1 = Project(Seq(ar1, Alias(Literal(null, IntegerType), "v")()), 
child)
+    val expectedColStats1 = Seq(
+      "key1" -> colStat1,
+      "v" -> ColumnStat(Some(0), None, None, Some(2), Some(4), Some(4), None, 
2))
+    val expectedStats1 = Statistics(
+      sizeInBytes = 2 * (8 + 4 + 4),
+      rowCount = Some(2),
+      attributeStats = toAttributeMap(expectedColStats1, proj1))
+    assert(proj1.stats == expectedStats1)
+
+    // non-nullable expression
+    val proj2 = Project(Seq(ar1, Alias(Literal(10L, LongType), "v")()), child)
+    val expectedColStats2 = Seq(
+      "key1" -> colStat1,
+      "v" -> ColumnStat(Some(1), Some(10L), Some(10L), Some(0), Some(8), 
Some(8), None, 2))
+    val expectedStats2 = Statistics(
+      sizeInBytes = 2 * (8 + 4 + 8),
+      rowCount = Some(2),
+      attributeStats = toAttributeMap(expectedColStats2, proj2))
+    assert(proj2.stats == expectedStats2)
+  }
+
   private def checkProjectStats(
       child: LogicalPlan,
       projectAttrMap: AttributeMap[ColumnStat],


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to