Repository: spark Updated Branches: refs/heads/master f3137feec -> 72561ecf4
[SPARK-22266][SQL] The same aggregate function was evaluated multiple times ## What changes were proposed in this pull request? To let the same aggregate function that appear multiple times in an Aggregate be evaluated only once, we need to deduplicate the aggregate expressions. The original code was trying to use a "distinct" call to get a set of aggregate expressions, but did not work, since the "distinct" did not compare semantic equality. And even if it did, further work should be done in result expression rewriting. In this PR, I changed the "set" to a map mapping the semantic identity of a aggregate expression to itself. Thus, later on, when rewriting result expressions (i.e., output expressions), the aggregate expression reference can be fixed. ## How was this patch tested? Added a new test in SQLQuerySuite Author: maryannxue <[email protected]> Closes #19488 from maryannxue/spark-22266. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/72561ecf Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/72561ecf Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/72561ecf Branch: refs/heads/master Commit: 72561ecf4b611d68f8bf695ddd0c4c2cce3a29d9 Parents: f3137fe Author: maryannxue <[email protected]> Authored: Wed Oct 18 20:59:40 2017 +0800 Committer: Wenchen Fan <[email protected]> Committed: Wed Oct 18 20:59:40 2017 +0800 ---------------------------------------------------------------------- .../spark/sql/catalyst/planning/patterns.scala | 16 +++++++----- .../org/apache/spark/sql/SQLQuerySuite.scala | 26 ++++++++++++++++++++ 2 files changed, 36 insertions(+), 6 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/72561ecf/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 8d034c2..cc391aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -205,14 +205,17 @@ object PhysicalAggregation { case logical.Aggregate(groupingExpressions, resultExpressions, child) => // A single aggregate expression might appear multiple times in resultExpressions. // In order to avoid evaluating an individual aggregate function multiple times, we'll - // build a set of the distinct aggregate expressions and build a function which can - // be used to re-write expressions so that they reference the single copy of the - // aggregate function which actually gets computed. + // build a set of semantically distinct aggregate expressions and re-write expressions so + // that they reference the single copy of the aggregate function which actually gets computed. + // Non-deterministic aggregate expressions are not deduplicated. + val equivalentAggregateExpressions = new EquivalentExpressions val aggregateExpressions = resultExpressions.flatMap { expr => expr.collect { - case agg: AggregateExpression => agg + // addExpr() always returns false for non-deterministic expressions and do not add them. + case agg: AggregateExpression + if (!equivalentAggregateExpressions.addExpr(agg)) => agg } - }.distinct + } val namedGroupingExpressions = groupingExpressions.map { case ne: NamedExpression => ne -> ne @@ -236,7 +239,8 @@ object PhysicalAggregation { case ae: AggregateExpression => // The final aggregation buffer's attributes will be `finalAggregationAttributes`, // so replace each aggregate expression by its corresponding attribute in the set: - ae.resultAttribute + equivalentAggregateExpressions.getEquivalentExprs(ae).headOption + .getOrElse(ae).asInstanceOf[AggregateExpression].resultAttribute case expression => // Since we're using `namedGroupingAttributes` to extract the grouping key // columns, we need to replace grouping key expressions with their corresponding http://git-wip-us.apache.org/repos/asf/spark/blob/72561ecf/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index f0c58e2..caf332d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.{AccumulatorSuite, SparkException} import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.aggregate +import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -2715,4 +2716,29 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(df, Row(1, 1, 1)) } } + + test("SRARK-22266: the same aggregate function was calculated multiple times") { + val query = "SELECT a, max(b+1), max(b+1) + 1 FROM testData2 GROUP BY a" + val df = sql(query) + val physical = df.queryExecution.sparkPlan + val aggregateExpressions = physical.collectFirst { + case agg : HashAggregateExec => agg.aggregateExpressions + case agg : SortAggregateExec => agg.aggregateExpressions + } + assert (aggregateExpressions.isDefined) + assert (aggregateExpressions.get.size == 1) + checkAnswer(df, Row(1, 3, 4) :: Row(2, 3, 4) :: Row(3, 3, 4) :: Nil) + } + + test("Non-deterministic aggregate functions should not be deduplicated") { + val query = "SELECT a, first_value(b), first_value(b) + 1 FROM testData2 GROUP BY a" + val df = sql(query) + val physical = df.queryExecution.sparkPlan + val aggregateExpressions = physical.collectFirst { + case agg : HashAggregateExec => agg.aggregateExpressions + case agg : SortAggregateExec => agg.aggregateExpressions + } + assert (aggregateExpressions.isDefined) + assert (aggregateExpressions.get.size == 2) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
