This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 27aba95bdf76 [SPARK-52451][CONNECT][SQL] Make WriteOperation in 
SparkConnectPlanner side effect free
27aba95bdf76 is described below

commit 27aba95bdf762b6d1730ed890698cf14d9c4585f
Author: Yihong He <heyihong...@gmail.com>
AuthorDate: Tue Aug 5 00:06:08 2025 +0800

    [SPARK-52451][CONNECT][SQL] Make WriteOperation in SparkConnectPlanner side 
effect free
    
    ### What changes were proposed in this pull request?
    
    This PR refactors the Spark Connect execution flow to make `WriteOperation` 
handling side-effect free by separating the transformation and execution 
phases. The key changes include:
    
    1. **Unified execution flow**: Consolidated `ROOT` and `COMMAND` operations 
through `SparkConnectPlanExecution.handlePlan()` instead of separate handlers
    2. **Pure transformation phase**: Introduced `transformCommand()` that 
converts `WriteOperation` to `LogicalPlan` without side effects. It leverages 
the new DataFrameWriter methods (saveCommand(), saveAsTableCommand(), 
insertIntoCommand()), which return logical plans instead of executing 
immediately.
    3. **DataFrameWriter refactoring**: The refactor adds new DataFrameWriter 
methods—saveCommand(), saveAsTableCommand(), and insertIntoCommand()—that 
return logical plans, and it introduces a new SaveAsV1TableCommand.
    
    ### Why are the changes needed?
    
    The current implementation has several issues:
    
    1. **Side effects in transformation**: The `handleWriteOperation` method 
both transforms and executes write operations, making it difficult to reason 
about the transformation logic independently.
    
    2. **Code duplication**: Separate handling paths for `ROOT` and `COMMAND` 
operations in `ExecuteThreadRunner` create unnecessary complexity and potential 
inconsistencies.
    
    ### Does this PR introduce any user-facing change?
    
    No. This is a purely internal refactoring that maintains the same external 
behavior and API. All existing Spark Connect client code will continue to work 
without any changes.
    
    ### How was this patch tested?
    
    `build/sbt "connect-client-jvm/testOnly *ClientE2ETestSuite"`
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Cursor 1.3.5
    
    Closes #51727 from heyihong/SPARK-52451.
    
    Authored-by: Yihong He <heyihong...@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../connect/execution/ExecuteThreadRunner.scala    |  29 +---
 .../execution/SparkConnectPlanExecution.scala      |  51 +++---
 .../sql/connect/planner/SparkConnectPlanner.scala  |  58 +++++--
 .../apache/spark/sql/classic/DataFrameWriter.scala | 179 ++++++++-------------
 .../apache/spark/sql/execution/command/ddl.scala   |  64 +++++++-
 .../spark/sql/util/DataFrameCallbackSuite.scala    |   9 +-
 6 files changed, 214 insertions(+), 176 deletions(-)

diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
index fdb0ef363124..93853805b437 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
@@ -28,7 +28,7 @@ import org.apache.spark.SparkSQLException
 import org.apache.spark.connect.proto
 import org.apache.spark.internal.{Logging, LogKeys, MDC}
 import org.apache.spark.sql.connect.common.ProtoUtils
-import org.apache.spark.sql.connect.planner.SparkConnectPlanner
+import org.apache.spark.sql.connect.planner.InvalidInputErrors
 import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteSessionTag, 
SparkConnectService}
 import org.apache.spark.sql.connect.utils.ErrorUtils
 import org.apache.spark.util.Utils
@@ -218,11 +218,13 @@ private[connect] class ExecuteThreadRunner(executeHolder: 
ExecuteHolder) extends
       session.sparkContext.setLocalProperty("callSite.long", 
Utils.abbreviate(debugString, 2048))
 
       executeHolder.request.getPlan.getOpTypeCase match {
-        case proto.Plan.OpTypeCase.COMMAND => 
handleCommand(executeHolder.request)
-        case proto.Plan.OpTypeCase.ROOT => handlePlan(executeHolder.request)
-        case _ =>
-          throw new UnsupportedOperationException(
-            s"${executeHolder.request.getPlan.getOpTypeCase} not supported.")
+        case proto.Plan.OpTypeCase.ROOT | proto.Plan.OpTypeCase.COMMAND =>
+          val execution = new SparkConnectPlanExecution(executeHolder)
+          execution.handlePlan(executeHolder.responseObserver)
+        case other =>
+          throw InvalidInputErrors.invalidOneOfField(
+            other,
+            executeHolder.request.getPlan.getDescriptorForType)
       }
 
       val observedMetrics: Map[String, Seq[(Option[String], Any)]] = {
@@ -304,21 +306,6 @@ private[connect] class ExecuteThreadRunner(executeHolder: 
ExecuteHolder) extends
       
proto.StreamingQueryListenerBusCommand.CommandCase.ADD_LISTENER_BUS_LISTENER
   }
 
-  private def handlePlan(request: proto.ExecutePlanRequest): Unit = {
-    val responseObserver = executeHolder.responseObserver
-
-    val execution = new SparkConnectPlanExecution(executeHolder)
-    execution.handlePlan(responseObserver)
-  }
-
-  private def handleCommand(request: proto.ExecutePlanRequest): Unit = {
-    val responseObserver = executeHolder.responseObserver
-
-    val command = request.getPlan.getCommand
-    val planner = new SparkConnectPlanner(executeHolder)
-    planner.process(command = command, responseObserver = responseObserver)
-  }
-
   private def requestString(request: Message) = {
     try {
       Utils.redact(
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
index 65b9863ca954..9050a84fda56 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
@@ -32,10 +32,10 @@ import org.apache.spark.sql.classic.{DataFrame, Dataset}
 import org.apache.spark.sql.connect.common.DataTypeProtoConverter
 import 
org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto
 import 
org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE
-import org.apache.spark.sql.connect.planner.SparkConnectPlanner
+import org.apache.spark.sql.connect.planner.{InvalidInputErrors, 
SparkConnectPlanner}
 import org.apache.spark.sql.connect.service.ExecuteHolder
 import org.apache.spark.sql.connect.utils.MetricGenerator
-import org.apache.spark.sql.execution.{DoNotCleanup, LocalTableScanExec, 
RemoveShuffleFiles, SkipMigration, SQLExecution}
+import org.apache.spark.sql.execution.{DoNotCleanup, LocalTableScanExec, 
QueryExecution, RemoveShuffleFiles, SkipMigration, SQLExecution}
 import org.apache.spark.sql.execution.arrow.ArrowConverters
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.StructType
@@ -53,10 +53,6 @@ private[execution] class 
SparkConnectPlanExecution(executeHolder: ExecuteHolder)
 
   def handlePlan(responseObserver: 
ExecuteResponseObserver[proto.ExecutePlanResponse]): Unit = {
     val request = executeHolder.request
-    if (request.getPlan.getOpTypeCase != proto.Plan.OpTypeCase.ROOT) {
-      throw new IllegalStateException(
-        s"Illegal operation type ${request.getPlan.getOpTypeCase} to be 
handled here.")
-    }
     val planner = new SparkConnectPlanner(executeHolder)
     val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
     val conf = session.sessionState.conf
@@ -68,19 +64,36 @@ private[execution] class 
SparkConnectPlanExecution(executeHolder: ExecuteHolder)
       } else {
         DoNotCleanup
       }
-    val dataframe =
-      Dataset.ofRows(
-        sessionHolder.session,
-        planner.transformRelation(request.getPlan.getRoot, cachePlan = true),
-        tracker,
-        shuffleCleanupMode)
-    responseObserver.onNext(createSchemaResponse(request.getSessionId, 
dataframe.schema))
-    processAsArrowBatches(dataframe, responseObserver, executeHolder)
-    
responseObserver.onNext(MetricGenerator.createMetricsResponse(sessionHolder, 
dataframe))
-    createObservedMetricsResponse(
-      request.getSessionId,
-      executeHolder.allObservationAndPlanIds,
-      dataframe).foreach(responseObserver.onNext)
+    request.getPlan.getOpTypeCase match {
+      case proto.Plan.OpTypeCase.ROOT =>
+        val dataframe =
+          Dataset.ofRows(
+            sessionHolder.session,
+            planner.transformRelation(request.getPlan.getRoot, cachePlan = 
true),
+            tracker,
+            shuffleCleanupMode)
+        responseObserver.onNext(createSchemaResponse(request.getSessionId, 
dataframe.schema))
+        processAsArrowBatches(dataframe, responseObserver, executeHolder)
+        
responseObserver.onNext(MetricGenerator.createMetricsResponse(sessionHolder, 
dataframe))
+        createObservedMetricsResponse(
+          request.getSessionId,
+          executeHolder.allObservationAndPlanIds,
+          dataframe).foreach(responseObserver.onNext)
+      case proto.Plan.OpTypeCase.COMMAND =>
+        val command = request.getPlan.getCommand
+        planner.transformCommand(command, tracker) match {
+          case Some(plan) =>
+            val qe =
+              new QueryExecution(session, plan, tracker, shuffleCleanupMode = 
shuffleCleanupMode)
+            qe.assertCommandExecuted()
+            executeHolder.eventsManager.postFinished()
+          case None =>
+            planner.process(command, responseObserver)
+        }
+      case other =>
+        throw InvalidInputErrors.invalidOneOfField(other, 
request.getPlan.getDescriptorForType)
+    }
+
   }
 
   type Batch = (Array[Byte], Long)
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index bcd643a30253..7320c6e3918c 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -59,7 +59,7 @@ import 
org.apache.spark.sql.catalyst.streaming.InternalOutputModes
 import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin, TreePattern}
 import org.apache.spark.sql.catalyst.types.DataTypeUtils
 import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, 
CharVarcharUtils}
-import org.apache.spark.sql.classic.{Catalog, Dataset, MergeIntoWriter, 
RelationalGroupedDataset, SparkSession, TypedAggUtils, UserDefinedFunctionUtils}
+import org.apache.spark.sql.classic.{Catalog, DataFrameWriter, Dataset, 
MergeIntoWriter, RelationalGroupedDataset, SparkSession, TypedAggUtils, 
UserDefinedFunctionUtils}
 import org.apache.spark.sql.classic.ClassicConversions._
 import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
 import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, 
ForeachWriterPacket, LiteralValueProtoConverter, StorageLevelProtoConverter, 
StreamingListenerPacket, UdfPacket}
@@ -2646,6 +2646,17 @@ class SparkConnectPlanner(
     process(command, new MockObserver())
   }
 
+  def transformCommand(
+      command: proto.Command,
+      tracker: QueryPlanningTracker): Option[LogicalPlan] = {
+    command.getCommandTypeCase match {
+      case proto.Command.CommandTypeCase.WRITE_OPERATION =>
+        Some(transformWriteOperation(command.getWriteOperation, tracker))
+      case _ =>
+        None
+    }
+  }
+
   def process(
       command: proto.Command,
       responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
@@ -3078,23 +3089,15 @@ class SparkConnectPlanner(
     executeHolder.eventsManager.postFinished()
   }
 
-  /**
-   * Transforms the write operation and executes it.
-   *
-   * The input write operation contains a reference to the input plan and 
transforms it to the
-   * corresponding logical plan. Afterwards, creates the DataFrameWriter and 
translates the
-   * parameters of the WriteOperation into the corresponding methods calls.
-   *
-   * @param writeOperation
-   */
-  private def handleWriteOperation(writeOperation: proto.WriteOperation): Unit 
= {
+  private def transformWriteOperation(
+      writeOperation: proto.WriteOperation,
+      tracker: QueryPlanningTracker): LogicalPlan = {
     // Transform the input plan into the logical plan.
     val plan = transformRelation(writeOperation.getInput)
     // And create a Dataset from the plan.
-    val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
     val dataset = Dataset.ofRows(session, plan, tracker)
 
-    val w = dataset.write
+    val w = dataset.write.asInstanceOf[DataFrameWriter[_]]
     if (writeOperation.getMode != 
proto.WriteOperation.SaveMode.SAVE_MODE_UNSPECIFIED) {
       w.mode(SaveModeConverter.toSaveMode(writeOperation.getMode))
     }
@@ -3129,20 +3132,41 @@ class SparkConnectPlanner(
     }
 
     writeOperation.getSaveTypeCase match {
-      case proto.WriteOperation.SaveTypeCase.SAVETYPE_NOT_SET => w.save()
-      case proto.WriteOperation.SaveTypeCase.PATH => 
w.save(writeOperation.getPath)
+      case proto.WriteOperation.SaveTypeCase.SAVETYPE_NOT_SET => 
w.saveCommand(None)
+      case proto.WriteOperation.SaveTypeCase.PATH =>
+        w.saveCommand(Some(writeOperation.getPath))
       case proto.WriteOperation.SaveTypeCase.TABLE =>
         val tableName = writeOperation.getTable.getTableName
         writeOperation.getTable.getSaveMethod match {
           case 
proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_SAVE_AS_TABLE 
=>
-            w.saveAsTable(tableName)
+            w.saveAsTableCommand(tableName)
           case 
proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_INSERT_INTO =>
-            w.insertInto(tableName)
+            w.insertIntoCommand(tableName)
           case other => throw InvalidInputErrors.invalidEnum(other)
         }
       case other =>
         throw InvalidInputErrors.invalidOneOfField(other, 
writeOperation.getDescriptorForType)
     }
+  }
+
+  private def runCommand(command: LogicalPlan, tracker: QueryPlanningTracker): 
Unit = {
+    val qe = new QueryExecution(session, command, tracker)
+    qe.assertCommandExecuted()
+  }
+
+  /**
+   * Transforms the write operation and executes it.
+   *
+   * The input write operation contains a reference to the input plan and 
transforms it to the
+   * corresponding logical plan. Afterwards, creates the DataFrameWriter and 
translates the
+   * parameters of the WriteOperation into the corresponding methods calls.
+   *
+   * @param writeOperation
+   */
+  private def handleWriteOperation(writeOperation: proto.WriteOperation): Unit 
= {
+    val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
+    runCommand(transformWriteOperation(writeOperation, tracker), tracker)
+
     executeHolder.eventsManager.postFinished()
   }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala
index 501b4985128d..737dc5b1e21e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala
@@ -25,7 +25,7 @@ import org.apache.spark.annotation.Stable
 import org.apache.spark.sql
 import org.apache.spark.sql.SaveMode
 import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, 
NoSuchTableException, UnresolvedIdentifier, UnresolvedRelation}
+import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, 
UnresolvedIdentifier, UnresolvedRelation}
 import org.apache.spark.sql.catalyst.catalog._
 import org.apache.spark.sql.catalyst.expressions.Literal
 import org.apache.spark.sql.catalyst.plans.logical._
@@ -36,8 +36,8 @@ import 
org.apache.spark.sql.connector.catalog.TableWritePrivilege._
 import org.apache.spark.sql.connector.expressions.{ClusterByTransform, 
FieldReference, IdentityTransform, Transform}
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.execution.QueryExecution
-import org.apache.spark.sql.execution.command.DDLUtils
-import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, 
DataSourceUtils, LogicalRelation}
+import org.apache.spark.sql.execution.command.{DDLUtils, SaveAsV1TableCommand}
+import org.apache.spark.sql.execution.datasources.{DataSource, DataSourceUtils}
 import org.apache.spark.sql.execution.datasources.v2._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode
@@ -115,7 +115,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) 
extends sql.DataFram
         extraOptions.contains("path")) {
       throw QueryCompilationErrors.pathOptionNotSetCorrectlyWhenWritingError()
     }
-    saveInternal(Some(path))
+    runCommand(df.sparkSession) {
+      saveCommand(Some(path))
+    }
   }
 
   /**
@@ -123,9 +125,13 @@ final class DataFrameWriter[T] private[sql](ds: 
Dataset[T]) extends sql.DataFram
    *
    * @since 1.4.0
    */
-  def save(): Unit = saveInternal(None)
+  def save(): Unit = {
+    runCommand(df.sparkSession) {
+      saveCommand(None)
+    }
+  }
 
-  private def saveInternal(path: Option[String]): Unit = {
+  private[sql] def saveCommand(path: Option[String]): LogicalPlan = {
     if (source.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) {
       throw 
QueryCompilationErrors.cannotOperateOnHiveDataSourceFilesError("write")
     }
@@ -179,23 +185,19 @@ final class DataFrameWriter[T] private[sql](ds: 
Dataset[T]) extends sql.DataFram
                 // Streaming also uses the data source V2 API. So it may be 
that the data source
                 // implements v2, but has no v2 implementation for batch 
writes. In that case, we
                 // fall back to saving as though it's a V1 source.
-                return saveToV1Source(path)
+                return saveToV1SourceCommand(path)
               }
           }
 
           val relation = DataSourceV2Relation.create(table, catalog, ident, 
dsOptions)
           checkPartitioningMatchesV2Table(table)
           if (curmode == SaveMode.Append) {
-            runCommand(df.sparkSession) {
-              AppendData.byName(relation, df.logicalPlan, finalOptions)
-            }
+            AppendData.byName(relation, df.logicalPlan, finalOptions)
           } else {
             // Truncate the table. TableCapabilityCheck will throw a nice 
exception if this
             // isn't supported
-            runCommand(df.sparkSession) {
-              OverwriteByExpression.byName(
-                relation, df.logicalPlan, Literal(true), finalOptions)
-            }
+            OverwriteByExpression.byName(
+              relation, df.logicalPlan, Literal(true), finalOptions)
           }
 
         case createMode =>
@@ -215,16 +217,14 @@ final class DataFrameWriter[T] private[sql](ds: 
Dataset[T]) extends sql.DataFram
                 serde = None,
                 external = false,
                 constraints = Seq.empty)
-              runCommand(df.sparkSession) {
-                CreateTableAsSelect(
-                  UnresolvedIdentifier(
-                    catalog.name +: ident.namespace.toImmutableArraySeq :+ 
ident.name),
-                  partitioningAsV2,
-                  df.queryExecution.analyzed,
-                  tableSpec,
-                  finalOptions,
-                  ignoreIfExists = createMode == SaveMode.Ignore)
-              }
+              CreateTableAsSelect(
+                UnresolvedIdentifier(
+                  catalog.name +: ident.namespace.toImmutableArraySeq :+ 
ident.name),
+                partitioningAsV2,
+                df.queryExecution.analyzed,
+                tableSpec,
+                finalOptions,
+                ignoreIfExists = createMode == SaveMode.Ignore)
             case _: TableProvider =>
               if (getTable.supports(BATCH_WRITE)) {
                 throw 
QueryCompilationErrors.writeWithSaveModeUnsupportedBySourceError(
@@ -233,13 +233,13 @@ final class DataFrameWriter[T] private[sql](ds: 
Dataset[T]) extends sql.DataFram
                 // Streaming also uses the data source V2 API. So it may be 
that the data source
                 // implements v2, but has no v2 implementation for batch 
writes. In that case, we
                 // fallback to saving as though it's a V1 source.
-                saveToV1Source(path)
+                saveToV1SourceCommand(path)
               }
           }
       }
 
     } else {
-      saveToV1Source(path)
+      saveToV1SourceCommand(path)
     }
   }
 
@@ -251,7 +251,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) 
extends sql.DataFram
     }
   }
 
-  private def saveToV1Source(path: Option[String]): Unit = {
+  private def saveToV1SourceCommand(path: Option[String]): LogicalPlan = {
     partitioningColumns.foreach { columns =>
       extraOptions = extraOptions + (
         DataSourceUtils.PARTITIONING_COLUMNS_KEY ->
@@ -266,13 +266,11 @@ final class DataFrameWriter[T] private[sql](ds: 
Dataset[T]) extends sql.DataFram
     val optionsWithPath = getOptionsWithPath(path)
 
     // Code path for data source v1.
-    runCommand(df.sparkSession) {
-      DataSource(
-        sparkSession = df.sparkSession,
-        className = source,
-        partitionColumns = partitioningColumns.getOrElse(Nil),
-        options = optionsWithPath.originalMap).planForWriting(curmode, 
df.logicalPlan)
-    }
+    DataSource(
+      sparkSession = df.sparkSession,
+      className = source,
+      partitionColumns = partitioningColumns.getOrElse(Nil),
+      options = optionsWithPath.originalMap).planForWriting(curmode, 
df.logicalPlan)
   }
 
   /**
@@ -304,6 +302,12 @@ final class DataFrameWriter[T] private[sql](ds: 
Dataset[T]) extends sql.DataFram
    * @since 1.4.0
    */
   def insertInto(tableName: String): Unit = {
+    runCommand(df.sparkSession) {
+      insertIntoCommand(tableName)
+    }
+  }
+
+  private[sql] def insertIntoCommand(tableName: String): LogicalPlan = {
     import df.sparkSession.sessionState.analyzer.{AsTableIdentifier, 
NonSessionCatalogAndIdentifier, SessionCatalogAndIdentifier}
     import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
 
@@ -318,30 +322,30 @@ final class DataFrameWriter[T] private[sql](ds: 
Dataset[T]) extends sql.DataFram
 
     session.sessionState.sqlParser.parseMultipartIdentifier(tableName) match {
       case NonSessionCatalogAndIdentifier(catalog, ident) =>
-        insertInto(catalog, ident)
+        insertIntoCommand(catalog, ident)
 
       case SessionCatalogAndIdentifier(catalog, ident)
           if canUseV2 && ident.namespace().length <= 1 =>
-        insertInto(catalog, ident)
+        insertIntoCommand(catalog, ident)
 
       case AsTableIdentifier(tableIdentifier) =>
-        insertInto(tableIdentifier)
+        insertIntoCommand(tableIdentifier)
       case other =>
         throw 
QueryCompilationErrors.cannotFindCatalogToHandleIdentifierError(other.quoted)
     }
   }
 
-  private def insertInto(catalog: CatalogPlugin, ident: Identifier): Unit = {
+  private def insertIntoCommand(catalog: CatalogPlugin, ident: Identifier): 
LogicalPlan = {
     import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
 
     val table = catalog.asTableCatalog.loadTable(ident, 
getWritePrivileges.toSet.asJava) match {
       case _: V1Table =>
-        return insertInto(TableIdentifier(ident.name(), 
ident.namespace().headOption))
+        return insertIntoCommand(TableIdentifier(ident.name(), 
ident.namespace().headOption))
       case t =>
         DataSourceV2Relation.create(t, Some(catalog), Some(ident))
     }
 
-    val command = curmode match {
+    curmode match {
       case SaveMode.Append | SaveMode.ErrorIfExists | SaveMode.Ignore =>
         AppendData.byPosition(table, df.logicalPlan, extraOptions.toMap)
 
@@ -356,22 +360,16 @@ final class DataFrameWriter[T] private[sql](ds: 
Dataset[T]) extends sql.DataFram
           OverwriteByExpression.byPosition(table, df.logicalPlan, 
Literal(true), extraOptions.toMap)
         }
     }
-
-    runCommand(df.sparkSession) {
-      command
-    }
   }
 
-  private def insertInto(tableIdent: TableIdentifier): Unit = {
-    runCommand(df.sparkSession) {
-      InsertIntoStatement(
-        table = 
UnresolvedRelation(tableIdent).requireWritePrivileges(getWritePrivileges),
-        partitionSpec = Map.empty[String, Option[String]],
-        Nil,
-        query = df.logicalPlan,
-        overwrite = curmode == SaveMode.Overwrite,
-        ifPartitionNotExists = false)
-    }
+  private def insertIntoCommand(tableIdent: TableIdentifier): LogicalPlan = {
+    InsertIntoStatement(
+      table = 
UnresolvedRelation(tableIdent).requireWritePrivileges(getWritePrivileges),
+      partitionSpec = Map.empty[String, Option[String]],
+      Nil,
+      query = df.logicalPlan,
+      overwrite = curmode == SaveMode.Overwrite,
+      ifPartitionNotExists = false)
   }
 
   private def getWritePrivileges: Seq[TableWritePrivilege] = curmode match {
@@ -430,6 +428,12 @@ final class DataFrameWriter[T] private[sql](ds: 
Dataset[T]) extends sql.DataFram
    * @since 1.4.0
    */
   def saveAsTable(tableName: String): Unit = {
+    runCommand(df.sparkSession) {
+      saveAsTableCommand(tableName)
+    }
+  }
+
+  private[sql] def saveAsTableCommand(tableName: String): LogicalPlan = {
     import df.sparkSession.sessionState.analyzer.{AsTableIdentifier, 
NonSessionCatalogAndIdentifier, SessionCatalogAndIdentifier}
     import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
 
@@ -440,30 +444,29 @@ final class DataFrameWriter[T] private[sql](ds: 
Dataset[T]) extends sql.DataFram
 
     session.sessionState.sqlParser.parseMultipartIdentifier(tableName) match {
       case nameParts @ NonSessionCatalogAndIdentifier(catalog, ident) =>
-        saveAsTable(catalog.asTableCatalog, ident, nameParts)
+        saveAsTableCommand(catalog.asTableCatalog, ident, nameParts)
 
       case nameParts @ SessionCatalogAndIdentifier(catalog, ident)
           if canUseV2 && ident.namespace().length <= 1 =>
-        saveAsTable(catalog.asTableCatalog, ident, nameParts)
+        saveAsTableCommand(catalog.asTableCatalog, ident, nameParts)
 
       case AsTableIdentifier(tableIdentifier) =>
-        saveAsTable(tableIdentifier)
+        saveAsV1TableCommand(tableIdentifier)
 
       case other =>
         throw 
QueryCompilationErrors.cannotFindCatalogToHandleIdentifierError(other.quoted)
     }
   }
 
-
-  private def saveAsTable(
-      catalog: TableCatalog, ident: Identifier, nameParts: Seq[String]): Unit 
= {
+  private def saveAsTableCommand(
+      catalog: TableCatalog, ident: Identifier, nameParts: Seq[String]): 
LogicalPlan = {
     val tableOpt = try Option(catalog.loadTable(ident, 
getWritePrivileges.toSet.asJava)) catch {
       case _: NoSuchTableException => None
     }
 
-    val command = (curmode, tableOpt) match {
+    (curmode, tableOpt) match {
       case (_, Some(_: V1Table)) =>
-        return saveAsTable(TableIdentifier(ident.name(), 
ident.namespace().headOption))
+        saveAsV1TableCommand(TableIdentifier(ident.name(), 
ident.namespace().headOption))
 
       case (SaveMode.Append, Some(table)) =>
         checkPartitioningMatchesV2Table(table)
@@ -512,56 +515,9 @@ final class DataFrameWriter[T] private[sql](ds: 
Dataset[T]) extends sql.DataFram
           writeOptions = extraOptions.toMap,
           other == SaveMode.Ignore)
     }
-
-    runCommand(df.sparkSession) {
-      command
-    }
-  }
-
-  private def saveAsTable(tableIdent: TableIdentifier): Unit = {
-    val catalog = df.sparkSession.sessionState.catalog
-    val qualifiedIdent = catalog.qualifyIdentifier(tableIdent)
-    val tableExists = catalog.tableExists(qualifiedIdent)
-
-    (tableExists, curmode) match {
-      case (true, SaveMode.Ignore) =>
-        // Do nothing
-
-      case (true, SaveMode.ErrorIfExists) =>
-        throw QueryCompilationErrors.tableAlreadyExistsError(qualifiedIdent)
-
-      case (true, SaveMode.Overwrite) =>
-        // Get all input data source or hive relations of the query.
-        val srcRelations = df.logicalPlan.collect {
-          case l: LogicalRelation => l.relation
-          case relation: HiveTableRelation => relation.tableMeta.identifier
-        }
-
-        val tableRelation = 
df.sparkSession.table(qualifiedIdent).queryExecution.analyzed
-        EliminateSubqueryAliases(tableRelation) match {
-          // check if the table is a data source table (the relation is a 
BaseRelation).
-          case l: LogicalRelation if srcRelations.contains(l.relation) =>
-            throw 
QueryCompilationErrors.cannotOverwriteTableThatIsBeingReadFromError(
-              qualifiedIdent)
-          // check hive table relation when overwrite mode
-          case relation: HiveTableRelation
-              if srcRelations.contains(relation.tableMeta.identifier) =>
-            throw 
QueryCompilationErrors.cannotOverwriteTableThatIsBeingReadFromError(
-              qualifiedIdent)
-          case _ => // OK
-        }
-
-        // Drop the existing table
-        catalog.dropTable(qualifiedIdent, ignoreIfNotExists = true, purge = 
false)
-        createTable(qualifiedIdent)
-        // Refresh the cache of the table in the catalog.
-        catalog.refreshTable(qualifiedIdent)
-
-      case _ => createTable(qualifiedIdent)
-    }
   }
 
-  private def createTable(tableIdent: TableIdentifier): Unit = {
+  private def saveAsV1TableCommand(tableIdent: TableIdentifier): 
SaveAsV1TableCommand = {
     val storage = DataSource.buildStorageFormatFromOptions(extraOptions.toMap)
     val tableType = if (storage.locationUri.isDefined) {
       CatalogTableType.EXTERNAL
@@ -586,8 +542,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) 
extends sql.DataFram
       bucketSpec = getBucketSpec,
       properties = properties)
 
-    runCommand(df.sparkSession)(
-      CreateTable(tableDesc, curmode, Some(df.logicalPlan)))
+    SaveAsV1TableCommand(tableDesc, curmode, df.logicalPlan)
   }
 
   /** Converts the provided partitioning and bucketing information to 
DataSourceV2 Transforms. */
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
index d76d8cf1cb71..2969194ae2f4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
@@ -30,9 +30,9 @@ import org.apache.hadoop.mapred.{FileInputFormat, JobConf}
 
 import org.apache.spark.internal.{Logging, LogKeys, MDC}
 import org.apache.spark.internal.config.RDD_PARALLEL_LISTING_THRESHOLD
-import org.apache.spark.sql.{Row, SparkSession}
+import org.apache.spark.sql.{Row, SaveMode, SparkSession}
 import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.analysis.Resolver
+import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, 
Resolver}
 import org.apache.spark.sql.catalyst.catalog._
 import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
 import org.apache.spark.sql.catalyst.expressions.Attribute
@@ -46,7 +46,7 @@ import 
org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAM
 import org.apache.spark.sql.connector.catalog.SupportsNamespaces._
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import 
org.apache.spark.sql.errors.QueryExecutionErrors.hiveTableWithAnsiIntervalsError
-import org.apache.spark.sql.execution.datasources.{DataSource, 
DataSourceUtils, FileFormat, HadoopFsRelation, LogicalRelationWithTable}
+import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, 
DataSourceUtils, FileFormat, HadoopFsRelation, LogicalRelation, 
LogicalRelationWithTable}
 import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2
 import org.apache.spark.sql.internal.{HiveSerDe, SQLConf}
 import org.apache.spark.sql.types._
@@ -956,6 +956,64 @@ case class AlterTableSetLocationCommand(
   }
 }
 
+/**
+ * A command that saves a query as a V1 table.
+ */
+private[sql] case class SaveAsV1TableCommand(
+    tableDesc: CatalogTable,
+    mode: SaveMode,
+    query: LogicalPlan) extends LeafRunnableCommand {
+  override def run(sparkSession: SparkSession): Seq[Row] = {
+    val catalog = sparkSession.sessionState.catalog
+    val qualifiedIdent = catalog.qualifyIdentifier(tableDesc.identifier)
+    val tableDescWithQualifiedIdent = tableDesc.copy(identifier = 
qualifiedIdent)
+    val tableExists = catalog.tableExists(qualifiedIdent)
+
+    (tableExists, mode) match {
+      case (true, SaveMode.Ignore) =>
+        // Do nothing
+
+      case (true, SaveMode.ErrorIfExists) =>
+        throw QueryCompilationErrors.tableAlreadyExistsError(qualifiedIdent)
+
+      case (true, SaveMode.Overwrite) =>
+        // Get all input data source or hive relations of the query.
+        val srcRelations = query.collect {
+          case l: LogicalRelation => l.relation
+          case relation: HiveTableRelation => relation.tableMeta.identifier
+        }
+
+        val tableRelation = 
sparkSession.table(qualifiedIdent).queryExecution.analyzed
+        EliminateSubqueryAliases(tableRelation) match {
+          // check if the table is a data source table (the relation is a 
BaseRelation).
+          case l: LogicalRelation if srcRelations.contains(l.relation) =>
+            throw 
QueryCompilationErrors.cannotOverwriteTableThatIsBeingReadFromError(
+              qualifiedIdent)
+          // check hive table relation when overwrite mode
+          case relation: HiveTableRelation
+              if srcRelations.contains(relation.tableMeta.identifier) =>
+            throw 
QueryCompilationErrors.cannotOverwriteTableThatIsBeingReadFromError(
+              qualifiedIdent)
+          case _ => // OK
+        }
+
+        // Drop the existing table
+        catalog.dropTable(qualifiedIdent, ignoreIfNotExists = true, purge = 
false)
+        runCommand(sparkSession, CreateTable(tableDescWithQualifiedIdent, 
mode, Some(query)))
+        // Refresh the cache of the table in the catalog.
+        catalog.refreshTable(qualifiedIdent)
+
+      case _ =>
+        runCommand(sparkSession, CreateTable(tableDescWithQualifiedIdent, 
mode, Some(query)))
+    }
+    Seq.empty[Row]
+  }
+
+  private def runCommand(session: SparkSession, command: LogicalPlan): Unit = {
+    val qe = session.sessionState.executePlan(command)
+    qe.assertCommandExecuted()
+  }
+}
 
 object DDLUtils extends Logging {
   val HIVE_PROVIDER = "hive"
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
index 1ec9aca857e2..25c82717a515 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
@@ -256,11 +256,12 @@ class DataFrameCallbackSuite extends QueryTest
     withTable("tab") {
       spark.range(10).select($"id", $"id" % 5 as 
"p").write.partitionBy("p").saveAsTable("tab")
       sparkContext.listenerBus.waitUntilEmpty()
-      // CTAS would derive 3 query executions
-      // 1. CreateDataSourceTableAsSelectCommand
+      // CTAS would derive 4 query executions
+      // 1. DropTable
       // 2. InsertIntoHadoopFsRelationCommand
-      // 3. CommandResultExec
-      assert(commands.length == 6)
+      // 3. CreateDataSourceTableAsSelectCommand
+      // 4. SaveAsV1TableCommand
+      assert(commands.length == 7)
       assert(commands(5)._1 == "command")
       assert(commands(5)._2.isInstanceOf[CreateDataSourceTableAsSelectCommand])
       assert(commands(5)._2.asInstanceOf[CreateDataSourceTableAsSelectCommand]


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to