This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 961dcdfb9845 [SPARK-45882][SQL] BroadcastHashJoinExec propagate
partitioning should respect CoalescedHashPartitioning
961dcdfb9845 is described below
commit 961dcdfb98455f341c3f6279fa65aa1dd58ca199
Author: ulysses-you <[email protected]>
AuthorDate: Tue Nov 14 05:42:13 2023 +0800
[SPARK-45882][SQL] BroadcastHashJoinExec propagate partitioning should
respect CoalescedHashPartitioning
### What changes were proposed in this pull request?
Add HashPartitioningLike trait and make HashPartitioning and
CoalescedHashPartitioning extend it. When we propagate output partiitoning, we
should handle HashPartitioningLike instead of HashPartitioning. This pr also
changes the BroadcastHashJoinExec to use HashPartitioningLike to avoid
regression.
### Why are the changes needed?
Avoid unnecessary shuffle exchange.
### Does this PR introduce _any_ user-facing change?
yes, avoid regression
### How was this patch tested?
add test
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #43753 from ulysses-you/partitioning.
Authored-by: ulysses-you <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../sql/catalyst/plans/physical/partitioning.scala | 46 ++++++++++------------
.../execution/joins/BroadcastHashJoinExec.scala | 11 +++---
.../scala/org/apache/spark/sql/JoinSuite.scala | 28 ++++++++++++-
3 files changed, 54 insertions(+), 31 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index 0ae2857161c8..60e6e42bedf8 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -258,18 +258,8 @@ case object SinglePartition extends Partitioning {
SinglePartitionShuffleSpec
}
-/**
- * Represents a partitioning where rows are split up across partitions based
on the hash
- * of `expressions`. All rows where `expressions` evaluate to the same values
are guaranteed to be
- * in the same partition.
- *
- * Since [[StatefulOpClusteredDistribution]] relies on this partitioning and
Spark requires
- * stateful operators to retain the same physical partitioning during the
lifetime of the query
- * (including restart), the result of evaluation on `partitionIdExpression`
must be unchanged
- * across Spark versions. Violation of this requirement may bring silent
correctness issue.
- */
-case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
- extends Expression with Partitioning with Unevaluable {
+trait HashPartitioningLike extends Expression with Partitioning with
Unevaluable {
+ def expressions: Seq[Expression]
override def children: Seq[Expression] = expressions
override def nullable: Boolean = false
@@ -294,6 +284,20 @@ case class HashPartitioning(expressions: Seq[Expression],
numPartitions: Int)
}
}
}
+}
+
+/**
+ * Represents a partitioning where rows are split up across partitions based
on the hash
+ * of `expressions`. All rows where `expressions` evaluate to the same values
are guaranteed to be
+ * in the same partition.
+ *
+ * Since [[StatefulOpClusteredDistribution]] relies on this partitioning and
Spark requires
+ * stateful operators to retain the same physical partitioning during the
lifetime of the query
+ * (including restart), the result of evaluation on `partitionIdExpression`
must be unchanged
+ * across Spark versions. Violation of this requirement may bring silent
correctness issue.
+ */
+case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
+ extends HashPartitioningLike {
override def createShuffleSpec(distribution: ClusteredDistribution):
ShuffleSpec =
HashShuffleSpec(this, distribution)
@@ -306,7 +310,6 @@ case class HashPartitioning(expressions: Seq[Expression],
numPartitions: Int)
override protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]): HashPartitioning = copy(expressions
= newChildren)
-
}
case class CoalescedBoundary(startReducerIndex: Int, endReducerIndex: Int)
@@ -316,25 +319,18 @@ case class CoalescedBoundary(startReducerIndex: Int,
endReducerIndex: Int)
* fewer number of partitions.
*/
case class CoalescedHashPartitioning(from: HashPartitioning, partitions:
Seq[CoalescedBoundary])
- extends Expression with Partitioning with Unevaluable {
-
- override def children: Seq[Expression] = from.expressions
- override def nullable: Boolean = from.nullable
- override def dataType: DataType = from.dataType
+ extends HashPartitioningLike {
- override def satisfies0(required: Distribution): Boolean =
from.satisfies0(required)
+ override def expressions: Seq[Expression] = from.expressions
override def createShuffleSpec(distribution: ClusteredDistribution):
ShuffleSpec =
CoalescedHashShuffleSpec(from.createShuffleSpec(distribution), partitions)
- override protected def withNewChildrenInternal(
- newChildren: IndexedSeq[Expression]): CoalescedHashPartitioning =
- copy(from = from.copy(expressions = newChildren))
-
override val numPartitions: Int = partitions.length
- override def toString: String = from.toString
- override def sql: String = from.sql
+ override protected def withNewChildrenInternal(
+ newChildren: IndexedSeq[Expression]): CoalescedHashPartitioning =
+ copy(from = from.copy(expressions = newChildren))
}
/**
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
index 68022757ff24..368534d05b1f 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight,
BuildSide}
import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution,
Distribution, HashPartitioning, Partitioning, PartitioningCollection,
UnspecifiedDistribution}
+import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution,
Distribution, HashPartitioningLike, Partitioning, PartitioningCollection,
UnspecifiedDistribution}
import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
@@ -73,7 +73,7 @@ case class BroadcastHashJoinExec(
joinType match {
case _: InnerLike if conf.broadcastHashJoinOutputPartitioningExpandLimit
> 0 =>
streamedPlan.outputPartitioning match {
- case h: HashPartitioning => expandOutputPartitioning(h)
+ case h: HashPartitioningLike => expandOutputPartitioning(h)
case c: PartitioningCollection => expandOutputPartitioning(c)
case other => other
}
@@ -99,7 +99,7 @@ case class BroadcastHashJoinExec(
private def expandOutputPartitioning(
partitioning: PartitioningCollection): PartitioningCollection = {
PartitioningCollection(partitioning.partitionings.flatMap {
- case h: HashPartitioning => expandOutputPartitioning(h).partitionings
+ case h: HashPartitioningLike => expandOutputPartitioning(h).partitionings
case c: PartitioningCollection => Seq(expandOutputPartitioning(c))
case other => Seq(other)
})
@@ -111,11 +111,12 @@ case class BroadcastHashJoinExec(
// the expanded partitioning will have the following expressions:
// Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x",
"y").
// The expanded expressions are returned as PartitioningCollection.
- private def expandOutputPartitioning(partitioning: HashPartitioning):
PartitioningCollection = {
+ private def expandOutputPartitioning(
+ partitioning: HashPartitioningLike): PartitioningCollection = {
PartitioningCollection(partitioning.multiTransformDown {
case e: Expression if
streamedKeyToBuildKeyMapping.contains(e.canonicalized) =>
e +: streamedKeyToBuildKeyMapping(e.canonicalized)
- }.asInstanceOf[LazyList[HashPartitioning]]
+ }.asInstanceOf[LazyList[HashPartitioningLike]]
.take(conf.broadcastHashJoinOutputPartitioningExpandLimit))
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index c41b85f75e58..909a05ce26f7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft,
BuildRight}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, HintInfo, Join,
JoinHint, NO_BROADCAST_AND_REPLICATION}
import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec,
ProjectExec, SortExec, SparkPlan, WholeStageCodegenExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec,
ShuffleExchangeLike}
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.execution.python.BatchEvalPythonExec
import org.apache.spark.sql.internal.SQLConf
@@ -1729,4 +1729,30 @@ class JoinSuite extends QueryTest with
SharedSparkSession with AdaptiveSparkPlan
checkAnswer(joined, expected)
}
+
+ test("SPARK-45882: BroadcastHashJoinExec propagate partitioning should
respect " +
+ "CoalescedHashPartitioning") {
+ val cached = spark.sql(
+ """
+ |select /*+ broadcast(testData) */ key, value, a
+ |from testData join (
+ | select a from testData2 group by a
+ |)tmp on key = a
+ |""".stripMargin).cache()
+ try {
+ val df = cached.groupBy("key").count()
+ val expected = Seq(Row(1, 1), Row(2, 1), Row(3, 1))
+ assert(find(df.queryExecution.executedPlan) {
+ case _: ShuffleExchangeLike => true
+ case _ => false
+ }.size == 1, df.queryExecution)
+ checkAnswer(df, expected)
+ assert(find(df.queryExecution.executedPlan) {
+ case _: ShuffleExchangeLike => true
+ case _ => false
+ }.isEmpty, df.queryExecution)
+ } finally {
+ cached.unpersist()
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]