Repository: spark
Updated Branches:
  refs/heads/master 0bf605c2c -> 039ed9fe8


[SPARK-19271][SQL] Change non-cbo estimation of aggregate

## What changes were proposed in this pull request?

Change non-cbo estimation behavior of aggregate:
- If groupExpression is empty, we can know row count (=1) and the corresponding 
size;
- otherwise, estimation falls back to UnaryNode's computeStats method, which 
should not propagate rowCount and attributeStats in Statistics because they are 
not estimated in that method.

## How was this patch tested?

Added test case

Author: wangzhenhua <wangzhen...@huawei.com>

Closes #16631 from wzhfy/aggNoCbo.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/039ed9fe
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/039ed9fe
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/039ed9fe

Branch: refs/heads/master
Commit: 039ed9fe8a2fdcd99e0561af64cda8fe3406bc12
Parents: 0bf605c
Author: wangzhenhua <wangzhen...@huawei.com>
Authored: Thu Jan 19 22:18:47 2017 -0800
Committer: gatorsmile <gatorsm...@gmail.com>
Committed: Thu Jan 19 22:18:47 2017 -0800

----------------------------------------------------------------------
 .../catalyst/plans/logical/LogicalPlan.scala    |  3 ++-
 .../plans/logical/basicLogicalOperators.scala   |  7 ++++--
 .../statsEstimation/AggregateEstimation.scala   |  2 +-
 .../statsEstimation/EstimationUtils.scala       |  4 ++--
 .../statsEstimation/ProjectEstimation.scala     |  2 +-
 .../AggregateEstimationSuite.scala              | 24 +++++++++++++++++++-
 .../StatsEstimationTestBase.scala               |  7 +++---
 7 files changed, 38 insertions(+), 11 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/039ed9fe/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 0587a59..93550e1 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -344,7 +344,8 @@ abstract class UnaryNode extends LogicalPlan {
       sizeInBytes = 1
     }
 
-    child.stats(conf).copy(sizeInBytes = sizeInBytes)
+    // Don't propagate rowCount and attributeStats, since they are not 
estimated here.
+    Statistics(sizeInBytes = sizeInBytes, isBroadcastable = 
child.stats(conf).isBroadcastable)
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/039ed9fe/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 3bd3143..432097d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogTable, 
CatalogTypes}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.plans._
-import 
org.apache.spark.sql.catalyst.plans.logical.statsEstimation.{AggregateEstimation,
 ProjectEstimation}
+import 
org.apache.spark.sql.catalyst.plans.logical.statsEstimation.{AggregateEstimation,
 EstimationUtils, ProjectEstimation}
 import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
 
@@ -541,7 +541,10 @@ case class Aggregate(
   override def computeStats(conf: CatalystConf): Statistics = {
     def simpleEstimation: Statistics = {
       if (groupingExpressions.isEmpty) {
-        super.computeStats(conf).copy(sizeInBytes = 1)
+        Statistics(
+          sizeInBytes = EstimationUtils.getOutputSize(output, outputRowCount = 
1),
+          rowCount = Some(1),
+          isBroadcastable = child.stats(conf).isBroadcastable)
       } else {
         super.computeStats(conf)
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/039ed9fe/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
index 21e94fc..ce74554 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
@@ -53,7 +53,7 @@ object AggregateEstimation {
 
       val outputAttrStats = getOutputMap(childStats.attributeStats, agg.output)
       Some(Statistics(
-        sizeInBytes = getOutputSize(agg.output, outputAttrStats, outputRows),
+        sizeInBytes = getOutputSize(agg.output, outputRows, outputAttrStats),
         rowCount = Some(outputRows),
         attributeStats = outputAttrStats,
         isBroadcastable = childStats.isBroadcastable))

http://git-wip-us.apache.org/repos/asf/spark/blob/039ed9fe/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
index cf4452d..e8b7942 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
@@ -37,8 +37,8 @@ object EstimationUtils {
 
   def getOutputSize(
       attributes: Seq[Attribute],
-      attrStats: AttributeMap[ColumnStat],
-      outputRowCount: BigInt): BigInt = {
+      outputRowCount: BigInt,
+      attrStats: AttributeMap[ColumnStat] = AttributeMap(Nil)): BigInt = {
     // We assign a generic overhead for a Row object, the actual overhead is 
different for different
     // Row format.
     val sizePerRow = 8 + attributes.map { attr =>

http://git-wip-us.apache.org/repos/asf/spark/blob/039ed9fe/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala
index 50b869a..e9084ad 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala
@@ -36,7 +36,7 @@ object ProjectEstimation {
       val outputAttrStats =
         getOutputMap(AttributeMap(inputAttrStats.toSeq ++ aliasStats), 
project.output)
       Some(childStats.copy(
-        sizeInBytes = getOutputSize(project.output, outputAttrStats, 
childStats.rowCount.get),
+        sizeInBytes = getOutputSize(project.output, childStats.rowCount.get, 
outputAttrStats),
         attributeStats = outputAttrStats))
     } else {
       None

http://git-wip-us.apache.org/repos/asf/spark/blob/039ed9fe/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala
index 41a4bc3..c0b9515 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala
@@ -90,6 +90,28 @@ class AggregateEstimationSuite extends 
StatsEstimationTestBase {
       expectedOutputRowCount = 0)
   }
 
+  test("non-cbo estimation") {
+    val attributes = Seq("key12").map(nameToAttr)
+    val child = StatsTestPlan(
+      outputList = attributes,
+      rowCount = 4,
+      // rowCount * (overhead + column size)
+      size = Some(4 * (8 + 4)),
+      attributeStats = AttributeMap(Seq("key12").map(nameToColInfo)))
+
+    val noGroupAgg = Aggregate(groupingExpressions = Nil,
+      aggregateExpressions = Seq(Alias(Count(Literal(1)), "cnt")()), child)
+    assert(noGroupAgg.stats(conf.copy(cboEnabled = false)) ==
+      // overhead + count result size
+      Statistics(sizeInBytes = 8 + 8, rowCount = Some(1)))
+
+    val hasGroupAgg = Aggregate(groupingExpressions = attributes,
+      aggregateExpressions = attributes :+ Alias(Count(Literal(1)), "cnt")(), 
child)
+    assert(hasGroupAgg.stats(conf.copy(cboEnabled = false)) ==
+      // From UnaryNode.computeStats, childSize * outputRowSize / childRowSize
+      Statistics(sizeInBytes = 48 * (8 + 4 + 8) / (8 + 4)))
+  }
+
   private def checkAggStats(
       tableColumns: Seq[String],
       tableRowCount: BigInt,
@@ -107,7 +129,7 @@ class AggregateEstimationSuite extends 
StatsEstimationTestBase {
 
     val expectedAttrStats = AttributeMap(groupByColumns.map(nameToColInfo))
     val expectedStats = Statistics(
-      sizeInBytes = getOutputSize(testAgg.output, expectedAttrStats, 
expectedOutputRowCount),
+      sizeInBytes = getOutputSize(testAgg.output, expectedOutputRowCount, 
expectedAttrStats),
       rowCount = Some(expectedOutputRowCount),
       attributeStats = expectedAttrStats)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/039ed9fe/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala
index e6adb67..a5fac4b 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala
@@ -45,11 +45,12 @@ class StatsEstimationTestBase extends SparkFunSuite {
 protected case class StatsTestPlan(
     outputList: Seq[Attribute],
     rowCount: BigInt,
-    attributeStats: AttributeMap[ColumnStat]) extends LeafNode {
+    attributeStats: AttributeMap[ColumnStat],
+    size: Option[BigInt] = None) extends LeafNode {
   override def output: Seq[Attribute] = outputList
   override def computeStats(conf: CatalystConf): Statistics = Statistics(
-    // sizeInBytes in stats of StatsTestPlan is useless in cbo estimation, we 
just use a fake value
-    sizeInBytes = Int.MaxValue,
+    // If sizeInBytes is useless in testing, we just use a fake value
+    sizeInBytes = size.getOrElse(Int.MaxValue),
     rowCount = Some(rowCount),
     attributeStats = attributeStats)
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to