This is an automated email from the ASF dual-hosted git repository.

gengliangwang pushed a commit to branch branch-4.x
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.x by this push:
     new 32105b5373a0 [SPARK-57026][SQL] SortMergeJoinExec and 
ShuffledHashJoinExec: replace anonymous TaskCompletionListener with shared 
JoinHelper methods
32105b5373a0 is described below

commit 32105b5373a03109a742f2b74ba6f9b4cf5505d7
Author: Gengliang Wang <[email protected]>
AuthorDate: Sat May 30 20:13:12 2026 -0700

    [SPARK-57026][SQL] SortMergeJoinExec and ShuffledHashJoinExec: replace 
anonymous TaskCompletionListener with shared JoinHelper methods
    
    ### What changes were proposed in this pull request?
    
    This is a sub-task of 
[SPARK-56908](https://issues.apache.org/jira/browse/SPARK-56908).
    
    Two join operators emit anonymous `TaskCompletionListener`s whose bodies 
are type-independent:
    
    - `SortMergeJoinExec.doProduce` registers a per-stage anonymous inner class 
that adds `matches.spillSize()` to the `spillSize` metric.
    - `ShuffledHashJoinExec.buildSideOrFullOuterJoinNonUniqueKey` registers a 
runtime anonymous closure that adds the `OpenHashSet[Long]` memory footprint 
(bit-set + data array) to `buildDataSize`.
    
    Hoist both into shared static helpers in a new file 
`sql/core/src/main/java/org/apache/spark/sql/execution/joins/JoinHelper.java`:
    
    ```java
    recordSpillSizeOnTaskCompletion(ExternalAppendOnlyUnsafeRowArray, SQLMetric)
    recordOpenHashSetMemoryUsageOnTaskCompletion(OpenHashSet<?>, SQLMetric)
    ```
    
    Also remove the now-unused `SortMergeJoinExec.getTaskContext()` whose only 
caller was the inlined listener.
    
    ### Why are the changes needed?
    
    - Smaller generated Java per `SortMergeJoinExec` whole-stage-codegen stage: 
one anonymous inner class is no longer emitted per stage.
    - Centralises the metric-recording listener bodies in one place where the 
JIT can compile them once instead of once per stage.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing test suites cover both paths with whole-stage codegen on and off:
    - `OuterJoinSuite` (SMJ full-outer codegen path).
    - `InnerJoinSuite` (SMJ codegen path with spill).
    - ShuffledHashJoin full-outer non-unique-key path tests in `OuterJoinSuite`.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Generated-by: Claude Code
    
    Closes #56074 from gengliangwang/SPARK-57026-listener-helpers.
    
    Authored-by: Gengliang Wang <[email protected]>
    Signed-off-by: Gengliang Wang <[email protected]>
    (cherry picked from commit be8a32d67656bfff1ac75422498db8314a37e939)
    Signed-off-by: Gengliang Wang <[email protected]>
---
 .../spark/sql/execution/joins/JoinHelper.java      | 35 ++++++++++++++++++++--
 .../sql/execution/joins/ShuffledHashJoinExec.scala | 12 +++-----
 .../sql/execution/joins/SortMergeJoinExec.scala    | 18 ++---------
 3 files changed, 39 insertions(+), 26 deletions(-)

diff --git 
a/sql/core/src/main/java/org/apache/spark/sql/execution/joins/JoinHelper.java 
b/sql/core/src/main/java/org/apache/spark/sql/execution/joins/JoinHelper.java
index 91156b2600fd..041bfa04081f 100644
--- 
a/sql/core/src/main/java/org/apache/spark/sql/execution/joins/JoinHelper.java
+++ 
b/sql/core/src/main/java/org/apache/spark/sql/execution/joins/JoinHelper.java
@@ -17,12 +17,17 @@
 
 package org.apache.spark.sql.execution.joins;
 
+import org.apache.spark.TaskContext;
+import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray;
+import org.apache.spark.sql.execution.metric.SQLMetric;
 import org.apache.spark.util.collection.BitSet;
+import org.apache.spark.util.collection.OpenHashSet;
 
 /**
  * Static helpers shared by join operators in this package, used both from 
whole-stage codegen and
- * from interpreted execution paths. Hoisting recurring snippets here keeps 
the generated Java
- * source smaller and lets the JIT compile the bodies once instead of once per 
stage.
+ * from interpreted execution paths. Hoisting recurring snippets here 
(especially the ones that
+ * would otherwise be emitted as anonymous inner classes per generated stage) 
keeps the generated
+ * Java source smaller and lets the JIT compile the bodies once instead of 
once per stage.
  */
 public final class JoinHelper {
 
@@ -44,4 +49,30 @@ public final class JoinHelper {
     }
     return new BitSet(bufferSize);
   }
+
+  /**
+   * Register a task-completion listener that adds the final spill size of 
{@code matches} to
+   * {@code spillSize}. Replaces an anonymous {@code TaskCompletionListener} 
that would otherwise
+   * be generated per {@code SortMergeJoinExec} whole-stage class.
+   */
+  public static void recordSpillSizeOnTaskCompletion(
+      ExternalAppendOnlyUnsafeRowArray matches, SQLMetric spillSize) {
+    TaskContext.get().addTaskCompletionListener(context -> {
+      spillSize.add(matches.spillSize());
+    });
+  }
+
+  /**
+   * Register a task-completion listener that adds the estimated memory 
footprint of
+   * {@code matchedRows} (the bit-set plus the data array) to {@code metric}. 
Used by
+   * {@code ShuffledHashJoinExec} to track {@code buildDataSize} for its 
matched-row tracker.
+   */
+  public static void recordOpenHashSetMemoryUsageOnTaskCompletion(
+      OpenHashSet<?> matchedRows, SQLMetric metric) {
+    TaskContext.get().addTaskCompletionListener(context -> {
+      long bitSetEstimatedSize = matchedRows.getBitSet().capacity() / 8L;
+      long dataEstimatedSize = matchedRows.capacity() * 8L;
+      metric.add(bitSetEstimatedSize + dataEstimatedSize);
+    });
+  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
index 0f90f443ad41..8d65a082984f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
@@ -270,14 +270,10 @@ case class ShuffledHashJoinExec private (
       buildNullRow: GenericInternalRow,
       isFullOuterJoin: Boolean): Iterator[InternalRow] = {
     val matchedRows = new OpenHashSet[Long]
-    TaskContext.get().addTaskCompletionListener[Unit](_ => {
-      // At the end of the task, update the task's memory usage for this
-      // [[OpenHashSet]] to track matched rows, which has two parts:
-      // [[OpenHashSet._bitset]] and [[OpenHashSet._data]].
-      val bitSetEstimatedSize = matchedRows.getBitSet.capacity / 8
-      val dataEstimatedSize = matchedRows.capacity * 8
-      longMetric("buildDataSize") += bitSetEstimatedSize + dataEstimatedSize
-    })
+    // At the end of the task, update the task's memory usage for this 
OpenHashSet that tracks
+    // matched rows (its underlying bit-set plus data array).
+    JoinHelper.recordOpenHashSetMemoryUsageOnTaskCompletion(
+      matchedRows, longMetric("buildDataSize"))
 
     def markRowMatched(keyIndex: Int, valueIndex: Int): Unit = {
       val rowIndex: Long = (keyIndex.toLong << 32) | valueIndex
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index 985fc518742c..b206fb528dcd 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -440,13 +440,6 @@ case class SortMergeJoinExec(
 
   override def needCopyResult: Boolean = true
 
-  /**
-   * This is called by generated Java class, should be public.
-   */
-  def getTaskContext(): TaskContext = {
-    TaskContext.get()
-  }
-
   override def doProduce(ctx: CodegenContext): String = {
     // Specialize `doProduce` code for full outer join, because full outer 
join needs to
     // buffer both sides of join.
@@ -591,16 +584,9 @@ case class SortMergeJoinExec(
     }
 
     val initJoin = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initJoin")
+    val helperCls = classOf[JoinHelper].getName
     val addHookToRecordMetrics =
-      s"""
-         |$thisPlan.getTaskContext().addTaskCompletionListener(
-         |  new org.apache.spark.util.TaskCompletionListener() {
-         |    @Override
-         |    public void onTaskCompletion(org.apache.spark.TaskContext 
context) {
-         |      ${metricTerm(ctx, "spillSize")}.add($matches.spillSize());
-         |    }
-         |});
-       """.stripMargin
+      s"$helperCls.recordSpillSizeOnTaskCompletion($matches, ${metricTerm(ctx, 
"spillSize")});"
 
     s"""
        |if (!$initJoin) {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to