Ngone51 commented on a change in pull request #33872:
URL: https://github.com/apache/spark/pull/33872#discussion_r717193397



##########
File path: 
core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
##########
@@ -18,23 +18,24 @@
 package org.apache.spark.scheduler
 
 import java.nio.ByteBuffer
-
 import scala.collection.mutable.{ArrayBuffer, HashMap}
 import scala.concurrent.duration._
-
 import org.mockito.ArgumentMatchers.{any, anyInt, anyString, eq => meq}
 import org.mockito.Mockito.{atLeast, atMost, never, spy, times, verify, when}
 import org.scalatest.BeforeAndAfterEach
 import org.scalatest.concurrent.Eventually
 import org.scalatestplus.mockito.MockitoSugar
-
 import org.apache.spark._
 import org.apache.spark.internal.Logging
 import org.apache.spark.internal.config
 import org.apache.spark.resource.{ExecutorResourceRequests, ResourceProfile, 
TaskResourceRequests}
 import org.apache.spark.resource.ResourceUtils._
 import org.apache.spark.resource.TestResourceIDs._
-import org.apache.spark.util.{Clock, ManualClock}
+import org.apache.spark.util.{Clock}
+import org.apache.spark.util.{ManualClock, ThreadUtils}
+
+import java.util.Properties
+import java.util.concurrent.{CountDownLatch, ExecutorService, 
LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit}

Review comment:
       could you follow the imports style as others?

##########
File path: 
core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
##########
@@ -1995,6 +1998,88 @@ class TaskSchedulerImplSuite extends SparkFunSuite with 
LocalSparkContext with B
     assert(!normalTSM.runningTasksSet.contains(taskId))
   }
 
+  test("SPARK-36575: Should ignore task finished event if its task set is gone 
" +
+    "in TaskSchedulerImpl.handleSuccessfulTask") {
+    val taskScheduler = setupScheduler()
+
+    val latch = new CountDownLatch(2)
+    val resultGetter = new TaskResultGetter(sc.env, taskScheduler) {
+      override protected val getTaskResultExecutor: ExecutorService =
+        new ThreadPoolExecutor(1, 1, 0L, TimeUnit.MILLISECONDS, new 
LinkedBlockingQueue[Runnable],
+          ThreadUtils.namedThreadFactory("task-result-getter")) {
+          override def execute(command: Runnable): Unit = {
+            super.execute(new Runnable {
+              override def run(): Unit = {
+                command.run()
+                latch.countDown()
+              }
+            })
+          }
+        }
+      def taskResultExecutor() : ExecutorService = getTaskResultExecutor
+    }
+    taskScheduler.taskResultGetter = resultGetter
+
+    val workerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", 1),
+      new WorkerOffer("executor1", "host1", 1))
+    val task1 = new ShuffleMapTask(1, 0, null, new Partition {
+      override def index: Int = 0
+    }, Seq(TaskLocation("host0", "executor0")), new Properties, null)
+
+    val task2 = new ShuffleMapTask(1, 0, null, new Partition {
+      override def index: Int = 1
+    }, Seq(TaskLocation("host1", "executor1")), new Properties, null)
+
+    val taskSet = new TaskSet(Array(task1, task2), 0, 0, 0, null, 0)
+
+    taskScheduler.submitTasks(taskSet)
+    val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten
+    assert(2 === taskDescriptions.length)
+
+    val ser = sc.env.serializer.newInstance()
+    val directResult = new DirectTaskResult[Int](ser.serialize(1), Seq(), 
Array.empty)
+    val resultBytes = ser.serialize(directResult)
+
+    import scala.language.reflectiveCalls

Review comment:
       Move to the imports group

##########
File path: 
core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
##########
@@ -1995,6 +1998,88 @@ class TaskSchedulerImplSuite extends SparkFunSuite with 
LocalSparkContext with B
     assert(!normalTSM.runningTasksSet.contains(taskId))
   }
 
+  test("SPARK-36575: Should ignore task finished event if its task set is gone 
" +
+    "in TaskSchedulerImpl.handleSuccessfulTask") {
+    val taskScheduler = setupScheduler()
+
+    val latch = new CountDownLatch(2)
+    val resultGetter = new TaskResultGetter(sc.env, taskScheduler) {
+      override protected val getTaskResultExecutor: ExecutorService =
+        new ThreadPoolExecutor(1, 1, 0L, TimeUnit.MILLISECONDS, new 
LinkedBlockingQueue[Runnable],
+          ThreadUtils.namedThreadFactory("task-result-getter")) {
+          override def execute(command: Runnable): Unit = {
+            super.execute(new Runnable {
+              override def run(): Unit = {
+                command.run()
+                latch.countDown()
+              }
+            })
+          }
+        }
+      def taskResultExecutor() : ExecutorService = getTaskResultExecutor
+    }
+    taskScheduler.taskResultGetter = resultGetter
+
+    val workerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", 1),
+      new WorkerOffer("executor1", "host1", 1))
+    val task1 = new ShuffleMapTask(1, 0, null, new Partition {
+      override def index: Int = 0
+    }, Seq(TaskLocation("host0", "executor0")), new Properties, null)
+
+    val task2 = new ShuffleMapTask(1, 0, null, new Partition {
+      override def index: Int = 1
+    }, Seq(TaskLocation("host1", "executor1")), new Properties, null)
+
+    val taskSet = new TaskSet(Array(task1, task2), 0, 0, 0, null, 0)
+
+    taskScheduler.submitTasks(taskSet)
+    val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten
+    assert(2 === taskDescriptions.length)
+
+    val ser = sc.env.serializer.newInstance()
+    val directResult = new DirectTaskResult[Int](ser.serialize(1), Seq(), 
Array.empty)
+    val resultBytes = ser.serialize(directResult)
+
+    import scala.language.reflectiveCalls
+    val busyTask = new Runnable {
+      val lock : Object = new Object
+      override def run(): Unit = {
+        lock.synchronized {
+          lock.wait()
+        }
+      }
+      def markTaskDone: Unit = {
+        lock.synchronized {
+          lock.notify()
+        }
+      }
+    }
+    // make getTaskResultExecutor busy
+    resultGetter.taskResultExecutor().submit(busyTask)
+
+    // task1 finished
+    val tid = taskDescriptions(0).taskId
+    taskScheduler.statusUpdate(
+      tid = tid,
+      state = TaskState.FINISHED,
+      serializedData = resultBytes
+    )
+
+    // mark executor heartbeat timed out
+    taskScheduler.executorLost(taskDescriptions(0).executorId, 
ExecutorProcessLost("Executor " +
+      "heartbeat timed out"))
+
+    busyTask.markTaskDone
+
+    // Wait until all events are processed
+    latch.await()
+
+    val taskSetManager = 
taskScheduler.taskIdToTaskSetManager.get(taskDescriptions(1).taskId)
+    assert(taskSetManager != null)
+    assert(0 == taskSetManager.tasksSuccessful)
+    assert(false == taskSetManager.successful(taskDescriptions(0).index))

Review comment:
       ```suggestion
       assert(!taskSetManager.successful(taskDescriptions(0).index))
   ```

##########
File path: 
core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
##########
@@ -34,7 +36,8 @@ import org.apache.spark.internal.config
 import org.apache.spark.resource.{ExecutorResourceRequests, ResourceProfile, 
TaskResourceRequests}
 import org.apache.spark.resource.ResourceUtils._
 import org.apache.spark.resource.TestResourceIDs._
-import org.apache.spark.util.{Clock, ManualClock}
+import org.apache.spark.util.{Clock}
+import org.apache.spark.util.{ManualClock, ThreadUtils}

Review comment:
       ```suggestion
   import org.apache.spark.util.{Clock, ManualClock, ThreadUtils}
   ```

##########
File path: 
core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
##########
@@ -870,7 +870,13 @@ private[spark] class TaskSchedulerImpl(
       taskSetManager: TaskSetManager,
       tid: Long,
       taskResult: DirectTaskResult[_]): Unit = synchronized {
-    taskSetManager.handleSuccessfulTask(tid, taskResult)
+    if (taskIdToTaskSetManager.contains(tid)) {
+      taskSetManager.handleSuccessfulTask(tid, taskResult)
+    } else {
+      logInfo(s"Ignoring update with state finished for TID $tid because its 
task set " +

Review comment:
       `....for TID $tid...` -> `for task (TID $tid)...`

##########
File path: 
core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
##########
@@ -870,7 +870,13 @@ private[spark] class TaskSchedulerImpl(
       taskSetManager: TaskSetManager,
       tid: Long,
       taskResult: DirectTaskResult[_]): Unit = synchronized {
-    taskSetManager.handleSuccessfulTask(tid, taskResult)
+    if (taskIdToTaskSetManager.contains(tid)) {
+      taskSetManager.handleSuccessfulTask(tid, taskResult)
+    } else {
+      logInfo(s"Ignoring update with state finished for TID $tid because its 
task set " +
+        s"is gone (this is likely the result of receiving duplicate task 
finished status updates)" +

Review comment:
       I wonder is it possible to receive "duplicate task finished status 
updates"?  




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to