This is an automated email from the ASF dual-hosted git repository.
gengliangwang pushed a commit to branch branch-4.x
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.x by this push:
new 5b28a6750bd2 [SPARK-57027][SQL] SortMergeJoinExec: skip
statically-dead branches in codegen
5b28a6750bd2 is described below
commit 5b28a6750bd28af9e9dc5ce314e4f5cb1ffb596f
Author: Gengliang Wang <[email protected]>
AuthorDate: Sun May 31 13:03:42 2026 -0700
[SPARK-57027][SQL] SortMergeJoinExec: skip statically-dead branches in
codegen
### What changes were proposed in this pull request?
This is a sub-task of
[SPARK-56908](https://issues.apache.org/jira/browse/SPARK-56908).
Two statically-dead patterns in `SortMergeJoinExec` codegen:
1. `genComparison` emits
```
comp = 0;
if (comp == 0) { comp = compare(k1); }
if (comp == 0) { comp = compare(k2); }
```
The first `if (comp == 0)` is always true (we just assigned 0). Emit
`comp = compare(k1);` directly; only wrap subsequent keys. `genComparison` is
called 5x per SMJ stage (twice in `genScanner`, three times in
`codegenFullOuter`). For single-key joins (common), each call collapses to one
line.
2. `genScanner` and `codegenFullOuter` emit `if (k1IsNull || k2IsNull ||
...) { handler }`. When all key `ExprValue`s have `isNull == FalseLiteral`, the
disjunction is statically `false` and the whole block (including its
`handleStreamedAnyNull` / "join with null row" handler) is dead. Detect this
and omit the block. Hits fact/dimension joins on numeric keys where Spark has
already proved non-nullability.
### Why are the changes needed?
Smaller generated Java per SMJ stage. JIT eliminates the dead code at
runtime; the win is smaller generated source, more 64KB method-limit headroom,
and slightly faster Janino compile.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing test suites cover both paths with whole-stage codegen on and off:
- `OuterJoinSuite` (SMJ full-outer codegen + interpreted scanner).
- `InnerJoinSuite` (SMJ codegen and non-codegen paths).
- `ExistenceJoinSuite` (SMJ existence path).
### Was this patch authored or co-authored using generative AI tooling?
Generated-by: Claude Code
Closes #56075 from gengliangwang/SPARK-57027-smj-dead-branches.
Authored-by: Gengliang Wang <[email protected]>
Signed-off-by: Gengliang Wang <[email protected]>
(cherry picked from commit 29d73467c23f9b14e96ba13845b8b41e58cc13f3)
Signed-off-by: Gengliang Wang <[email protected]>
---
.../sql/execution/joins/SortMergeJoinExec.scala | 103 +++++++++++++++------
1 file changed, 77 insertions(+), 26 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index b206fb528dcd..51604cdfedf1 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -191,7 +191,13 @@ case class SortMergeJoinExec(
}
private def genComparison(ctx: CodegenContext, a: Seq[ExprCode], b:
Seq[ExprCode]): String = {
- val comparisons = a.zip(b).zipWithIndex.map { case ((l, r), i) =>
+ // The first key compare always runs, so emit it unguarded. Each
subsequent key compare runs
+ // only when previous keys were equal (comp == 0).
+ val pairs = a.zip(b).zipWithIndex
+ val firstCompare = pairs.headOption.map { case ((l, r), i) =>
+ s"comp = ${ctx.genComp(leftKeys(i).dataType, l.value, r.value)};"
+ }.getOrElse("comp = 0;")
+ val restCompares = pairs.drop(1).map { case ((l, r), i) =>
s"""
|if (comp == 0) {
| comp = ${ctx.genComp(leftKeys(i).dataType, l.value, r.value)};
@@ -199,8 +205,8 @@ case class SortMergeJoinExec(
""".stripMargin.trim
}
s"""
- |comp = 0;
- |${comparisons.mkString("\n")}
+ |$firstCompare
+ |${restCompares.mkString("\n")}
""".stripMargin
}
@@ -216,11 +222,18 @@ case class SortMergeJoinExec(
val streamedRow = ctx.addMutableState("InternalRow", "streamedRow",
forceInline = true)
val bufferedRow = ctx.addMutableState("InternalRow", "bufferedRow",
forceInline = true)
- // Create variables for join keys from both sides.
+ // Create variables for join keys from both sides. Filter out
`FalseLiteral` `isNull`
+ // terms before building the disjunction so the emitted check has no
statically-dead
+ // `false` operands. When every key is statically non-nullable, the
disjunction is
+ // empty and we skip emitting the check (and the dead handler branch)
entirely.
val streamedKeyVars = createJoinKey(ctx, streamedRow, streamedKeys,
streamedOutput)
- val streamedAnyNull = streamedKeyVars.map(_.isNull).mkString(" || ")
+ val nullableStreamedIsNulls = streamedKeyVars.map(_.isNull).filter(_ !=
FalseLiteral)
+ val streamedKeysNullable = nullableStreamedIsNulls.nonEmpty
+ val streamedAnyNull = nullableStreamedIsNulls.mkString(" || ")
val bufferedKeyTmpVars = createJoinKey(ctx, bufferedRow, bufferedKeys,
bufferedOutput)
- val bufferedAnyNull = bufferedKeyTmpVars.map(_.isNull).mkString(" || ")
+ val nullableBufferedIsNulls = bufferedKeyTmpVars.map(_.isNull).filter(_ !=
FalseLiteral)
+ val bufferedKeysNullable = nullableBufferedIsNulls.nonEmpty
+ val bufferedAnyNull = nullableBufferedIsNulls.mkString(" || ")
// Copy the buffered key as class members so they could be used in next
function call.
val bufferedKeyVars = copyKeys(ctx, bufferedKeyTmpVars)
@@ -287,6 +300,27 @@ case class SortMergeJoinExec(
s"$matches.add((UnsafeRow) $bufferedRow);"
}
+ val checkStreamedAnyNull = if (streamedKeysNullable) {
+ s"""
+ |if ($streamedAnyNull) {
+ | $handleStreamedAnyNull
+ |}
+ """.stripMargin
+ } else {
+ ""
+ }
+
+ val checkBufferedAnyNull = if (bufferedKeysNullable) {
+ s"""
+ |if ($bufferedAnyNull) {
+ | $bufferedRow = null;
+ | continue;
+ |}
+ """.stripMargin
+ } else {
+ ""
+ }
+
// Generate a function to scan both streamed and buffered sides to find a
match.
// Return whether a match is found.
//
@@ -329,9 +363,7 @@ case class SortMergeJoinExec(
| if (!streamedIter.hasNext()) return false;
| $streamedRow = (InternalRow) streamedIter.next();
| ${streamedKeyVars.map(_.code).mkString("\n")}
- | if ($streamedAnyNull) {
- | $handleStreamedAnyNull
- | }
+ | ${checkStreamedAnyNull.trim}
| if (!$matches.isEmpty()) {
| ${genComparison(ctx, streamedKeyVars, matchedKeyVars)}
| if (comp == 0) {
@@ -348,10 +380,7 @@ case class SortMergeJoinExec(
| }
| $bufferedRow = (InternalRow) bufferedIter.next();
| ${bufferedKeyTmpVars.map(_.code).mkString("\n")}
- | if ($bufferedAnyNull) {
- | $bufferedRow = null;
- | continue;
- | }
+ | ${checkBufferedAnyNull.trim}
| ${bufferedKeyVars.map(_.code).mkString("\n")}
| }
| ${genComparison(ctx, streamedKeyVars, bufferedKeyVars)}
@@ -788,11 +817,17 @@ case class SortMergeJoinExec(
val leftInputRow = ctx.addMutableState("InternalRow", "leftInputRow",
forceInline = true)
val rightInputRow = ctx.addMutableState("InternalRow", "rightInputRow",
forceInline = true)
- // Create variables for join keys from both sides.
+ // Create variables for join keys from both sides. As in `genScanner`,
drop FalseLiteral
+ // `isNull` terms before joining the disjunction so the emitted check has
no dead `false`
+ // operands; omit the check entirely when every key is statically
non-nullable.
val leftKeyVars = createJoinKey(ctx, leftInputRow, leftKeys, left.output)
- val leftAnyNull = leftKeyVars.map(_.isNull).mkString(" || ")
+ val nullableLeftIsNulls = leftKeyVars.map(_.isNull).filter(_ !=
FalseLiteral)
+ val leftKeysNullable = nullableLeftIsNulls.nonEmpty
+ val leftAnyNull = nullableLeftIsNulls.mkString(" || ")
val rightKeyVars = createJoinKey(ctx, rightInputRow, rightKeys,
right.output)
- val rightAnyNull = rightKeyVars.map(_.isNull).mkString(" || ")
+ val nullableRightIsNulls = rightKeyVars.map(_.isNull).filter(_ !=
FalseLiteral)
+ val rightKeysNullable = nullableRightIsNulls.nonEmpty
+ val rightAnyNull = nullableRightIsNulls.mkString(" || ")
val matchedKeyVars = copyKeys(ctx, leftKeyVars)
val leftMatchedKeyVars = createJoinKey(ctx, leftInputRow, leftKeys,
left.output)
val rightMatchedKeyVars = createJoinKey(ctx, rightInputRow, rightKeys,
right.output)
@@ -866,6 +901,30 @@ case class SortMergeJoinExec(
// - Step 3: Buffer rows with same join keys from both sides into
`leftBuffer` and
// `rightBuffer`. Reset bit sets for both buffers accordingly
(`leftMatched` and
// `rightMatched`).
+ val checkLeftAnyNull = if (leftKeysNullable) {
+ s"""
+ |if ($leftAnyNull) {
+ | // The left row join key is null, join it with null row
+ | $outputLeftNoMatch
+ | return;
+ |}
+ """.stripMargin
+ } else {
+ ""
+ }
+
+ val checkRightAnyNull = if (rightKeysNullable) {
+ s"""
+ |if ($rightAnyNull) {
+ | // The right row join key is null, join it with null row
+ | $outputRightNoMatch
+ | return;
+ |}
+ """.stripMargin
+ } else {
+ ""
+ }
+
val findNextJoinRowsFuncName = ctx.freshName("findNextJoinRows")
ctx.addNewFunction(findNextJoinRowsFuncName,
s"""
@@ -884,18 +943,10 @@ case class SortMergeJoinExec(
| }
|
| ${leftKeyVars.map(_.code).mkString("\n")}
- | if ($leftAnyNull) {
- | // The left row join key is null, join it with null row
- | $outputLeftNoMatch
- | return;
- | }
+ | ${checkLeftAnyNull.trim}
|
| ${rightKeyVars.map(_.code).mkString("\n")}
- | if ($rightAnyNull) {
- | // The right row join key is null, join it with null row
- | $outputRightNoMatch
- | return;
- | }
+ | ${checkRightAnyNull.trim}
|
| ${genComparison(ctx, leftKeyVars, rightKeyVars)}
| if (comp < 0) {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]