This is an automated email from the ASF dual-hosted git repository. yamamuro pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new b1857a4 [SPARK-26894][SQL] Handle Alias as well in AggregateEstimation to propagate child stats b1857a4 is described below commit b1857a4d7dfe17663f8adccd7825d890ae70d2a1 Author: Venkata krishnan Sowrirajan <vsowrira...@qubole.com> AuthorDate: Thu Mar 21 11:21:56 2019 +0900 [SPARK-26894][SQL] Handle Alias as well in AggregateEstimation to propagate child stats ## What changes were proposed in this pull request? Currently aliases are not handled in the Aggregate Estimation due to which stats are not getting propagated. This causes CBO join-reordering to not give optimal join plans. ProjectEstimation is already taking care of aliases, we need same logic for AggregateEstimation as well to properly propagate stats when CBO is enabled. ## How was this patch tested? This patch is manually tested using the query Q83 of TPCDS benchmark (scale 1000) Closes #23803 from venkata91/aggstats. Authored-by: Venkata krishnan Sowrirajan <vsowrira...@qubole.com> Signed-off-by: Takeshi Yamamuro <yamam...@apache.org> --- .../statsEstimation/AggregateEstimation.scala | 7 +++++-- .../logical/statsEstimation/EstimationUtils.scala | 14 ++++++++++++- .../statsEstimation/ProjectEstimation.scala | 10 +++------ .../statsEstimation/AggregateEstimationSuite.scala | 24 ++++++++++++++++++++++ 4 files changed, 45 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala index eb56ab4..0606d0d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Statistics} @@ -52,7 +52,10 @@ object AggregateEstimation { outputRows.min(childStats.rowCount.get) } - val outputAttrStats = getOutputMap(childStats.attributeStats, agg.output) + val aliasStats = EstimationUtils.getAliasStats(agg.expressions, childStats.attributeStats) + + val outputAttrStats = getOutputMap( + AttributeMap(childStats.attributeStats.toSeq ++ aliasStats), agg.output) Some(Statistics( sizeInBytes = getOutputSize(agg.output, outputRows, outputAttrStats), rowCount = Some(outputRows), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index 211a2a0..11d2f02 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation import scala.collection.mutable.ArrayBuffer import scala.math.BigDecimal.RoundingMode -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, Expression} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types.{DecimalType, _} @@ -71,6 +71,18 @@ object EstimationUtils { AttributeMap(output.flatMap(a => inputMap.get(a).map(a -> _))) } + /** + * Returns the stats for aliases of child's attributes + */ + def getAliasStats( + expressions: Seq[Expression], + attributeStats: AttributeMap[ColumnStat]): Seq[(Attribute, ColumnStat)] = { + expressions.collect { + case alias @ Alias(attr: Attribute, _) if attributeStats.contains(attr) => + alias.toAttribute -> attributeStats(attr) + } + } + def getSizePerRow( attributes: Seq[Attribute], attrStats: AttributeMap[ColumnStat] = AttributeMap(Nil)): BigInt = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala index 489eb90..6925423 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala @@ -26,14 +26,10 @@ object ProjectEstimation { def estimate(project: Project): Option[Statistics] = { if (rowCountsExist(project.child)) { val childStats = project.child.stats - val inputAttrStats = childStats.attributeStats - // Match alias with its child's column stat - val aliasStats = project.expressions.collect { - case alias @ Alias(attr: Attribute, _) if inputAttrStats.contains(attr) => - alias.toAttribute -> inputAttrStats(attr) - } + val aliasStats = EstimationUtils.getAliasStats(project.expressions, childStats.attributeStats) + val outputAttrStats = - getOutputMap(AttributeMap(inputAttrStats.toSeq ++ aliasStats), project.output) + getOutputMap(AttributeMap(childStats.attributeStats.toSeq ++ aliasStats), project.output) Some(childStats.copy( sizeInBytes = getOutputSize(project.output, childStats.rowCount.get, outputAttrStats), attributeStats = outputAttrStats)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala index 8213d56..dfa6e46 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala @@ -45,6 +45,30 @@ class AggregateEstimationSuite extends StatsEstimationTestBase with PlanTest { private val nameToColInfo: Map[String, (Attribute, ColumnStat)] = columnInfo.map(kv => kv._1.name -> kv) + test("SPARK-26894: propagate child stats for aliases in Aggregate") { + val tableColumns = Seq("key11", "key12") + val groupByColumns = Seq("key11") + val attributes = groupByColumns.map(nameToAttr) + + val rowCount = 2 + val child = StatsTestPlan( + outputList = tableColumns.map(nameToAttr), + rowCount, + // rowCount * (overhead + column size) + size = Some(4 * (8 + 4)), + attributeStats = AttributeMap(tableColumns.map(nameToColInfo))) + + val testAgg = Aggregate( + groupingExpressions = attributes, + aggregateExpressions = Seq(Alias(nameToAttr("key12"), "abc")()), + child) + + val expectedColStats = Seq("abc" -> nameToColInfo("key12")._2) + val expectedAttrStats = toAttributeMap(expectedColStats, testAgg) + + assert(testAgg.stats.attributeStats == expectedAttrStats) + } + test("set an upper bound if the product of ndv's of group-by columns is too large") { // Suppose table1 (key11 int, key12 int) has 4 records: (1, 10), (1, 20), (2, 30), (2, 40) checkAggStats( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org