This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 28da1d853477 [SPARK-46383] Reduce Driver Heap Usage by Reducing the 
Lifespan of `TaskInfo.accumulables()`
28da1d853477 is described below

commit 28da1d853477b306774798d8aa738901221fb804
Author: Utkarsh <utkarsh.agar...@databricks.com>
AuthorDate: Fri Jan 12 10:28:22 2024 +0800

    [SPARK-46383] Reduce Driver Heap Usage by Reducing the Lifespan of 
`TaskInfo.accumulables()`
    
    ### What changes were proposed in this pull request?
    `AccumulableInfo` is one of the top heap consumers in driver's heap dumps 
for stages with many tasks. For a stage with a large number of tasks 
(**_O(100k)_**), we saw **30%** of the heap usage stemming from 
`TaskInfo.accumulables()`.
    
    
![image](https://github.com/apache/spark/assets/10495099/13ef5d07-abfc-47fd-81b6-705f599db011)
    
    The `TaskSetManager` today keeps around the TaskInfo objects 
([ref1](https://github.com/apache/spark/blob/c1ba963e64a22dea28e17b1ed954e6d03d38da1e/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala#L134),
 
[ref2](https://github.com/apache/spark/blob/c1ba963e64a22dea28e17b1ed954e6d03d38da1e/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala#L192)))
 and in turn the task metrics (`AccumulableInfo`) for every task attempt until 
the stage is completed. This [...]
    
    This PR is an opt-in change (disabled by default) to reduce the driver's 
heap usage for stages with many tasks by no longer referencing the task metrics 
of completed tasks. Once a task is completed in `TaskSetManager`, we no longer 
keep its metrics around. Upon task completion, we clone the `TaskInfo` object 
and empty out the metrics for the clone. The cloned `TaskInfo` is retained by 
the `TaskSetManager` while the original `TaskInfo` object with the metrics is 
sent over to the `DAGSc [...]
    
    ### Config to gate changes
    The changes in the PR are guarded with the Spark conf 
`spark.scheduler.dropTaskInfoAccumulablesOnTaskCompletion.enabled` which can be 
used for rollback or staged rollouts.
    
    ### Why are the changes disabled by default?
    The PR introduces a breaking change wherein the `TaskInfo.accumulables()` 
are empty for `Resubmitted` tasks upon the loss of an executor. Read 
https://github.com/apache/spark/pull/44321#pullrequestreview-1785137821 for 
details.
    
    ### Why are the changes needed?
    
    Reduce driver's heap usage, especially for stages with many tasks
    
    ## Benchmarking
    On a cluster running a scan stage with 100k tasks, the TaskSetManager's 
heap usage dropped from 1.1 GB to 37 MB. This  **reduced the total driver's 
heap usage by 38%**, down to 2 GB from 3.5 GB.
    
    **BEFORE**
    
    
![image](https://github.com/databricks/runtime/assets/10495099/7c1599f3-3587-48a1-b019-84115b1bb90d)
    **WITH FIX**
    <img width="1386" alt="image" 
src="https://github.com/databricks/runtime/assets/10495099/b85129c8-dc10-4ee2-898d-61c8e7449616";>
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Added new tests and did benchmarking on a cluster.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Generated-by: Github Copilot
    
    Closes #44321 from utkarsh39/SPARK-46383.
    
    Authored-by: Utkarsh <utkarsh.agar...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../org/apache/spark/internal/config/package.scala | 10 +++
 .../org/apache/spark/scheduler/TaskInfo.scala      | 10 ++-
 .../apache/spark/scheduler/TaskSetManager.scala    | 71 +++++++++++++++++++---
 .../spark/scheduler/SparkListenerSuite.scala       | 35 +++++++++++
 .../spark/scheduler/TaskSetManagerSuite.scala      | 51 ++++++++++++++++
 5 files changed, 169 insertions(+), 8 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala 
b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 2823b7cdb602..bbd79c8b9653 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -2620,4 +2620,14 @@ package object config {
       .stringConf
       .toSequence
       .createWithDefault("org.apache.spark.sql.connect.client" :: Nil)
+
+  private[spark] val DROP_TASK_INFO_ACCUMULABLES_ON_TASK_COMPLETION =
+    
ConfigBuilder("spark.scheduler.dropTaskInfoAccumulablesOnTaskCompletion.enabled")
+      .internal()
+      .doc("If true, the task info accumulables will be cleared upon task 
completion in " +
+        "TaskSetManager. This reduces the heap usage of the driver by only 
referencing the " +
+        "task info accumulables for the active tasks and not for completed 
tasks.")
+      .version("4.0.0")
+      .booleanConf
+      .createWithDefault(false)
 }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala 
b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
index 2d4624828a94..9ed95870d240 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
@@ -45,7 +45,7 @@ class TaskInfo(
     val executorId: String,
     val host: String,
     val taskLocality: TaskLocality.TaskLocality,
-    val speculative: Boolean) {
+    val speculative: Boolean) extends Cloneable {
 
   /**
    * This api doesn't contains partitionId, please use the new api.
@@ -83,6 +83,14 @@ class TaskInfo(
     _accumulables = newAccumulables
   }
 
+  override def clone(): TaskInfo = super.clone().asInstanceOf[TaskInfo]
+
+  private[scheduler] def cloneWithEmptyAccumulables(): TaskInfo = {
+    val cloned = clone()
+    cloned.setAccumulables(Nil)
+    cloned
+  }
+
   /**
    * The time when the task has completed successfully (including the time to 
remotely fetch
    * results, if necessary).
diff --git 
a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala 
b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index e15ba28eeda0..390689cb8f72 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -256,6 +256,9 @@ private[spark] class TaskSetManager(
 
   private[scheduler] var emittedTaskSizeWarning = false
 
+  private[scheduler] val dropTaskInfoAccumulablesOnTaskCompletion =
+    conf.get(DROP_TASK_INFO_ACCUMULABLES_ON_TASK_COMPLETION)
+
   /** Add a task to all the pending-task lists that it should be on. */
   private[spark] def addPendingTask(
       index: Int,
@@ -785,6 +788,11 @@ private[spark] class TaskSetManager(
     // SPARK-37300: when the task was already finished state, just ignore it,
     // so that there won't cause successful and tasksSuccessful wrong result.
     if(info.finished) {
+      if (dropTaskInfoAccumulablesOnTaskCompletion) {
+        // SPARK-46383: Clear out the accumulables for a completed task to 
reduce accumulable
+        // lifetime.
+        info.setAccumulables(Nil)
+      }
       return
     }
     val index = info.index
@@ -802,6 +810,8 @@ private[spark] class TaskSetManager(
       // Handle this task as a killed task
       handleFailedTask(tid, TaskState.KILLED,
         TaskKilled("Finish but did not commit due to another attempt 
succeeded"))
+      // SPARK-46383: Not clearing the accumulables here because they are 
already cleared in
+      // handleFailedTask.
       return
     }
 
@@ -844,11 +854,50 @@ private[spark] class TaskSetManager(
     // "result.value()" in "TaskResultGetter.enqueueSuccessfulTask" before 
reaching here.
     // Note: "result.value()" only deserializes the value when it's called at 
the first time, so
     // here "result.value()" just returns the value and won't block other 
threads.
-    sched.dagScheduler.taskEnded(tasks(index), Success, result.value(), 
result.accumUpdates,
-      result.metricPeaks, info)
+
+    emptyTaskInfoAccumulablesAndNotifyDagScheduler(tid, tasks(index), Success, 
result.value(),
+      result.accumUpdates, result.metricPeaks)
     maybeFinishTaskSet()
   }
 
+  /**
+   * A wrapper around [[DAGScheduler.taskEnded()]] that empties out the 
accumulables for the
+   * TaskInfo object, corresponding to the completed task, referenced by this 
class.
+   *
+   * SPARK-46383: For the completed task, we ship the original TaskInfo to the 
DAGScheduler and only
+   * retain a cloned TaskInfo in this class. We then set the accumulables to 
Nil for the TaskInfo
+   * object that corresponds to the completed task.
+   * We do this to release references to `TaskInfo.accumulables()` as the 
TaskInfo
+   * objects held by this class are long-lived and have a heavy memory 
footprint on the driver.
+   *
+   * This is safe as the TaskInfo accumulables are not needed once they are 
shipped to the
+   * DAGScheduler where they are aggregated. Additionally, the original 
TaskInfo, and not a
+   * clone, must be sent to the DAGScheduler as this TaskInfo object is sent 
to the
+   * DAGScheduler on multiple events during the task's lifetime. Users can 
install
+   * SparkListeners that compare the TaskInfo objects across these 
SparkListener events and
+   * thus the TaskInfo object sent to the DAGScheduler must always reference 
the same TaskInfo
+   * object.
+   */
+  private def emptyTaskInfoAccumulablesAndNotifyDagScheduler(
+      taskId: Long,
+      task: Task[_],
+      reason: TaskEndReason,
+      result: Any,
+      accumUpdates: Seq[AccumulatorV2[_, _]],
+      metricPeaks: Array[Long]): Unit = {
+    val taskInfoWithAccumulables = taskInfos(taskId);
+    if (dropTaskInfoAccumulablesOnTaskCompletion) {
+      val index = taskInfoWithAccumulables.index
+      val clonedTaskInfo = 
taskInfoWithAccumulables.cloneWithEmptyAccumulables()
+      // Update this task's taskInfo while preserving its position in the list
+      taskAttempts(index) =
+        taskAttempts(index).map { i => if (i eq taskInfoWithAccumulables) 
clonedTaskInfo else i }
+      taskInfos(taskId) = clonedTaskInfo
+    }
+    sched.dagScheduler.taskEnded(task, reason, result, accumUpdates, 
metricPeaks,
+      taskInfoWithAccumulables)
+  }
+
   private[scheduler] def markPartitionCompleted(partitionId: Int): Unit = {
     partitionToIndex.get(partitionId).foreach { index =>
       if (!successful(index)) {
@@ -872,6 +921,11 @@ private[spark] class TaskSetManager(
     // SPARK-37300: when the task was already finished state, just ignore it,
     // so that there won't cause copiesRunning wrong result.
     if (info.finished) {
+      if (dropTaskInfoAccumulablesOnTaskCompletion) {
+        // SPARK-46383: Clear out the accumulables for a completed task to 
reduce accumulable
+        // lifetime.
+        info.setAccumulables(Nil)
+      }
       return
     }
     removeRunningTask(tid)
@@ -906,7 +960,8 @@ private[spark] class TaskSetManager(
         if (ef.className == classOf[NotSerializableException].getName) {
           // If the task result wasn't serializable, there's no point in 
trying to re-execute it.
           logError(s"$task had a not serializable result: ${ef.description}; 
not retrying")
-          sched.dagScheduler.taskEnded(tasks(index), reason, null, 
accumUpdates, metricPeaks, info)
+          emptyTaskInfoAccumulablesAndNotifyDagScheduler(tid, tasks(index), 
reason, null,
+            accumUpdates, metricPeaks)
           abort(s"$task had a not serializable result: ${ef.description}")
           return
         }
@@ -915,7 +970,8 @@ private[spark] class TaskSetManager(
           // re-execute it.
           logError("Task %s in stage %s (TID %d) can not write to output file: 
%s; not retrying"
             .format(info.id, taskSet.id, tid, ef.description))
-          sched.dagScheduler.taskEnded(tasks(index), reason, null, 
accumUpdates, metricPeaks, info)
+          emptyTaskInfoAccumulablesAndNotifyDagScheduler(tid, tasks(index), 
reason, null,
+            accumUpdates, metricPeaks)
           abort("Task %s in stage %s (TID %d) can not write to output file: 
%s".format(
             info.id, taskSet.id, tid, ef.description))
           return
@@ -968,7 +1024,8 @@ private[spark] class TaskSetManager(
       isZombie = true
     }
 
-    sched.dagScheduler.taskEnded(tasks(index), reason, null, accumUpdates, 
metricPeaks, info)
+    emptyTaskInfoAccumulablesAndNotifyDagScheduler(tid, tasks(index), reason, 
null,
+      accumUpdates, metricPeaks)
 
     if (!isZombie && reason.countTowardsTaskFailures) {
       assert (null != failureReason)
@@ -1084,8 +1141,8 @@ private[spark] class TaskSetManager(
           addPendingTask(index)
           // Tell the DAGScheduler that this task was resubmitted so that it 
doesn't think our
           // stage finishes when a total of tasks.size tasks finish.
-          sched.dagScheduler.taskEnded(
-            tasks(index), Resubmitted, null, Seq.empty, Array.empty, info)
+          emptyTaskInfoAccumulablesAndNotifyDagScheduler(tid,
+            tasks(index), Resubmitted, null, Seq.empty, Array.empty)
         }
       }
     }
diff --git 
a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala 
b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
index 29e27e96908f..34b2a40d1e3b 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.scheduler
 
 import java.io.{Externalizable, ObjectInput, ObjectOutput}
+import java.util.{Collections, IdentityHashMap}
 import java.util.concurrent.Semaphore
 
 import scala.collection.mutable
@@ -289,6 +290,19 @@ class SparkListenerSuite extends SparkFunSuite with 
LocalSparkContext with Match
     stageInfo.rddInfos.forall(_.numPartitions == 4) should be {true}
   }
 
+  test("SPARK-46383: Track TaskInfo objects") {
+    // Test that the same TaskInfo object is sent to the `DAGScheduler` in the 
`onTaskStart` and
+    // `onTaskEnd` events.
+    val conf = new 
SparkConf().set(DROP_TASK_INFO_ACCUMULABLES_ON_TASK_COMPLETION, true)
+    sc = new SparkContext("local", "SparkListenerSuite", conf)
+    val listener = new SaveActiveTaskInfos
+    sc.addSparkListener(listener)
+    val rdd1 = sc.parallelize(1 to 100, 4)
+    sc.runJob(rdd1, (items: Iterator[Int]) => items.size, Seq(0, 1))
+    sc.listenerBus.waitUntilEmpty()
+    listener.taskInfos.size should be { 0 }
+  }
+
   test("local metrics") {
     sc = new SparkContext("local", "SparkListenerSuite")
     val listener = new SaveStageAndTaskInfo
@@ -643,6 +657,27 @@ class SparkListenerSuite extends SparkFunSuite with 
LocalSparkContext with Match
     }
   }
 
+  /**
+   * A simple listener that tracks task infos for all active tasks.
+   */
+  private class SaveActiveTaskInfos extends SparkListener {
+    // Use a set based on IdentityHashMap instead of a HashSet to track unique 
references of
+    // TaskInfo objects.
+    val taskInfos = Collections.newSetFromMap[TaskInfo](new IdentityHashMap)
+
+    override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
+      val info = taskStart.taskInfo
+      if (info != null) {
+        taskInfos.add(info)
+      }
+    }
+
+    override def onTaskEnd(task: SparkListenerTaskEnd): Unit = {
+      val info = task.taskInfo
+      taskInfos.remove(info)
+    }
+  }
+
   /**
    * A simple listener that saves the task indices for all task events.
    */
diff --git 
a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala 
b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index 26b38bfcc9ab..c55de278a6d2 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -62,6 +62,12 @@ class FakeDAGScheduler(sc: SparkContext, taskScheduler: 
FakeTaskScheduler)
       accumUpdates: Seq[AccumulatorV2[_, _]],
       metricPeaks: Array[Long],
       taskInfo: TaskInfo): Unit = {
+    // Set task accumulables emulating DAGScheduler behavior to enable tests 
related to
+    // `TaskInfo.accumulables`.
+    accumUpdates.foreach(acc =>
+      taskInfo.setAccumulables(
+        acc.toInfo(Some(acc.value), Some(acc.value)) +: taskInfo.accumulables)
+    )
     taskScheduler.endedTasks(taskInfo.index) = reason
   }
 
@@ -230,6 +236,51 @@ class TaskSetManagerSuite
     super.afterEach()
   }
 
+  test("SPARK-46383: TaskInfo accumulables are cleared upon task completion") {
+    val conf = new SparkConf().
+      set(config.DROP_TASK_INFO_ACCUMULABLES_ON_TASK_COMPLETION, true)
+    sc = new SparkContext("local", "test", conf)
+    sched = new FakeTaskScheduler(sc, ("exec1", "host1"))
+    val taskSet = FakeTask.createTaskSet(2)
+    val clock = new ManualClock
+    val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock 
= clock)
+    val accumUpdates = taskSet.tasks.head.metrics.internalAccums
+
+    // Offer a host. This will launch the first task.
+    val taskOption = manager.resourceOffer("exec1", "host1", NO_PREF)._1
+    assert(taskOption.isDefined)
+
+    clock.advance(1)
+    // Tell it the first task has finished successfully
+    manager.handleSuccessfulTask(0, createTaskResult(0, accumUpdates))
+    assert(sched.endedTasks(0) === Success)
+
+    // Only one task was launched and it completed successfully, thus the 
TaskInfo accumulables
+    // should be empty.
+    assert(!manager.taskInfos.exists(t => !t._2.accumulables.isEmpty))
+    assert(manager.taskAttempts.flatMap(t => 
t.filter(!_.accumulables.isEmpty)).isEmpty)
+
+    // Fail the second task (MAX_TASK_FAILURES - 1) times.
+    (1 to manager.maxTaskFailures - 1).foreach { index =>
+      val offerResult = manager.resourceOffer("exec1", "host1", ANY)._1
+      assert(offerResult.isDefined,
+        "Expect resource offer on iteration %s to return a task".format(index))
+      assert(offerResult.get.index === 1)
+      manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, 
TaskResultLost)
+    }
+
+    clock.advance(1)
+    // Successfully finish the second task.
+    val taskOption1 = manager.resourceOffer("exec1", "host1", ANY)._1
+    manager.handleSuccessfulTask(taskOption1.get.taskId, createTaskResult(1, 
accumUpdates))
+    assert(sched.endedTasks(1) === Success)
+    // The TaskInfo accumulables should be empty as the second task has now 
completed successfully.
+    assert(!manager.taskInfos.exists(t => !t._2.accumulables.isEmpty))
+    assert(manager.taskAttempts.flatMap(t => 
t.filter(!_.accumulables.isEmpty)).isEmpty)
+
+    assert(sched.finishedManagers.contains(manager))
+  }
+
   test("TaskSet with no preferences") {
     sc = new SparkContext("local", "test")
     sched = new FakeTaskScheduler(sc, ("exec1", "host1"))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to