This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 683179c6813 [SPARK-39397][SQL] Relax AliasAwareOutputExpression to
support alias with expression
683179c6813 is described below
commit 683179c6813dbdccebd4063c3aac520020765692
Author: ulysses-you <[email protected]>
AuthorDate: Wed Jun 15 00:06:06 2022 +0800
[SPARK-39397][SQL] Relax AliasAwareOutputExpression to support alias with
expression
### What changes were proposed in this pull request?
Change AliasAwareOutputExpression to using expression rather than attribute
to track if we can nomalize. So the aliased expression can also preserve the
output partitioning and ordering.
### Why are the changes needed?
We will pull out complex keys from grouping expressions, so the project can
hold a alias with expression. Unfortunately we may lose the output partitioning
since the current AliasAwareOutputExpression only support preserve the alias
with attribute.
For example, the follow query will introduce three exchanges instead of two.
```SQL
SELECT c1 + 1, count(*)
FROM t1
JOIN t2 ON c1 + 1 = c2
GROUP BY c1 + 1
```
### Does this PR introduce _any_ user-facing change?
no, improve performance
### How was this patch tested?
add new test
Closes #36785 from ulysses-you/SPARK-39397.
Authored-by: ulysses-you <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../sql/execution/AliasAwareOutputExpression.scala | 12 ++++++------
.../org/apache/spark/sql/execution/PlannerSuite.scala | 17 +++++++++++++++++
2 files changed, 23 insertions(+), 6 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala
index 23a9527a1b3..92e86637eec 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala
@@ -16,7 +16,7 @@
*/
package org.apache.spark.sql.execution
-import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap,
AttributeReference, Expression, NamedExpression, SortOrder}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Expression,
NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning,
Partitioning, PartitioningCollection, UnknownPartitioning}
/**
@@ -25,15 +25,15 @@ import
org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partition
trait AliasAwareOutputExpression extends UnaryExecNode {
protected def outputExpressions: Seq[NamedExpression]
- private lazy val aliasMap = AttributeMap(outputExpressions.collect {
- case a @ Alias(child: AttributeReference, _) => (child, a.toAttribute)
- })
+ private lazy val aliasMap = outputExpressions.collect {
+ case a @ Alias(child, _) => child.canonicalized -> a.toAttribute
+ }.toMap
protected def hasAlias: Boolean = aliasMap.nonEmpty
protected def normalizeExpression(exp: Expression): Expression = {
- exp.transform {
- case attr: AttributeReference => aliasMap.getOrElse(attr, attr)
+ exp.transformDown {
+ case e: Expression => aliasMap.getOrElse(e.canonicalized, e)
}
}
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 3bc39c8b768..6f4869bf110 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -1276,6 +1276,23 @@ class PlannerSuite extends SharedSparkSession with
AdaptiveSparkPlanHelper {
checkSinglePartitioning(sql("SELECT /*+ REPARTITION(1) */ * FROM
VALUES(1),(2),(3) AS t(c)"))
checkSinglePartitioning(sql("SELECT /*+ REPARTITION(1, c) */ * FROM
VALUES(1),(2),(3) AS t(c)"))
}
+
+ test("SPARK-39397: Relax AliasAwareOutputExpression to support alias with
expression") {
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+ val df1 = Seq("a").toDF("c1")
+ val df2 = Seq("A").toDF("c2")
+ val df = df1.join(df2, upper($"c1") ===
$"c2").groupBy(upper($"c1")).agg(max($"c1"))
+ val numShuffles = collect(df.queryExecution.executedPlan) {
+ case e: ShuffleExchangeExec => e
+ }
+ val numSorts = collect(df.queryExecution.executedPlan) {
+ case e: SortExec => e
+ }
+ // before: numShuffles is 3, numSorts is 4
+ assert(numShuffles.size == 2)
+ assert(numSorts.size == 2)
+ }
+ }
}
// Used for unit-testing EnsureRequirements
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]