cloud-fan commented on code in PR #56243:
URL: https://github.com/apache/spark/pull/56243#discussion_r3377986111


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala:
##########
@@ -90,9 +90,43 @@ case class CoalesceShufflePartitions(session: SparkSession) 
extends AQEShuffleRe
       }
     }
 
+    // For groups that feed a partitioned join 
(SortMergeJoin/ShuffledHashJoin), enforce a
+    // minimum partition count to avoid eliminating join parallelism.
+    // Design choice: we use a pre-coalesce floor (Option A) rather than 
post-coalesce skew
+    // re-checking (Option B). Option A is simpler and avoids re-running skew 
detection after
+    // coalescing. Option B would be more robust for edge cases but adds 
significant complexity
+    // and can be explored as a follow-up.
+    val adjustedMinNumPartitionsByGroup = 
coalesceGroups.zip(minNumPartitionsByGroup).map {
+      case (group, minNum) if group.feedsJoin =>
+        // mapStats is always Some here because coalescing runs after stage 
materialization in
+        // AQE. If stats are unexpectedly absent, flatMap safely contributes 0 
bytes and the
+        // floor is effectively skipped (totalSize <= advisorySize), 
preserving correctness.
+        val totalSize = group.shuffleStages.flatMap(
+          _.shuffleStage.mapStats.map(_.bytesByPartitionId.sum)).sum
+        val advisorySize = advisoryPartitionSize(group)
+        if (totalSize <= advisorySize) {
+          // Tiny data: all join data fits in one advisory-sized partition, so 
coalescing
+          // to 1 is fine -- no parallelism benefit from multiple partitions.
+          minNum
+        } else if 
(conf.getConf(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM).isDefined) {
+          // User explicitly set COALESCE_PARTITIONS_MIN_PARTITION_NUM — 
respect that intent.
+          // minNum already incorporates the config value, so no additional 
floor is needed.
+          minNum
+        } else {
+          // Compute a data-aware floor: the number of partitions needed to 
keep each partition
+          // at or below the advisory target size. The max(2, ...) ensures we 
never collapse to
+          // a single reducer for join data -- for totalSize between 
advisorySize and
+          // 2*advisorySize, ceil gives 1 or 2, so the floor of 2 prevents 
single-partition joins.
+          val joinFloor = math.max(2, math.ceil(totalSize.toDouble / 
advisorySize).toInt)

Review Comment:
   I don't think this floor changes behavior, and I'd question the approach 
even if it did.
   
   **It's a no-op.** `coalescePartitions` is already data-aware: `targetSize = 
min(ceil(totalSize/minNum), advisory)` (`ShufflePartitionsUtil.scala:60-62`), 
so a group coalesces to a *single* partition only when `targetSize >= 
totalSize`, i.e. only when `totalSize <= advisory` (for `minNum = 1`). But this 
floor is applied only in the `else`-branch where `totalSize > advisory`, and 
the tiny-data branch above returns `minNum` unchanged. Those regimes are 
disjoint: in the tiny-data case (where single-partition coalescing actually 
happens) the floor is skipped, so a join still coalesces to 1; where the floor 
engages, `coalescePartitions` with `minNum = 1` already picks `targetSize = 
advisory` and yields `ceil(totalSize/advisory) >= 2` partitions, and raising 
`minNum` to `ceil(totalSize/advisory)` reproduces the same count. None of the 
added tests fail on master as far as I can tell — could you add one that does?
   
   **Even setting that aside, input size is the wrong cost model for a join.** 
Coalescing a `<= advisory` join to one reducer is the rule working as designed 
(one reducer handles the target input size). A reducer's work on a join isn't 
its input bytes (`left + right`) — join output can be a multiple of inputs 
depending on key fanout/selectivity, which AQE doesn't know before the join 
runs. The codebase already models output blow-up for the explosive join types 
by shrinking the advisory *target* to `MIN_PARTITION_SIZE` 
(`advisoryPartitionSize`, L166-169), not via a count floor; SMJ/SHJ are 
deliberately excluded because equi-join output isn't generally explosive.
   
   And `COALESCE_PARTITIONS_MIN_PARTITION_NUM` already enforces a hard floor 
that *does* work in the tiny-data case this skips (`targetSize = totalSize/N < 
totalSize` forces >= N). Its real gap is the proportional split across groups 
(L77-91), which this doesn't address either.
   
   Minor, within this block: `math.max(2, ...)` is dead — in this branch 
`totalSize > advisory` so `ceil(totalSize/advisory)` is always `>= 2`; and the 
comment above ("ceil gives 1 or 2") is wrong, `ceil` of a value in `(1, 2]` is 
always `2`.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala:
##########
@@ -171,7 +205,11 @@ case class CoalesceShufflePartitions(session: 
SparkSession) extends AQEShuffleRe
       if (shuffleStages.forall(s => isSupported(s.shuffleStage.shuffle))) {
         // The recursion stops here, we need to call 
`p.exists(isExplodingJoin)` and find out if
         // there is any exploding join in this sub-plan-tree.
-        Seq(CoalesceGroup(shuffleStages, hasExplodingJoin || 
p.exists(isExplodingJoin)))
+        // `isPartitionedJoin(p)` catches the case where `p` itself is the 
join node.
+        // `p.exists(isPartitionedJoin)` catches cases where the join is below 
`p`
+        // (e.g., Project(SortMergeJoin(...))).
+        Seq(CoalesceGroup(shuffleStages, hasExplodingJoin || 
p.exists(isExplodingJoin),
+          feedsJoin = isPartitionedJoin(p) || p.exists(isPartitionedJoin)))

Review Comment:
   `isPartitionedJoin(p)` here is redundant: `TreeNode.exists` is `f(this) || 
children.exists(...)` (`TreeNode.scala:256`), so `p.exists(isPartitionedJoin)` 
already evaluates the predicate on `p` itself. This can just be:
   ```suggestion
             feedsJoin = p.exists(isPartitionedJoin)))
   ```
   The comment above (distinguishing "`p` itself" from "join below `p`") then 
describes a distinction that doesn't operationally exist — `exists` covers both 
— so it's worth dropping or rewording.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala:
##########
@@ -90,9 +90,43 @@ case class CoalesceShufflePartitions(session: SparkSession) 
extends AQEShuffleRe
       }
     }
 
+    // For groups that feed a partitioned join 
(SortMergeJoin/ShuffledHashJoin), enforce a
+    // minimum partition count to avoid eliminating join parallelism.
+    // Design choice: we use a pre-coalesce floor (Option A) rather than 
post-coalesce skew
+    // re-checking (Option B). Option A is simpler and avoids re-running skew 
detection after
+    // coalescing. Option B would be more robust for edge cases but adds 
significant complexity
+    // and can be explored as a follow-up.
+    val adjustedMinNumPartitionsByGroup = 
coalesceGroups.zip(minNumPartitionsByGroup).map {
+      case (group, minNum) if group.feedsJoin =>
+        // mapStats is always Some here because coalescing runs after stage 
materialization in
+        // AQE. If stats are unexpectedly absent, flatMap safely contributes 0 
bytes and the
+        // floor is effectively skipped (totalSize <= advisorySize), 
preserving correctness.
+        val totalSize = group.shuffleStages.flatMap(
+          _.shuffleStage.mapStats.map(_.bytesByPartitionId.sum)).sum
+        val advisorySize = advisoryPartitionSize(group)
+        if (totalSize <= advisorySize) {
+          // Tiny data: all join data fits in one advisory-sized partition, so 
coalescing
+          // to 1 is fine -- no parallelism benefit from multiple partitions.
+          minNum
+        } else if 
(conf.getConf(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM).isDefined) {
+          // User explicitly set COALESCE_PARTITIONS_MIN_PARTITION_NUM — 
respect that intent.

Review Comment:
   Non-ASCII em-dash trips scalastyle's `nonascii` rule.
   ```suggestion
             // User explicitly set COALESCE_PARTITIONS_MIN_PARTITION_NUM -- 
respect that intent.
   ```



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala:
##########
@@ -569,6 +569,236 @@ class CoalesceShufflePartitionsSuite extends 
SparkFunSuite with SQLConfHelper
     }
     withSparkSession(test, 100, None)
   }
+
+  test("SPARK-56145: CoalesceShufflePartitions should not coalesce join to 1 
partition") {
+    val test: SparkSession => Unit = { spark: SparkSession =>
+      
spark.sessionState.conf.unsetConf(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM)
+      spark.conf.set(SQLConf.COALESCE_PARTITIONS_PARALLELISM_FIRST.key, 
"false")
+      spark.conf.set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "-1")
+      spark.conf.set(SQLConf.SHUFFLE_PARTITIONS.key, "5")
+      spark.conf.set(SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key, 
"10")
+      spark.conf.set(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_SIZE.key, "1")
+
+      // Use enough data so totalSize > advisorySize (100000 bytes)
+      val df1 = spark.range(0, 10000, 1, numInputPartitions)
+        .selectExpr("id % 500 as key1", "id as value1")
+      val df2 = spark.range(0, 10000, 1, numInputPartitions)
+        .selectExpr("id % 500 as key2", "id as value2")
+
+      val join = df1.join(df2, col("key1") === 
col("key2")).select(col("key1"), col("value2"))
+      join.collect()
+
+      val finalPlan = stripAQEPlan(join.queryExecution.executedPlan)
+      val shuffleReads = finalPlan.collect {
+        case r @ CoalescedShuffleRead() => r
+      }
+
+      // After fix: join should NOT be coalesced to 1 partition
+      assert(shuffleReads.nonEmpty, "Expected coalesced shuffle reads")
+      val numPartitions = shuffleReads.head.outputPartitioning.numPartitions
+      assert(numPartitions > 1,
+        s"Join stage should not be coalesced to 1 partition, got 
$numPartitions")
+    }
+    // Advisory size 100000: with ~320KB join data, floor = 
ceil(320000/100000) = 4
+    withSparkSession(test, 100000, None)
+  }
+
+

Review Comment:
   Nit: doubled blank lines here (and again around L663-664) — single blank 
line between tests.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to