This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 683179c6813 [SPARK-39397][SQL] Relax AliasAwareOutputExpression to support alias with expression 683179c6813 is described below commit 683179c6813dbdccebd4063c3aac520020765692 Author: ulysses-you <ulyssesyo...@gmail.com> AuthorDate: Wed Jun 15 00:06:06 2022 +0800 [SPARK-39397][SQL] Relax AliasAwareOutputExpression to support alias with expression ### What changes were proposed in this pull request? Change AliasAwareOutputExpression to using expression rather than attribute to track if we can nomalize. So the aliased expression can also preserve the output partitioning and ordering. ### Why are the changes needed? We will pull out complex keys from grouping expressions, so the project can hold a alias with expression. Unfortunately we may lose the output partitioning since the current AliasAwareOutputExpression only support preserve the alias with attribute. For example, the follow query will introduce three exchanges instead of two. ```SQL SELECT c1 + 1, count(*) FROM t1 JOIN t2 ON c1 + 1 = c2 GROUP BY c1 + 1 ``` ### Does this PR introduce _any_ user-facing change? no, improve performance ### How was this patch tested? add new test Closes #36785 from ulysses-you/SPARK-39397. Authored-by: ulysses-you <ulyssesyo...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../sql/execution/AliasAwareOutputExpression.scala | 12 ++++++------ .../org/apache/spark/sql/execution/PlannerSuite.scala | 17 +++++++++++++++++ 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala index 23a9527a1b3..92e86637eec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.sql.execution -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, AttributeReference, Expression, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, PartitioningCollection, UnknownPartitioning} /** @@ -25,15 +25,15 @@ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partition trait AliasAwareOutputExpression extends UnaryExecNode { protected def outputExpressions: Seq[NamedExpression] - private lazy val aliasMap = AttributeMap(outputExpressions.collect { - case a @ Alias(child: AttributeReference, _) => (child, a.toAttribute) - }) + private lazy val aliasMap = outputExpressions.collect { + case a @ Alias(child, _) => child.canonicalized -> a.toAttribute + }.toMap protected def hasAlias: Boolean = aliasMap.nonEmpty protected def normalizeExpression(exp: Expression): Expression = { - exp.transform { - case attr: AttributeReference => aliasMap.getOrElse(attr, attr) + exp.transformDown { + case e: Expression => aliasMap.getOrElse(e.canonicalized, e) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 3bc39c8b768..6f4869bf110 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -1276,6 +1276,23 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { checkSinglePartitioning(sql("SELECT /*+ REPARTITION(1) */ * FROM VALUES(1),(2),(3) AS t(c)")) checkSinglePartitioning(sql("SELECT /*+ REPARTITION(1, c) */ * FROM VALUES(1),(2),(3) AS t(c)")) } + + test("SPARK-39397: Relax AliasAwareOutputExpression to support alias with expression") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df1 = Seq("a").toDF("c1") + val df2 = Seq("A").toDF("c2") + val df = df1.join(df2, upper($"c1") === $"c2").groupBy(upper($"c1")).agg(max($"c1")) + val numShuffles = collect(df.queryExecution.executedPlan) { + case e: ShuffleExchangeExec => e + } + val numSorts = collect(df.queryExecution.executedPlan) { + case e: SortExec => e + } + // before: numShuffles is 3, numSorts is 4 + assert(numShuffles.size == 2) + assert(numSorts.size == 2) + } + } } // Used for unit-testing EnsureRequirements --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org