This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new eec21d0c803f [SPARK-55035] Perform shuffle cleanup in child executions
eec21d0c803f is described below

commit eec21d0c803f298b4da78d9b9914330c5a9c3831
Author: Bo Zhang <[email protected]>
AuthorDate: Mon Jan 19 20:35:04 2026 +0800

    [SPARK-55035] Perform shuffle cleanup in child executions
    
    ### What changes were proposed in this pull request?
    For queries with multiple layers of executions (e.g., CTAS), this change is 
to perform shuffle dependency cleanup in child executions.
    
    The shuffle cleanup mode for child executions are determined based on the 
following order:
    1. Input arg when constructing this QueryExecution
    2. The cleanup mode of the root execution
    3. SQLConf from SparkSession
    
    ### Why are the changes needed?
    To make shuffle cleanup more effective.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Added a new unit test.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #53799 from bozhang2820/spark-55035.
    
    Authored-by: Bo Zhang <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../execution/SparkConnectPlanExecution.scala      |  2 +-
 .../apache/spark/sql/classic/DataFrameWriter.scala |  3 +-
 .../spark/sql/classic/DataFrameWriterV2.scala      |  4 +--
 .../org/apache/spark/sql/classic/Dataset.scala     |  4 +--
 .../spark/sql/execution/QueryExecution.scala       | 35 ++++++++++++++++++----
 .../spark/sql/execution/command/CommandUtils.scala |  4 +--
 .../streaming/runtime/IncrementalExecution.scala   |  4 +--
 .../sql/internal/BaseSessionStateBuilder.scala     |  3 +-
 .../spark/sql/execution/QueryExecutionSuite.scala  | 19 ++++++++++++
 .../sql/hive/thriftserver/SparkSQLDriver.scala     |  2 +-
 10 files changed, 63 insertions(+), 17 deletions(-)

diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
index 7dad774eed0f..4332074228d9 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
@@ -95,7 +95,7 @@ private[execution] class 
SparkConnectPlanExecution(executeHolder: ExecuteHolder)
               session,
               transformer(tracker),
               tracker,
-              shuffleCleanupMode = shuffleCleanupMode)
+              shuffleCleanupModeOpt = Some(shuffleCleanupMode))
             qe.assertCommandExecuted()
             executeHolder.eventsManager.postFinished()
           case None =>
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala
index 52012a862942..a29fcc3d1eca 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala
@@ -588,7 +588,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) 
extends sql.DataFram
    */
   private def runCommand(session: SparkSession)(command: LogicalPlan): Unit = {
     val qe = new QueryExecution(session, command, df.queryExecution.tracker,
-      shuffleCleanupMode = 
QueryExecution.determineShuffleCleanupMode(session.sessionState.conf))
+      shuffleCleanupModeOpt =
+        
Some(QueryExecution.determineShuffleCleanupMode(session.sessionState.conf)))
     qe.assertCommandExecuted()
   }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriterV2.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriterV2.scala
index 7d481b86eb77..19cc9b76beae 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriterV2.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriterV2.scala
@@ -228,8 +228,8 @@ final class DataFrameWriterV2[T] private[sql](table: 
String, ds: Dataset[T])
    */
   private def runCommand(command: LogicalPlan): Unit = {
     val qe = new QueryExecution(sparkSession, command, 
df.queryExecution.tracker,
-      shuffleCleanupMode =
-        
QueryExecution.determineShuffleCleanupMode(sparkSession.sessionState.conf))
+      shuffleCleanupModeOpt =
+        
Some(QueryExecution.determineShuffleCleanupMode(sparkSession.sessionState.conf)))
     qe.assertCommandExecuted()
   }
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
index d02b63b49ca5..088df782a541 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
@@ -121,7 +121,7 @@ private[sql] object Dataset {
       shuffleCleanupMode: ShuffleCleanupMode): DataFrame =
     sparkSession.withActive {
       val qe = new QueryExecution(
-        sparkSession, logicalPlan, shuffleCleanupMode = shuffleCleanupMode)
+        sparkSession, logicalPlan, shuffleCleanupModeOpt = 
Some(shuffleCleanupMode))
       if (!qe.isLazyAnalysis) qe.assertAnalyzed()
       new Dataset[Row](qe, () => RowEncoder.encoderFor(qe.analyzed.schema))
     }
@@ -134,7 +134,7 @@ private[sql] object Dataset {
       shuffleCleanupMode: ShuffleCleanupMode = DoNotCleanup)
     : DataFrame = sparkSession.withActive {
     val qe = new QueryExecution(
-      sparkSession, logicalPlan, tracker, shuffleCleanupMode = 
shuffleCleanupMode)
+      sparkSession, logicalPlan, tracker, shuffleCleanupModeOpt = 
Some(shuffleCleanupMode))
     if (!qe.isLazyAnalysis) qe.assertAnalyzed()
     new Dataset[Row](qe, () => RowEncoder.encoderFor(qe.analyzed.schema))
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index 28bdd0e1ef4e..f08b561d6ef9 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -40,6 +40,7 @@ import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, 
Rule}
 import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat
 import org.apache.spark.sql.catalyst.util.truncatedString
 import org.apache.spark.sql.classic.SparkSession
+import org.apache.spark.sql.execution.SQLExecution.EXECUTION_ROOT_ID_KEY
 import org.apache.spark.sql.execution.adaptive.{AdaptiveExecutionContext, 
InsertAdaptiveSparkPlan}
 import org.apache.spark.sql.execution.bucketing.{CoalesceBucketsInJoin, 
DisableUnnecessaryBucketedScan}
 import org.apache.spark.sql.execution.datasources.v2.V2TableRefreshUtil
@@ -66,7 +67,7 @@ class QueryExecution(
     val logical: LogicalPlan,
     val tracker: QueryPlanningTracker = new QueryPlanningTracker,
     val mode: CommandExecutionMode.Value = CommandExecutionMode.ALL,
-    val shuffleCleanupMode: ShuffleCleanupMode = DoNotCleanup,
+    val shuffleCleanupModeOpt: Option[ShuffleCleanupMode] = None,
     val refreshPhaseEnabled: Boolean = true,
     val queryId: UUID = UUIDv7Generator.generate()) extends Logging {
 
@@ -467,6 +468,32 @@ class QueryExecution(
     Utils.redact(sparkSession.sessionState.conf.stringRedactionPattern, 
message)
   }
 
+  /**
+   * Determine the shuffle cleanup mode, based on the following order:
+   * 1. input arg when constructing this QueryExecution
+   * 2. the cleanup mode of the root execution
+   * 3. SQLConf from SparkSession
+   */
+  lazy val shuffleCleanupMode: ShuffleCleanupMode =
+    shuffleCleanupModeOpt.getOrElse(
+      getShuffleCleanupModeFromRootExecution.getOrElse(
+        
QueryExecution.determineShuffleCleanupMode(sparkSession.sessionState.conf)))
+
+  private def getShuffleCleanupModeFromRootExecution: 
Option[ShuffleCleanupMode] = {
+    val rootExecutionIdStr = 
sparkSession.sparkContext.getLocalProperty(EXECUTION_ROOT_ID_KEY)
+    if (rootExecutionIdStr != null) {
+      val rootExecutionId = rootExecutionIdStr.toLong
+      val rootExecution = SQLExecution.getQueryExecution(rootExecutionId)
+      if (rootExecution != null) {
+        rootExecution.shuffleCleanupModeOpt
+      } else {
+        None
+      }
+    } else {
+      None
+    }
+  }
+
   def extendedExplainInfo(append: String => Unit, plan: SparkPlan): Unit = {
     val generators = 
sparkSession.sessionState.conf.getConf(SQLConf.EXTENDED_EXPLAIN_PROVIDERS)
       .getOrElse(Seq.empty)
@@ -581,7 +608,7 @@ object QueryExecution {
       sparkSession,
       logical,
       mode = CommandExecutionMode.ALL,
-      shuffleCleanupMode = 
determineShuffleCleanupMode(sparkSession.sessionState.conf),
+      shuffleCleanupModeOpt = 
Some(determineShuffleCleanupMode(sparkSession.sessionState.conf)),
       refreshPhaseEnabled = refreshPhaseEnabled)
   }
 
@@ -771,13 +798,11 @@ object QueryExecution {
       mode: CommandExecutionMode.Value = CommandExecutionMode.SKIP,
       shuffleCleanupModeOpt: Option[ShuffleCleanupMode] = None)
     : (QueryExecution, Array[InternalRow]) = {
-    val shuffleCleanupMode = shuffleCleanupModeOpt.getOrElse(
-      determineShuffleCleanupMode(sparkSession.sessionState.conf))
     val qe = new QueryExecution(
       sparkSession,
       command,
       mode = mode,
-      shuffleCleanupMode = shuffleCleanupMode,
+      shuffleCleanupModeOpt = shuffleCleanupModeOpt,
       refreshPhaseEnabled = refreshPhaseEnabled)
     val result = QueryExecution.withInternalError(s"Executed $name failed.") {
       SQLExecution.withNewExecutionId(qe, Some(name)) {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala
index e1ff1ae73094..23055037ac4c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala
@@ -308,7 +308,7 @@ object CommandUtils extends Logging {
 
     val namedExpressions = expressions.map(e => Alias(e, e.toString)())
     val statsRow = new QueryExecution(sparkSession, Aggregate(Nil, 
namedExpressions, relation),
-      shuffleCleanupMode = RemoveShuffleFiles).executedPlan.executeTake(1).head
+      shuffleCleanupModeOpt = 
Some(RemoveShuffleFiles)).executedPlan.executeTake(1).head
 
     val rowCount = statsRow.getLong(0)
     val columnStats = columns.zipWithIndex.map { case (attr, i) =>
@@ -345,7 +345,7 @@ object CommandUtils extends Logging {
       }
 
       val percentilesRow = new QueryExecution(sparkSession, Aggregate(Nil, 
namedExprs, relation),
-        shuffleCleanupMode = 
RemoveShuffleFiles).executedPlan.executeTake(1).head
+        shuffleCleanupModeOpt = 
Some(RemoveShuffleFiles)).executedPlan.executeTake(1).head
       attrsToGenHistogram.zipWithIndex.foreach { case (attr, i) =>
         val percentiles = percentilesRow.getArray(i)
         // When there is no non-null value, `percentiles` is null. In such 
case, there is no
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala
index 594329173b3b..169ab6f606da 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala
@@ -83,8 +83,8 @@ class IncrementalExecution(
     mode: CommandExecutionMode.Value = CommandExecutionMode.ALL,
     val isTerminatingTrigger: Boolean = false)
   extends QueryExecution(sparkSession, logicalPlan, mode = mode,
-    shuffleCleanupMode =
-      
QueryExecution.determineShuffleCleanupMode(sparkSession.sessionState.conf),
+    shuffleCleanupModeOpt =
+      
Some(QueryExecution.determineShuffleCleanupMode(sparkSession.sessionState.conf)),
     queryId = queryId) with Logging {
 
   // Modified planner with stateful operations.
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
index 7e3a6b9dbb7e..d2d72ee21c72 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
@@ -415,7 +415,8 @@ abstract class BaseSessionStateBuilder(
   protected def createQueryExecution:
     (LogicalPlan, CommandExecutionMode.Value) => QueryExecution =
       (plan, mode) => new QueryExecution(session, plan, mode = mode,
-        shuffleCleanupMode = 
QueryExecution.determineShuffleCleanupMode(session.sessionState.conf))
+        shuffleCleanupModeOpt =
+          
Some(QueryExecution.determineShuffleCleanupMode(session.sessionState.conf)))
 
   /**
    * Interface to start and stop streaming queries.
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala
index 0d0a0f2f3100..af3a0f3e3710 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala
@@ -410,6 +410,25 @@ class QueryExecutionSuite extends SharedSparkSession {
     }
   }
 
+  test("SPARK-55035: Shuffle cleanup performed in child executions") {
+    val sourceDF = spark.range(100).repartition(10)
+    sourceDF.createOrReplaceTempView("source")
+
+    val createTablePlan = spark.sessionState.sqlParser.parsePlan(
+      """
+        CREATE TABLE child_exec_test
+        USING parquet
+        AS SELECT * FROM source WHERE id < 50
+      """)
+    val df = Dataset.ofRows(spark, createTablePlan, RemoveShuffleFiles)
+    df.collect()
+
+    val blockManager = spark.sparkContext.env.blockManager
+    assert(blockManager.migratableResolver.getStoredShuffles().isEmpty)
+    assert(blockManager.diskBlockManager.getAllBlocks().isEmpty)
+    cleanupShuffles()
+  }
+
   test("SPARK-35378: Return UnsafeRow in CommandResultExecCheck execute 
methods") {
     val plan = spark.sql("SHOW FUNCTIONS").queryExecution.executedPlan
     assert(plan.isInstanceOf[CommandResultExec])
diff --git 
a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala
 
b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala
index 7a220d516757..baa8f8cc07cf 100644
--- 
a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala
+++ 
b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala
@@ -80,7 +80,7 @@ private[hive] class SparkSQLDriver(val sparkSession: 
SparkSession = SparkSQLEnv.
       val execution = new QueryExecution(
         sparkSession.asInstanceOf[org.apache.spark.sql.classic.SparkSession],
         logicalPlan,
-        shuffleCleanupMode = shuffleCleanupMode)
+        shuffleCleanupModeOpt = Some(shuffleCleanupMode))
 
       // the above execution already has an execution ID, therefore we don't 
need to
       // wrap it again with a new execution ID when getting Hive result.


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to