This is an automated email from the ASF dual-hosted git repository.

ptoth pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 895666b9e369 [SPARK-55551][SQL] Improve `BroadcastHashJoinExec` output 
partitioning
895666b9e369 is described below

commit 895666b9e3690659e1841600753c33c41fdbd207
Author: Peter Toth <[email protected]>
AuthorDate: Thu Feb 19 16:57:34 2026 +0100

    [SPARK-55551][SQL] Improve `BroadcastHashJoinExec` output partitioning
    
    ### What changes were proposed in this pull request?
    
    This is a minor refector of `BroadcastHashJoinExec.outputPartitioning` to:
    - simlify the logic and
    - make it future proof by using `Partitioning with Expression` instead of 
`HashPartitioningLike`.
    
    ### Why are the changes needed?
    Code cleanup and add support for future partitionings that implement 
`Expression` but not `HashPartitioningLike`. (Like `KeyedPartitioning` is in 
https://github.com/apache/spark/pull/54330.)
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Existing tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #54335 from 
peter-toth/SPARK-55551-improve-broadcasthashjoinexec-output-partitioning.
    
    Authored-by: Peter Toth <[email protected]>
    Signed-off-by: Peter Toth <[email protected]>
---
 .../execution/joins/BroadcastHashJoinExec.scala    | 54 +++++++++++-----------
 .../sql/execution/joins/BroadcastJoinSuite.scala   |  9 ++--
 2 files changed, 31 insertions(+), 32 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
index 944ee3b05909..4fd5eb50dbfc 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, 
BuildSide, JoinSelectionHelper}
 import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, 
Distribution, HashPartitioningLike, Partitioning, PartitioningCollection, 
UnspecifiedDistribution}
+import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, 
Distribution, Partitioning, PartitioningCollection, UnspecifiedDistribution}
 import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan}
 import org.apache.spark.sql.execution.metric.SQLMetrics
 
@@ -72,10 +72,14 @@ case class BroadcastHashJoinExec private(
   override lazy val outputPartitioning: Partitioning = {
     joinType match {
       case _: InnerLike if conf.broadcastHashJoinOutputPartitioningExpandLimit 
> 0 =>
-        streamedPlan.outputPartitioning match {
-          case h: HashPartitioningLike => expandOutputPartitioning(h)
-          case c: PartitioningCollection => expandOutputPartitioning(c)
-          case other => other
+        val expandedPartitioning = 
expandOutputPartitioning(streamedPlan.outputPartitioning)
+        expandedPartitioning match {
+          // We don't need to handle the empty case, since it could only occur 
if
+          // `streamedPlan.outputPartitioning` were an empty 
`PartitioningCollection`, but its
+          // constructor prevents that.
+
+          case p :: Nil => p
+          case ps => PartitioningCollection(ps)
         }
       case _ => streamedPlan.outputPartitioning
     }
@@ -95,29 +99,25 @@ case class BroadcastHashJoinExec private(
     mapping.toMap
   }
 
-  // Expands the given partitioning collection recursively.
-  private def expandOutputPartitioning(
-      partitioning: PartitioningCollection): PartitioningCollection = {
-    PartitioningCollection(partitioning.partitionings.flatMap {
-      case h: HashPartitioningLike => expandOutputPartitioning(h).partitionings
-      case c: PartitioningCollection => Seq(expandOutputPartitioning(c))
+  // Expands the given partitioning recursively.
+  private def expandOutputPartitioning(partitioning: Partitioning): 
Seq[Partitioning] = {
+    partitioning match {
+      case c: PartitioningCollection => 
c.partitionings.flatMap(expandOutputPartitioning)
+      case p: Partitioning with Expression =>
+        // Expands the given partitioning, that is also an expression, by 
substituting streamed keys
+        // with build keys.
+        // For example, if the expressions for the given partitioning are 
Seq("a", "b", "c") where
+        // the streamed keys are Seq("b", "c") and the build keys are Seq("x", 
"y"), the expanded
+        // partitioning will have the following expressions:
+        // Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), 
Seq("a", "x", "y").
+        // The expanded expressions are returned as `Seq[Partitioning]`.
+        p.multiTransformDown {
+          case e: Expression if 
streamedKeyToBuildKeyMapping.contains(e.canonicalized) =>
+            e +: streamedKeyToBuildKeyMapping(e.canonicalized)
+        }.asInstanceOf[LazyList[Partitioning]]
+          .take(conf.broadcastHashJoinOutputPartitioningExpandLimit)
       case other => Seq(other)
-    })
-  }
-
-  // Expands the given hash partitioning by substituting streamed keys with 
build keys.
-  // For example, if the expressions for the given partitioning are Seq("a", 
"b", "c")
-  // where the streamed keys are Seq("b", "c") and the build keys are Seq("x", 
"y"),
-  // the expanded partitioning will have the following expressions:
-  // Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", 
"y").
-  // The expanded expressions are returned as PartitioningCollection.
-  private def expandOutputPartitioning(
-      partitioning: HashPartitioningLike): PartitioningCollection = {
-    PartitioningCollection(partitioning.multiTransformDown {
-      case e: Expression if 
streamedKeyToBuildKeyMapping.contains(e.canonicalized) =>
-        e +: streamedKeyToBuildKeyMapping(e.canonicalized)
-    }.asInstanceOf[LazyList[HashPartitioningLike]]
-      .take(conf.broadcastHashJoinOutputPartitioningExpandLimit))
+    }
   }
 
   protected override def doExecute(): RDD[InternalRow] = {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index 9bd858608cb9..66139134d50e 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -592,11 +592,10 @@ abstract class BroadcastJoinSuiteBase extends QueryTest 
with SQLTestUtils
         HashPartitioning(Seq(l3), 1)))),
       right = DummySparkPlan())
     expected = PartitioningCollection(Seq(
-      PartitioningCollection(Seq(
-        HashPartitioning(Seq(l1), 1),
-        HashPartitioning(Seq(r1), 1),
-        HashPartitioning(Seq(l2), 1),
-        HashPartitioning(Seq(r2), 1))),
+      HashPartitioning(Seq(l1), 1),
+      HashPartitioning(Seq(r1), 1),
+      HashPartitioning(Seq(l2), 1),
+      HashPartitioning(Seq(r2), 1),
       HashPartitioning(Seq(l3), 1),
       HashPartitioning(Seq(r3), 1)))
     assert(bhj.outputPartitioning === expected)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to