This is an automated email from the ASF dual-hosted git repository.
ptoth pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 895666b9e369 [SPARK-55551][SQL] Improve `BroadcastHashJoinExec` output
partitioning
895666b9e369 is described below
commit 895666b9e3690659e1841600753c33c41fdbd207
Author: Peter Toth <[email protected]>
AuthorDate: Thu Feb 19 16:57:34 2026 +0100
[SPARK-55551][SQL] Improve `BroadcastHashJoinExec` output partitioning
### What changes were proposed in this pull request?
This is a minor refector of `BroadcastHashJoinExec.outputPartitioning` to:
- simlify the logic and
- make it future proof by using `Partitioning with Expression` instead of
`HashPartitioningLike`.
### Why are the changes needed?
Code cleanup and add support for future partitionings that implement
`Expression` but not `HashPartitioningLike`. (Like `KeyedPartitioning` is in
https://github.com/apache/spark/pull/54330.)
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #54335 from
peter-toth/SPARK-55551-improve-broadcasthashjoinexec-output-partitioning.
Authored-by: Peter Toth <[email protected]>
Signed-off-by: Peter Toth <[email protected]>
---
.../execution/joins/BroadcastHashJoinExec.scala | 54 +++++++++++-----------
.../sql/execution/joins/BroadcastJoinSuite.scala | 9 ++--
2 files changed, 31 insertions(+), 32 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
index 944ee3b05909..4fd5eb50dbfc 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight,
BuildSide, JoinSelectionHelper}
import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution,
Distribution, HashPartitioningLike, Partitioning, PartitioningCollection,
UnspecifiedDistribution}
+import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution,
Distribution, Partitioning, PartitioningCollection, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
@@ -72,10 +72,14 @@ case class BroadcastHashJoinExec private(
override lazy val outputPartitioning: Partitioning = {
joinType match {
case _: InnerLike if conf.broadcastHashJoinOutputPartitioningExpandLimit
> 0 =>
- streamedPlan.outputPartitioning match {
- case h: HashPartitioningLike => expandOutputPartitioning(h)
- case c: PartitioningCollection => expandOutputPartitioning(c)
- case other => other
+ val expandedPartitioning =
expandOutputPartitioning(streamedPlan.outputPartitioning)
+ expandedPartitioning match {
+ // We don't need to handle the empty case, since it could only occur
if
+ // `streamedPlan.outputPartitioning` were an empty
`PartitioningCollection`, but its
+ // constructor prevents that.
+
+ case p :: Nil => p
+ case ps => PartitioningCollection(ps)
}
case _ => streamedPlan.outputPartitioning
}
@@ -95,29 +99,25 @@ case class BroadcastHashJoinExec private(
mapping.toMap
}
- // Expands the given partitioning collection recursively.
- private def expandOutputPartitioning(
- partitioning: PartitioningCollection): PartitioningCollection = {
- PartitioningCollection(partitioning.partitionings.flatMap {
- case h: HashPartitioningLike => expandOutputPartitioning(h).partitionings
- case c: PartitioningCollection => Seq(expandOutputPartitioning(c))
+ // Expands the given partitioning recursively.
+ private def expandOutputPartitioning(partitioning: Partitioning):
Seq[Partitioning] = {
+ partitioning match {
+ case c: PartitioningCollection =>
c.partitionings.flatMap(expandOutputPartitioning)
+ case p: Partitioning with Expression =>
+ // Expands the given partitioning, that is also an expression, by
substituting streamed keys
+ // with build keys.
+ // For example, if the expressions for the given partitioning are
Seq("a", "b", "c") where
+ // the streamed keys are Seq("b", "c") and the build keys are Seq("x",
"y"), the expanded
+ // partitioning will have the following expressions:
+ // Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"),
Seq("a", "x", "y").
+ // The expanded expressions are returned as `Seq[Partitioning]`.
+ p.multiTransformDown {
+ case e: Expression if
streamedKeyToBuildKeyMapping.contains(e.canonicalized) =>
+ e +: streamedKeyToBuildKeyMapping(e.canonicalized)
+ }.asInstanceOf[LazyList[Partitioning]]
+ .take(conf.broadcastHashJoinOutputPartitioningExpandLimit)
case other => Seq(other)
- })
- }
-
- // Expands the given hash partitioning by substituting streamed keys with
build keys.
- // For example, if the expressions for the given partitioning are Seq("a",
"b", "c")
- // where the streamed keys are Seq("b", "c") and the build keys are Seq("x",
"y"),
- // the expanded partitioning will have the following expressions:
- // Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x",
"y").
- // The expanded expressions are returned as PartitioningCollection.
- private def expandOutputPartitioning(
- partitioning: HashPartitioningLike): PartitioningCollection = {
- PartitioningCollection(partitioning.multiTransformDown {
- case e: Expression if
streamedKeyToBuildKeyMapping.contains(e.canonicalized) =>
- e +: streamedKeyToBuildKeyMapping(e.canonicalized)
- }.asInstanceOf[LazyList[HashPartitioningLike]]
- .take(conf.broadcastHashJoinOutputPartitioningExpandLimit))
+ }
}
protected override def doExecute(): RDD[InternalRow] = {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index 9bd858608cb9..66139134d50e 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -592,11 +592,10 @@ abstract class BroadcastJoinSuiteBase extends QueryTest
with SQLTestUtils
HashPartitioning(Seq(l3), 1)))),
right = DummySparkPlan())
expected = PartitioningCollection(Seq(
- PartitioningCollection(Seq(
- HashPartitioning(Seq(l1), 1),
- HashPartitioning(Seq(r1), 1),
- HashPartitioning(Seq(l2), 1),
- HashPartitioning(Seq(r2), 1))),
+ HashPartitioning(Seq(l1), 1),
+ HashPartitioning(Seq(r1), 1),
+ HashPartitioning(Seq(l2), 1),
+ HashPartitioning(Seq(r2), 1),
HashPartitioning(Seq(l3), 1),
HashPartitioning(Seq(r3), 1)))
assert(bhj.outputPartitioning === expected)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]