venkata91 commented on a change in pull request #33896:
URL: https://github.com/apache/spark/pull/33896#discussion_r772564100



##########
File path: 
core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
##########
@@ -3847,11 +3886,248 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
 
     // Job successful ended.
     assert(results === Map(0 -> 11, 1 -> 12))
+  }
+
+  test("SPARK-33701: shuffle adaptive merge finalization") {
+    initPushBasedShuffleConfs(conf)
+    conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 3)
+    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
+    val parts = 2
+
+    val shuffleMapRdd1 = new MyRDD(sc, parts, Nil)
+    val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new 
HashPartitioner(parts))
+    val shuffleMapRdd2 = new MyRDD(sc, parts, Nil)
+    val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new 
HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep1, shuffleDep2),
+      tracker = mapOutputTracker)
+
+    // Submit a reduce job that depends which will create a map stage
+    submit(reduceRdd, (0 until parts).toArray)
+
+    val taskResults = taskSets(0).tasks.zipWithIndex.map {
+      case (_, idx) =>
+        (Success, makeMapStatus("host" + ('A' + idx).toChar, parts))
+    }.toSeq
+    for ((result, i) <- taskResults.zipWithIndex) {
+      runEvent(makeCompletionEvent(taskSets(0).tasks(i), result._1, result._2))
+    }
+    val shuffleStage1 = 
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+    assert(shuffleStage1.shuffleDep.shuffleMergeEnabled)
+    val finalizeTask1 = shuffleStage1.shuffleDep.getFinalizeTask.get
+      .asInstanceOf[DummyScheduledFuture]
+    assert(finalizeTask1.delay == 10 && finalizeTask1.registerMergeResults)
+    assert(shuffleStage1.shuffleDep.shuffleMergeFinalized)
+
+    complete(taskSets(1), taskSets(1).tasks.zipWithIndex.map {
+      case (_, idx) =>
+        (Success, makeMapStatus("host" + ('A' + idx).toChar, parts, 10))
+    }.toSeq)
+    val shuffleStage2 = 
scheduler.stageIdToStage(1).asInstanceOf[ShuffleMapStage]
+    // Verify finalize task is set with default delay of 10s and merge results 
are marked
+    // for registration
+    assert(shuffleStage2.shuffleDep.getFinalizeTask.nonEmpty)
+    val finalizeTask2 = shuffleStage2.shuffleDep.getFinalizeTask.get
+      .asInstanceOf[DummyScheduledFuture]
+    assert(finalizeTask2.delay == 10 && finalizeTask2.registerMergeResults)
+
+    assert(mapOutputTracker.
+      getNumAvailableMergeResults(shuffleStage1.shuffleDep.shuffleId) == parts)
+    assert(mapOutputTracker.
+      getNumAvailableMergeResults(shuffleStage2.shuffleDep.shuffleId) == parts)
+    completeNextResultStageWithSuccess(2, 0)
+    assert(results === Map(0 -> 42, 1 -> 42))
+
     results.clear()
     assertDataStructuresEmpty()
   }
 
-  /**
+  test("SPARK-33701: check shuffle merge finalization triggering after 
minimum" +
+    " threshold push complete") {
+    initPushBasedShuffleConfs(conf)
+    conf.set(config.PUSH_BASED_SHUFFLE_SIZE_MIN_SHUFFLE_SIZE_TO_WAIT, 10L)
+    conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 5)
+    conf.set(config.PUSH_BASED_SHUFFLE_MIN_PUSH_RATIO, 0.5)
+    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
+    val parts = 4
+
+    val shuffleMapRdd1 = new MyRDD(sc, parts, Nil)
+    val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new 
HashPartitioner(parts))
+    val shuffleMapRdd2 = new MyRDD(sc, parts, Nil)
+    val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new 
HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep1, shuffleDep2),
+      tracker = mapOutputTracker)
+
+    // Submit a reduce job that depends which will create a map stage
+    submit(reduceRdd, (0 until parts).toArray)
+
+    val taskResults = taskSets(0).tasks.zipWithIndex.map {
+      case (_, idx) =>
+        (Success, makeMapStatus("host" + ('A' + idx).toChar, parts))
+    }.toSeq
+
+    runEvent(makeCompletionEvent(taskSets(0).tasks(0), taskResults(0)._1, 
taskResults(0)._2))
+    runEvent(makeCompletionEvent(taskSets(0).tasks(1), taskResults(0)._1, 
taskResults(0)._2))
+
+    val shuffleStage1 = 
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+    assert(shuffleStage1.shuffleDep.shuffleMergeEnabled)
+
+    pushComplete(shuffleStage1.shuffleDep.shuffleId, 0, 0)
+    pushComplete(shuffleStage1.shuffleDep.shuffleId, 0, 1)
+
+    // Minimum push complete for 2 tasks, schedule merge finalization
+    val finalizeTask = shuffleStage1.shuffleDep.getFinalizeTask.get
+      .asInstanceOf[DummyScheduledFuture]
+    assert(finalizeTask.registerMergeResults && finalizeTask.delay == 0)
+
+    runEvent(makeCompletionEvent(taskSets(0).tasks(2), taskResults(0)._1, 
taskResults(0)._2))
+    runEvent(makeCompletionEvent(taskSets(0).tasks(3), taskResults(0)._1, 
taskResults(0)._2))
+
+    completeShuffleMapStageSuccessfully(1, 0, parts)
+
+    completeNextResultStageWithSuccess(2, 0)
+    assert(results === Map(0 -> 42, 1 -> 42, 2 -> 42, 3 -> 42))
+
+    results.clear()
+    assertDataStructuresEmpty()
+  }
+
+  test("SPARK-33701: check shuffle merge finalization with stage 
cancellation") {

Review comment:
       I added a comment, please take a look. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to