This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 28da1d853477 [SPARK-46383] Reduce Driver Heap Usage by Reducing the Lifespan of `TaskInfo.accumulables()` 28da1d853477 is described below commit 28da1d853477b306774798d8aa738901221fb804 Author: Utkarsh <utkarsh.agar...@databricks.com> AuthorDate: Fri Jan 12 10:28:22 2024 +0800 [SPARK-46383] Reduce Driver Heap Usage by Reducing the Lifespan of `TaskInfo.accumulables()` ### What changes were proposed in this pull request? `AccumulableInfo` is one of the top heap consumers in driver's heap dumps for stages with many tasks. For a stage with a large number of tasks (**_O(100k)_**), we saw **30%** of the heap usage stemming from `TaskInfo.accumulables()`. ![image](https://github.com/apache/spark/assets/10495099/13ef5d07-abfc-47fd-81b6-705f599db011) The `TaskSetManager` today keeps around the TaskInfo objects ([ref1](https://github.com/apache/spark/blob/c1ba963e64a22dea28e17b1ed954e6d03d38da1e/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala#L134), [ref2](https://github.com/apache/spark/blob/c1ba963e64a22dea28e17b1ed954e6d03d38da1e/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala#L192))) and in turn the task metrics (`AccumulableInfo`) for every task attempt until the stage is completed. This [...] This PR is an opt-in change (disabled by default) to reduce the driver's heap usage for stages with many tasks by no longer referencing the task metrics of completed tasks. Once a task is completed in `TaskSetManager`, we no longer keep its metrics around. Upon task completion, we clone the `TaskInfo` object and empty out the metrics for the clone. The cloned `TaskInfo` is retained by the `TaskSetManager` while the original `TaskInfo` object with the metrics is sent over to the `DAGSc [...] ### Config to gate changes The changes in the PR are guarded with the Spark conf `spark.scheduler.dropTaskInfoAccumulablesOnTaskCompletion.enabled` which can be used for rollback or staged rollouts. ### Why are the changes disabled by default? The PR introduces a breaking change wherein the `TaskInfo.accumulables()` are empty for `Resubmitted` tasks upon the loss of an executor. Read https://github.com/apache/spark/pull/44321#pullrequestreview-1785137821 for details. ### Why are the changes needed? Reduce driver's heap usage, especially for stages with many tasks ## Benchmarking On a cluster running a scan stage with 100k tasks, the TaskSetManager's heap usage dropped from 1.1 GB to 37 MB. This **reduced the total driver's heap usage by 38%**, down to 2 GB from 3.5 GB. **BEFORE** ![image](https://github.com/databricks/runtime/assets/10495099/7c1599f3-3587-48a1-b019-84115b1bb90d) **WITH FIX** <img width="1386" alt="image" src="https://github.com/databricks/runtime/assets/10495099/b85129c8-dc10-4ee2-898d-61c8e7449616"> ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added new tests and did benchmarking on a cluster. ### Was this patch authored or co-authored using generative AI tooling? Generated-by: Github Copilot Closes #44321 from utkarsh39/SPARK-46383. Authored-by: Utkarsh <utkarsh.agar...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../org/apache/spark/internal/config/package.scala | 10 +++ .../org/apache/spark/scheduler/TaskInfo.scala | 10 ++- .../apache/spark/scheduler/TaskSetManager.scala | 71 +++++++++++++++++++--- .../spark/scheduler/SparkListenerSuite.scala | 35 +++++++++++ .../spark/scheduler/TaskSetManagerSuite.scala | 51 ++++++++++++++++ 5 files changed, 169 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 2823b7cdb602..bbd79c8b9653 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -2620,4 +2620,14 @@ package object config { .stringConf .toSequence .createWithDefault("org.apache.spark.sql.connect.client" :: Nil) + + private[spark] val DROP_TASK_INFO_ACCUMULABLES_ON_TASK_COMPLETION = + ConfigBuilder("spark.scheduler.dropTaskInfoAccumulablesOnTaskCompletion.enabled") + .internal() + .doc("If true, the task info accumulables will be cleared upon task completion in " + + "TaskSetManager. This reduces the heap usage of the driver by only referencing the " + + "task info accumulables for the active tasks and not for completed tasks.") + .version("4.0.0") + .booleanConf + .createWithDefault(false) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index 2d4624828a94..9ed95870d240 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -45,7 +45,7 @@ class TaskInfo( val executorId: String, val host: String, val taskLocality: TaskLocality.TaskLocality, - val speculative: Boolean) { + val speculative: Boolean) extends Cloneable { /** * This api doesn't contains partitionId, please use the new api. @@ -83,6 +83,14 @@ class TaskInfo( _accumulables = newAccumulables } + override def clone(): TaskInfo = super.clone().asInstanceOf[TaskInfo] + + private[scheduler] def cloneWithEmptyAccumulables(): TaskInfo = { + val cloned = clone() + cloned.setAccumulables(Nil) + cloned + } + /** * The time when the task has completed successfully (including the time to remotely fetch * results, if necessary). diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index e15ba28eeda0..390689cb8f72 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -256,6 +256,9 @@ private[spark] class TaskSetManager( private[scheduler] var emittedTaskSizeWarning = false + private[scheduler] val dropTaskInfoAccumulablesOnTaskCompletion = + conf.get(DROP_TASK_INFO_ACCUMULABLES_ON_TASK_COMPLETION) + /** Add a task to all the pending-task lists that it should be on. */ private[spark] def addPendingTask( index: Int, @@ -785,6 +788,11 @@ private[spark] class TaskSetManager( // SPARK-37300: when the task was already finished state, just ignore it, // so that there won't cause successful and tasksSuccessful wrong result. if(info.finished) { + if (dropTaskInfoAccumulablesOnTaskCompletion) { + // SPARK-46383: Clear out the accumulables for a completed task to reduce accumulable + // lifetime. + info.setAccumulables(Nil) + } return } val index = info.index @@ -802,6 +810,8 @@ private[spark] class TaskSetManager( // Handle this task as a killed task handleFailedTask(tid, TaskState.KILLED, TaskKilled("Finish but did not commit due to another attempt succeeded")) + // SPARK-46383: Not clearing the accumulables here because they are already cleared in + // handleFailedTask. return } @@ -844,11 +854,50 @@ private[spark] class TaskSetManager( // "result.value()" in "TaskResultGetter.enqueueSuccessfulTask" before reaching here. // Note: "result.value()" only deserializes the value when it's called at the first time, so // here "result.value()" just returns the value and won't block other threads. - sched.dagScheduler.taskEnded(tasks(index), Success, result.value(), result.accumUpdates, - result.metricPeaks, info) + + emptyTaskInfoAccumulablesAndNotifyDagScheduler(tid, tasks(index), Success, result.value(), + result.accumUpdates, result.metricPeaks) maybeFinishTaskSet() } + /** + * A wrapper around [[DAGScheduler.taskEnded()]] that empties out the accumulables for the + * TaskInfo object, corresponding to the completed task, referenced by this class. + * + * SPARK-46383: For the completed task, we ship the original TaskInfo to the DAGScheduler and only + * retain a cloned TaskInfo in this class. We then set the accumulables to Nil for the TaskInfo + * object that corresponds to the completed task. + * We do this to release references to `TaskInfo.accumulables()` as the TaskInfo + * objects held by this class are long-lived and have a heavy memory footprint on the driver. + * + * This is safe as the TaskInfo accumulables are not needed once they are shipped to the + * DAGScheduler where they are aggregated. Additionally, the original TaskInfo, and not a + * clone, must be sent to the DAGScheduler as this TaskInfo object is sent to the + * DAGScheduler on multiple events during the task's lifetime. Users can install + * SparkListeners that compare the TaskInfo objects across these SparkListener events and + * thus the TaskInfo object sent to the DAGScheduler must always reference the same TaskInfo + * object. + */ + private def emptyTaskInfoAccumulablesAndNotifyDagScheduler( + taskId: Long, + task: Task[_], + reason: TaskEndReason, + result: Any, + accumUpdates: Seq[AccumulatorV2[_, _]], + metricPeaks: Array[Long]): Unit = { + val taskInfoWithAccumulables = taskInfos(taskId); + if (dropTaskInfoAccumulablesOnTaskCompletion) { + val index = taskInfoWithAccumulables.index + val clonedTaskInfo = taskInfoWithAccumulables.cloneWithEmptyAccumulables() + // Update this task's taskInfo while preserving its position in the list + taskAttempts(index) = + taskAttempts(index).map { i => if (i eq taskInfoWithAccumulables) clonedTaskInfo else i } + taskInfos(taskId) = clonedTaskInfo + } + sched.dagScheduler.taskEnded(task, reason, result, accumUpdates, metricPeaks, + taskInfoWithAccumulables) + } + private[scheduler] def markPartitionCompleted(partitionId: Int): Unit = { partitionToIndex.get(partitionId).foreach { index => if (!successful(index)) { @@ -872,6 +921,11 @@ private[spark] class TaskSetManager( // SPARK-37300: when the task was already finished state, just ignore it, // so that there won't cause copiesRunning wrong result. if (info.finished) { + if (dropTaskInfoAccumulablesOnTaskCompletion) { + // SPARK-46383: Clear out the accumulables for a completed task to reduce accumulable + // lifetime. + info.setAccumulables(Nil) + } return } removeRunningTask(tid) @@ -906,7 +960,8 @@ private[spark] class TaskSetManager( if (ef.className == classOf[NotSerializableException].getName) { // If the task result wasn't serializable, there's no point in trying to re-execute it. logError(s"$task had a not serializable result: ${ef.description}; not retrying") - sched.dagScheduler.taskEnded(tasks(index), reason, null, accumUpdates, metricPeaks, info) + emptyTaskInfoAccumulablesAndNotifyDagScheduler(tid, tasks(index), reason, null, + accumUpdates, metricPeaks) abort(s"$task had a not serializable result: ${ef.description}") return } @@ -915,7 +970,8 @@ private[spark] class TaskSetManager( // re-execute it. logError("Task %s in stage %s (TID %d) can not write to output file: %s; not retrying" .format(info.id, taskSet.id, tid, ef.description)) - sched.dagScheduler.taskEnded(tasks(index), reason, null, accumUpdates, metricPeaks, info) + emptyTaskInfoAccumulablesAndNotifyDagScheduler(tid, tasks(index), reason, null, + accumUpdates, metricPeaks) abort("Task %s in stage %s (TID %d) can not write to output file: %s".format( info.id, taskSet.id, tid, ef.description)) return @@ -968,7 +1024,8 @@ private[spark] class TaskSetManager( isZombie = true } - sched.dagScheduler.taskEnded(tasks(index), reason, null, accumUpdates, metricPeaks, info) + emptyTaskInfoAccumulablesAndNotifyDagScheduler(tid, tasks(index), reason, null, + accumUpdates, metricPeaks) if (!isZombie && reason.countTowardsTaskFailures) { assert (null != failureReason) @@ -1084,8 +1141,8 @@ private[spark] class TaskSetManager( addPendingTask(index) // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our // stage finishes when a total of tasks.size tasks finish. - sched.dagScheduler.taskEnded( - tasks(index), Resubmitted, null, Seq.empty, Array.empty, info) + emptyTaskInfoAccumulablesAndNotifyDagScheduler(tid, + tasks(index), Resubmitted, null, Seq.empty, Array.empty) } } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 29e27e96908f..34b2a40d1e3b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler import java.io.{Externalizable, ObjectInput, ObjectOutput} +import java.util.{Collections, IdentityHashMap} import java.util.concurrent.Semaphore import scala.collection.mutable @@ -289,6 +290,19 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match stageInfo.rddInfos.forall(_.numPartitions == 4) should be {true} } + test("SPARK-46383: Track TaskInfo objects") { + // Test that the same TaskInfo object is sent to the `DAGScheduler` in the `onTaskStart` and + // `onTaskEnd` events. + val conf = new SparkConf().set(DROP_TASK_INFO_ACCUMULABLES_ON_TASK_COMPLETION, true) + sc = new SparkContext("local", "SparkListenerSuite", conf) + val listener = new SaveActiveTaskInfos + sc.addSparkListener(listener) + val rdd1 = sc.parallelize(1 to 100, 4) + sc.runJob(rdd1, (items: Iterator[Int]) => items.size, Seq(0, 1)) + sc.listenerBus.waitUntilEmpty() + listener.taskInfos.size should be { 0 } + } + test("local metrics") { sc = new SparkContext("local", "SparkListenerSuite") val listener = new SaveStageAndTaskInfo @@ -643,6 +657,27 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match } } + /** + * A simple listener that tracks task infos for all active tasks. + */ + private class SaveActiveTaskInfos extends SparkListener { + // Use a set based on IdentityHashMap instead of a HashSet to track unique references of + // TaskInfo objects. + val taskInfos = Collections.newSetFromMap[TaskInfo](new IdentityHashMap) + + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { + val info = taskStart.taskInfo + if (info != null) { + taskInfos.add(info) + } + } + + override def onTaskEnd(task: SparkListenerTaskEnd): Unit = { + val info = task.taskInfo + taskInfos.remove(info) + } + } + /** * A simple listener that saves the task indices for all task events. */ diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 26b38bfcc9ab..c55de278a6d2 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -62,6 +62,12 @@ class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler) accumUpdates: Seq[AccumulatorV2[_, _]], metricPeaks: Array[Long], taskInfo: TaskInfo): Unit = { + // Set task accumulables emulating DAGScheduler behavior to enable tests related to + // `TaskInfo.accumulables`. + accumUpdates.foreach(acc => + taskInfo.setAccumulables( + acc.toInfo(Some(acc.value), Some(acc.value)) +: taskInfo.accumulables) + ) taskScheduler.endedTasks(taskInfo.index) = reason } @@ -230,6 +236,51 @@ class TaskSetManagerSuite super.afterEach() } + test("SPARK-46383: TaskInfo accumulables are cleared upon task completion") { + val conf = new SparkConf(). + set(config.DROP_TASK_INFO_ACCUMULABLES_ON_TASK_COMPLETION, true) + sc = new SparkContext("local", "test", conf) + sched = new FakeTaskScheduler(sc, ("exec1", "host1")) + val taskSet = FakeTask.createTaskSet(2) + val clock = new ManualClock + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) + val accumUpdates = taskSet.tasks.head.metrics.internalAccums + + // Offer a host. This will launch the first task. + val taskOption = manager.resourceOffer("exec1", "host1", NO_PREF)._1 + assert(taskOption.isDefined) + + clock.advance(1) + // Tell it the first task has finished successfully + manager.handleSuccessfulTask(0, createTaskResult(0, accumUpdates)) + assert(sched.endedTasks(0) === Success) + + // Only one task was launched and it completed successfully, thus the TaskInfo accumulables + // should be empty. + assert(!manager.taskInfos.exists(t => !t._2.accumulables.isEmpty)) + assert(manager.taskAttempts.flatMap(t => t.filter(!_.accumulables.isEmpty)).isEmpty) + + // Fail the second task (MAX_TASK_FAILURES - 1) times. + (1 to manager.maxTaskFailures - 1).foreach { index => + val offerResult = manager.resourceOffer("exec1", "host1", ANY)._1 + assert(offerResult.isDefined, + "Expect resource offer on iteration %s to return a task".format(index)) + assert(offerResult.get.index === 1) + manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, TaskResultLost) + } + + clock.advance(1) + // Successfully finish the second task. + val taskOption1 = manager.resourceOffer("exec1", "host1", ANY)._1 + manager.handleSuccessfulTask(taskOption1.get.taskId, createTaskResult(1, accumUpdates)) + assert(sched.endedTasks(1) === Success) + // The TaskInfo accumulables should be empty as the second task has now completed successfully. + assert(!manager.taskInfos.exists(t => !t._2.accumulables.isEmpty)) + assert(manager.taskAttempts.flatMap(t => t.filter(!_.accumulables.isEmpty)).isEmpty) + + assert(sched.finishedManagers.contains(manager)) + } + test("TaskSet with no preferences") { sc = new SparkContext("local", "test") sched = new FakeTaskScheduler(sc, ("exec1", "host1")) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org