This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 27aba95bdf76 [SPARK-52451][CONNECT][SQL] Make WriteOperation in SparkConnectPlanner side effect free 27aba95bdf76 is described below commit 27aba95bdf762b6d1730ed890698cf14d9c4585f Author: Yihong He <heyihong...@gmail.com> AuthorDate: Tue Aug 5 00:06:08 2025 +0800 [SPARK-52451][CONNECT][SQL] Make WriteOperation in SparkConnectPlanner side effect free ### What changes were proposed in this pull request? This PR refactors the Spark Connect execution flow to make `WriteOperation` handling side-effect free by separating the transformation and execution phases. The key changes include: 1. **Unified execution flow**: Consolidated `ROOT` and `COMMAND` operations through `SparkConnectPlanExecution.handlePlan()` instead of separate handlers 2. **Pure transformation phase**: Introduced `transformCommand()` that converts `WriteOperation` to `LogicalPlan` without side effects. It leverages the new DataFrameWriter methods (saveCommand(), saveAsTableCommand(), insertIntoCommand()), which return logical plans instead of executing immediately. 3. **DataFrameWriter refactoring**: The refactor adds new DataFrameWriter methods—saveCommand(), saveAsTableCommand(), and insertIntoCommand()—that return logical plans, and it introduces a new SaveAsV1TableCommand. ### Why are the changes needed? The current implementation has several issues: 1. **Side effects in transformation**: The `handleWriteOperation` method both transforms and executes write operations, making it difficult to reason about the transformation logic independently. 2. **Code duplication**: Separate handling paths for `ROOT` and `COMMAND` operations in `ExecuteThreadRunner` create unnecessary complexity and potential inconsistencies. ### Does this PR introduce any user-facing change? No. This is a purely internal refactoring that maintains the same external behavior and API. All existing Spark Connect client code will continue to work without any changes. ### How was this patch tested? `build/sbt "connect-client-jvm/testOnly *ClientE2ETestSuite"` ### Was this patch authored or co-authored using generative AI tooling? Cursor 1.3.5 Closes #51727 from heyihong/SPARK-52451. Authored-by: Yihong He <heyihong...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../connect/execution/ExecuteThreadRunner.scala | 29 +--- .../execution/SparkConnectPlanExecution.scala | 51 +++--- .../sql/connect/planner/SparkConnectPlanner.scala | 58 +++++-- .../apache/spark/sql/classic/DataFrameWriter.scala | 179 ++++++++------------- .../apache/spark/sql/execution/command/ddl.scala | 64 +++++++- .../spark/sql/util/DataFrameCallbackSuite.scala | 9 +- 6 files changed, 214 insertions(+), 176 deletions(-) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala index fdb0ef363124..93853805b437 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala @@ -28,7 +28,7 @@ import org.apache.spark.SparkSQLException import org.apache.spark.connect.proto import org.apache.spark.internal.{Logging, LogKeys, MDC} import org.apache.spark.sql.connect.common.ProtoUtils -import org.apache.spark.sql.connect.planner.SparkConnectPlanner +import org.apache.spark.sql.connect.planner.InvalidInputErrors import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteSessionTag, SparkConnectService} import org.apache.spark.sql.connect.utils.ErrorUtils import org.apache.spark.util.Utils @@ -218,11 +218,13 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends session.sparkContext.setLocalProperty("callSite.long", Utils.abbreviate(debugString, 2048)) executeHolder.request.getPlan.getOpTypeCase match { - case proto.Plan.OpTypeCase.COMMAND => handleCommand(executeHolder.request) - case proto.Plan.OpTypeCase.ROOT => handlePlan(executeHolder.request) - case _ => - throw new UnsupportedOperationException( - s"${executeHolder.request.getPlan.getOpTypeCase} not supported.") + case proto.Plan.OpTypeCase.ROOT | proto.Plan.OpTypeCase.COMMAND => + val execution = new SparkConnectPlanExecution(executeHolder) + execution.handlePlan(executeHolder.responseObserver) + case other => + throw InvalidInputErrors.invalidOneOfField( + other, + executeHolder.request.getPlan.getDescriptorForType) } val observedMetrics: Map[String, Seq[(Option[String], Any)]] = { @@ -304,21 +306,6 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends proto.StreamingQueryListenerBusCommand.CommandCase.ADD_LISTENER_BUS_LISTENER } - private def handlePlan(request: proto.ExecutePlanRequest): Unit = { - val responseObserver = executeHolder.responseObserver - - val execution = new SparkConnectPlanExecution(executeHolder) - execution.handlePlan(responseObserver) - } - - private def handleCommand(request: proto.ExecutePlanRequest): Unit = { - val responseObserver = executeHolder.responseObserver - - val command = request.getPlan.getCommand - val planner = new SparkConnectPlanner(executeHolder) - planner.process(command = command, responseObserver = responseObserver) - } - private def requestString(request: Message) = { try { Utils.redact( diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala index 65b9863ca954..9050a84fda56 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala @@ -32,10 +32,10 @@ import org.apache.spark.sql.classic.{DataFrame, Dataset} import org.apache.spark.sql.connect.common.DataTypeProtoConverter import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE -import org.apache.spark.sql.connect.planner.SparkConnectPlanner +import org.apache.spark.sql.connect.planner.{InvalidInputErrors, SparkConnectPlanner} import org.apache.spark.sql.connect.service.ExecuteHolder import org.apache.spark.sql.connect.utils.MetricGenerator -import org.apache.spark.sql.execution.{DoNotCleanup, LocalTableScanExec, RemoveShuffleFiles, SkipMigration, SQLExecution} +import org.apache.spark.sql.execution.{DoNotCleanup, LocalTableScanExec, QueryExecution, RemoveShuffleFiles, SkipMigration, SQLExecution} import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType @@ -53,10 +53,6 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) def handlePlan(responseObserver: ExecuteResponseObserver[proto.ExecutePlanResponse]): Unit = { val request = executeHolder.request - if (request.getPlan.getOpTypeCase != proto.Plan.OpTypeCase.ROOT) { - throw new IllegalStateException( - s"Illegal operation type ${request.getPlan.getOpTypeCase} to be handled here.") - } val planner = new SparkConnectPlanner(executeHolder) val tracker = executeHolder.eventsManager.createQueryPlanningTracker() val conf = session.sessionState.conf @@ -68,19 +64,36 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) } else { DoNotCleanup } - val dataframe = - Dataset.ofRows( - sessionHolder.session, - planner.transformRelation(request.getPlan.getRoot, cachePlan = true), - tracker, - shuffleCleanupMode) - responseObserver.onNext(createSchemaResponse(request.getSessionId, dataframe.schema)) - processAsArrowBatches(dataframe, responseObserver, executeHolder) - responseObserver.onNext(MetricGenerator.createMetricsResponse(sessionHolder, dataframe)) - createObservedMetricsResponse( - request.getSessionId, - executeHolder.allObservationAndPlanIds, - dataframe).foreach(responseObserver.onNext) + request.getPlan.getOpTypeCase match { + case proto.Plan.OpTypeCase.ROOT => + val dataframe = + Dataset.ofRows( + sessionHolder.session, + planner.transformRelation(request.getPlan.getRoot, cachePlan = true), + tracker, + shuffleCleanupMode) + responseObserver.onNext(createSchemaResponse(request.getSessionId, dataframe.schema)) + processAsArrowBatches(dataframe, responseObserver, executeHolder) + responseObserver.onNext(MetricGenerator.createMetricsResponse(sessionHolder, dataframe)) + createObservedMetricsResponse( + request.getSessionId, + executeHolder.allObservationAndPlanIds, + dataframe).foreach(responseObserver.onNext) + case proto.Plan.OpTypeCase.COMMAND => + val command = request.getPlan.getCommand + planner.transformCommand(command, tracker) match { + case Some(plan) => + val qe = + new QueryExecution(session, plan, tracker, shuffleCleanupMode = shuffleCleanupMode) + qe.assertCommandExecuted() + executeHolder.eventsManager.postFinished() + case None => + planner.process(command, responseObserver) + } + case other => + throw InvalidInputErrors.invalidOneOfField(other, request.getPlan.getDescriptorForType) + } + } type Batch = (Array[Byte], Long) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index bcd643a30253..7320c6e3918c 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -59,7 +59,7 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin, TreePattern} import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils} -import org.apache.spark.sql.classic.{Catalog, Dataset, MergeIntoWriter, RelationalGroupedDataset, SparkSession, TypedAggUtils, UserDefinedFunctionUtils} +import org.apache.spark.sql.classic.{Catalog, DataFrameWriter, Dataset, MergeIntoWriter, RelationalGroupedDataset, SparkSession, TypedAggUtils, UserDefinedFunctionUtils} import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.connect.client.arrow.ArrowSerializer import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ForeachWriterPacket, LiteralValueProtoConverter, StorageLevelProtoConverter, StreamingListenerPacket, UdfPacket} @@ -2646,6 +2646,17 @@ class SparkConnectPlanner( process(command, new MockObserver()) } + def transformCommand( + command: proto.Command, + tracker: QueryPlanningTracker): Option[LogicalPlan] = { + command.getCommandTypeCase match { + case proto.Command.CommandTypeCase.WRITE_OPERATION => + Some(transformWriteOperation(command.getWriteOperation, tracker)) + case _ => + None + } + } + def process( command: proto.Command, responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { @@ -3078,23 +3089,15 @@ class SparkConnectPlanner( executeHolder.eventsManager.postFinished() } - /** - * Transforms the write operation and executes it. - * - * The input write operation contains a reference to the input plan and transforms it to the - * corresponding logical plan. Afterwards, creates the DataFrameWriter and translates the - * parameters of the WriteOperation into the corresponding methods calls. - * - * @param writeOperation - */ - private def handleWriteOperation(writeOperation: proto.WriteOperation): Unit = { + private def transformWriteOperation( + writeOperation: proto.WriteOperation, + tracker: QueryPlanningTracker): LogicalPlan = { // Transform the input plan into the logical plan. val plan = transformRelation(writeOperation.getInput) // And create a Dataset from the plan. - val tracker = executeHolder.eventsManager.createQueryPlanningTracker() val dataset = Dataset.ofRows(session, plan, tracker) - val w = dataset.write + val w = dataset.write.asInstanceOf[DataFrameWriter[_]] if (writeOperation.getMode != proto.WriteOperation.SaveMode.SAVE_MODE_UNSPECIFIED) { w.mode(SaveModeConverter.toSaveMode(writeOperation.getMode)) } @@ -3129,20 +3132,41 @@ class SparkConnectPlanner( } writeOperation.getSaveTypeCase match { - case proto.WriteOperation.SaveTypeCase.SAVETYPE_NOT_SET => w.save() - case proto.WriteOperation.SaveTypeCase.PATH => w.save(writeOperation.getPath) + case proto.WriteOperation.SaveTypeCase.SAVETYPE_NOT_SET => w.saveCommand(None) + case proto.WriteOperation.SaveTypeCase.PATH => + w.saveCommand(Some(writeOperation.getPath)) case proto.WriteOperation.SaveTypeCase.TABLE => val tableName = writeOperation.getTable.getTableName writeOperation.getTable.getSaveMethod match { case proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_SAVE_AS_TABLE => - w.saveAsTable(tableName) + w.saveAsTableCommand(tableName) case proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_INSERT_INTO => - w.insertInto(tableName) + w.insertIntoCommand(tableName) case other => throw InvalidInputErrors.invalidEnum(other) } case other => throw InvalidInputErrors.invalidOneOfField(other, writeOperation.getDescriptorForType) } + } + + private def runCommand(command: LogicalPlan, tracker: QueryPlanningTracker): Unit = { + val qe = new QueryExecution(session, command, tracker) + qe.assertCommandExecuted() + } + + /** + * Transforms the write operation and executes it. + * + * The input write operation contains a reference to the input plan and transforms it to the + * corresponding logical plan. Afterwards, creates the DataFrameWriter and translates the + * parameters of the WriteOperation into the corresponding methods calls. + * + * @param writeOperation + */ + private def handleWriteOperation(writeOperation: proto.WriteOperation): Unit = { + val tracker = executeHolder.eventsManager.createQueryPlanningTracker() + runCommand(transformWriteOperation(writeOperation, tracker), tracker) + executeHolder.eventsManager.postFinished() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala index 501b4985128d..737dc5b1e21e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala @@ -25,7 +25,7 @@ import org.apache.spark.annotation.Stable import org.apache.spark.sql import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, NoSuchTableException, UnresolvedIdentifier, UnresolvedRelation} +import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, UnresolvedIdentifier, UnresolvedRelation} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.logical._ @@ -36,8 +36,8 @@ import org.apache.spark.sql.connector.catalog.TableWritePrivilege._ import org.apache.spark.sql.connector.expressions.{ClusterByTransform, FieldReference, IdentityTransform, Transform} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.QueryExecution -import org.apache.spark.sql.execution.command.DDLUtils -import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, DataSourceUtils, LogicalRelation} +import org.apache.spark.sql.execution.command.{DDLUtils, SaveAsV1TableCommand} +import org.apache.spark.sql.execution.datasources.{DataSource, DataSourceUtils} import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode @@ -115,7 +115,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram extraOptions.contains("path")) { throw QueryCompilationErrors.pathOptionNotSetCorrectlyWhenWritingError() } - saveInternal(Some(path)) + runCommand(df.sparkSession) { + saveCommand(Some(path)) + } } /** @@ -123,9 +125,13 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram * * @since 1.4.0 */ - def save(): Unit = saveInternal(None) + def save(): Unit = { + runCommand(df.sparkSession) { + saveCommand(None) + } + } - private def saveInternal(path: Option[String]): Unit = { + private[sql] def saveCommand(path: Option[String]): LogicalPlan = { if (source.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) { throw QueryCompilationErrors.cannotOperateOnHiveDataSourceFilesError("write") } @@ -179,23 +185,19 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram // Streaming also uses the data source V2 API. So it may be that the data source // implements v2, but has no v2 implementation for batch writes. In that case, we // fall back to saving as though it's a V1 source. - return saveToV1Source(path) + return saveToV1SourceCommand(path) } } val relation = DataSourceV2Relation.create(table, catalog, ident, dsOptions) checkPartitioningMatchesV2Table(table) if (curmode == SaveMode.Append) { - runCommand(df.sparkSession) { - AppendData.byName(relation, df.logicalPlan, finalOptions) - } + AppendData.byName(relation, df.logicalPlan, finalOptions) } else { // Truncate the table. TableCapabilityCheck will throw a nice exception if this // isn't supported - runCommand(df.sparkSession) { - OverwriteByExpression.byName( - relation, df.logicalPlan, Literal(true), finalOptions) - } + OverwriteByExpression.byName( + relation, df.logicalPlan, Literal(true), finalOptions) } case createMode => @@ -215,16 +217,14 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram serde = None, external = false, constraints = Seq.empty) - runCommand(df.sparkSession) { - CreateTableAsSelect( - UnresolvedIdentifier( - catalog.name +: ident.namespace.toImmutableArraySeq :+ ident.name), - partitioningAsV2, - df.queryExecution.analyzed, - tableSpec, - finalOptions, - ignoreIfExists = createMode == SaveMode.Ignore) - } + CreateTableAsSelect( + UnresolvedIdentifier( + catalog.name +: ident.namespace.toImmutableArraySeq :+ ident.name), + partitioningAsV2, + df.queryExecution.analyzed, + tableSpec, + finalOptions, + ignoreIfExists = createMode == SaveMode.Ignore) case _: TableProvider => if (getTable.supports(BATCH_WRITE)) { throw QueryCompilationErrors.writeWithSaveModeUnsupportedBySourceError( @@ -233,13 +233,13 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram // Streaming also uses the data source V2 API. So it may be that the data source // implements v2, but has no v2 implementation for batch writes. In that case, we // fallback to saving as though it's a V1 source. - saveToV1Source(path) + saveToV1SourceCommand(path) } } } } else { - saveToV1Source(path) + saveToV1SourceCommand(path) } } @@ -251,7 +251,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram } } - private def saveToV1Source(path: Option[String]): Unit = { + private def saveToV1SourceCommand(path: Option[String]): LogicalPlan = { partitioningColumns.foreach { columns => extraOptions = extraOptions + ( DataSourceUtils.PARTITIONING_COLUMNS_KEY -> @@ -266,13 +266,11 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram val optionsWithPath = getOptionsWithPath(path) // Code path for data source v1. - runCommand(df.sparkSession) { - DataSource( - sparkSession = df.sparkSession, - className = source, - partitionColumns = partitioningColumns.getOrElse(Nil), - options = optionsWithPath.originalMap).planForWriting(curmode, df.logicalPlan) - } + DataSource( + sparkSession = df.sparkSession, + className = source, + partitionColumns = partitioningColumns.getOrElse(Nil), + options = optionsWithPath.originalMap).planForWriting(curmode, df.logicalPlan) } /** @@ -304,6 +302,12 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram * @since 1.4.0 */ def insertInto(tableName: String): Unit = { + runCommand(df.sparkSession) { + insertIntoCommand(tableName) + } + } + + private[sql] def insertIntoCommand(tableName: String): LogicalPlan = { import df.sparkSession.sessionState.analyzer.{AsTableIdentifier, NonSessionCatalogAndIdentifier, SessionCatalogAndIdentifier} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ @@ -318,30 +322,30 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram session.sessionState.sqlParser.parseMultipartIdentifier(tableName) match { case NonSessionCatalogAndIdentifier(catalog, ident) => - insertInto(catalog, ident) + insertIntoCommand(catalog, ident) case SessionCatalogAndIdentifier(catalog, ident) if canUseV2 && ident.namespace().length <= 1 => - insertInto(catalog, ident) + insertIntoCommand(catalog, ident) case AsTableIdentifier(tableIdentifier) => - insertInto(tableIdentifier) + insertIntoCommand(tableIdentifier) case other => throw QueryCompilationErrors.cannotFindCatalogToHandleIdentifierError(other.quoted) } } - private def insertInto(catalog: CatalogPlugin, ident: Identifier): Unit = { + private def insertIntoCommand(catalog: CatalogPlugin, ident: Identifier): LogicalPlan = { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ val table = catalog.asTableCatalog.loadTable(ident, getWritePrivileges.toSet.asJava) match { case _: V1Table => - return insertInto(TableIdentifier(ident.name(), ident.namespace().headOption)) + return insertIntoCommand(TableIdentifier(ident.name(), ident.namespace().headOption)) case t => DataSourceV2Relation.create(t, Some(catalog), Some(ident)) } - val command = curmode match { + curmode match { case SaveMode.Append | SaveMode.ErrorIfExists | SaveMode.Ignore => AppendData.byPosition(table, df.logicalPlan, extraOptions.toMap) @@ -356,22 +360,16 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram OverwriteByExpression.byPosition(table, df.logicalPlan, Literal(true), extraOptions.toMap) } } - - runCommand(df.sparkSession) { - command - } } - private def insertInto(tableIdent: TableIdentifier): Unit = { - runCommand(df.sparkSession) { - InsertIntoStatement( - table = UnresolvedRelation(tableIdent).requireWritePrivileges(getWritePrivileges), - partitionSpec = Map.empty[String, Option[String]], - Nil, - query = df.logicalPlan, - overwrite = curmode == SaveMode.Overwrite, - ifPartitionNotExists = false) - } + private def insertIntoCommand(tableIdent: TableIdentifier): LogicalPlan = { + InsertIntoStatement( + table = UnresolvedRelation(tableIdent).requireWritePrivileges(getWritePrivileges), + partitionSpec = Map.empty[String, Option[String]], + Nil, + query = df.logicalPlan, + overwrite = curmode == SaveMode.Overwrite, + ifPartitionNotExists = false) } private def getWritePrivileges: Seq[TableWritePrivilege] = curmode match { @@ -430,6 +428,12 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram * @since 1.4.0 */ def saveAsTable(tableName: String): Unit = { + runCommand(df.sparkSession) { + saveAsTableCommand(tableName) + } + } + + private[sql] def saveAsTableCommand(tableName: String): LogicalPlan = { import df.sparkSession.sessionState.analyzer.{AsTableIdentifier, NonSessionCatalogAndIdentifier, SessionCatalogAndIdentifier} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ @@ -440,30 +444,29 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram session.sessionState.sqlParser.parseMultipartIdentifier(tableName) match { case nameParts @ NonSessionCatalogAndIdentifier(catalog, ident) => - saveAsTable(catalog.asTableCatalog, ident, nameParts) + saveAsTableCommand(catalog.asTableCatalog, ident, nameParts) case nameParts @ SessionCatalogAndIdentifier(catalog, ident) if canUseV2 && ident.namespace().length <= 1 => - saveAsTable(catalog.asTableCatalog, ident, nameParts) + saveAsTableCommand(catalog.asTableCatalog, ident, nameParts) case AsTableIdentifier(tableIdentifier) => - saveAsTable(tableIdentifier) + saveAsV1TableCommand(tableIdentifier) case other => throw QueryCompilationErrors.cannotFindCatalogToHandleIdentifierError(other.quoted) } } - - private def saveAsTable( - catalog: TableCatalog, ident: Identifier, nameParts: Seq[String]): Unit = { + private def saveAsTableCommand( + catalog: TableCatalog, ident: Identifier, nameParts: Seq[String]): LogicalPlan = { val tableOpt = try Option(catalog.loadTable(ident, getWritePrivileges.toSet.asJava)) catch { case _: NoSuchTableException => None } - val command = (curmode, tableOpt) match { + (curmode, tableOpt) match { case (_, Some(_: V1Table)) => - return saveAsTable(TableIdentifier(ident.name(), ident.namespace().headOption)) + saveAsV1TableCommand(TableIdentifier(ident.name(), ident.namespace().headOption)) case (SaveMode.Append, Some(table)) => checkPartitioningMatchesV2Table(table) @@ -512,56 +515,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram writeOptions = extraOptions.toMap, other == SaveMode.Ignore) } - - runCommand(df.sparkSession) { - command - } - } - - private def saveAsTable(tableIdent: TableIdentifier): Unit = { - val catalog = df.sparkSession.sessionState.catalog - val qualifiedIdent = catalog.qualifyIdentifier(tableIdent) - val tableExists = catalog.tableExists(qualifiedIdent) - - (tableExists, curmode) match { - case (true, SaveMode.Ignore) => - // Do nothing - - case (true, SaveMode.ErrorIfExists) => - throw QueryCompilationErrors.tableAlreadyExistsError(qualifiedIdent) - - case (true, SaveMode.Overwrite) => - // Get all input data source or hive relations of the query. - val srcRelations = df.logicalPlan.collect { - case l: LogicalRelation => l.relation - case relation: HiveTableRelation => relation.tableMeta.identifier - } - - val tableRelation = df.sparkSession.table(qualifiedIdent).queryExecution.analyzed - EliminateSubqueryAliases(tableRelation) match { - // check if the table is a data source table (the relation is a BaseRelation). - case l: LogicalRelation if srcRelations.contains(l.relation) => - throw QueryCompilationErrors.cannotOverwriteTableThatIsBeingReadFromError( - qualifiedIdent) - // check hive table relation when overwrite mode - case relation: HiveTableRelation - if srcRelations.contains(relation.tableMeta.identifier) => - throw QueryCompilationErrors.cannotOverwriteTableThatIsBeingReadFromError( - qualifiedIdent) - case _ => // OK - } - - // Drop the existing table - catalog.dropTable(qualifiedIdent, ignoreIfNotExists = true, purge = false) - createTable(qualifiedIdent) - // Refresh the cache of the table in the catalog. - catalog.refreshTable(qualifiedIdent) - - case _ => createTable(qualifiedIdent) - } } - private def createTable(tableIdent: TableIdentifier): Unit = { + private def saveAsV1TableCommand(tableIdent: TableIdentifier): SaveAsV1TableCommand = { val storage = DataSource.buildStorageFormatFromOptions(extraOptions.toMap) val tableType = if (storage.locationUri.isDefined) { CatalogTableType.EXTERNAL @@ -586,8 +542,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram bucketSpec = getBucketSpec, properties = properties) - runCommand(df.sparkSession)( - CreateTable(tableDesc, curmode, Some(df.logicalPlan))) + SaveAsV1TableCommand(tableDesc, curmode, df.logicalPlan) } /** Converts the provided partitioning and bucketing information to DataSourceV2 Transforms. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index d76d8cf1cb71..2969194ae2f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -30,9 +30,9 @@ import org.apache.hadoop.mapred.{FileInputFormat, JobConf} import org.apache.spark.internal.{Logging, LogKeys, MDC} import org.apache.spark.internal.config.RDD_PARALLEL_LISTING_THRESHOLD -import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.{Row, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, Resolver} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.Attribute @@ -46,7 +46,7 @@ import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAM import org.apache.spark.sql.connector.catalog.SupportsNamespaces._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.errors.QueryExecutionErrors.hiveTableWithAnsiIntervalsError -import org.apache.spark.sql.execution.datasources.{DataSource, DataSourceUtils, FileFormat, HadoopFsRelation, LogicalRelationWithTable} +import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, DataSourceUtils, FileFormat, HadoopFsRelation, LogicalRelation, LogicalRelationWithTable} import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.types._ @@ -956,6 +956,64 @@ case class AlterTableSetLocationCommand( } } +/** + * A command that saves a query as a V1 table. + */ +private[sql] case class SaveAsV1TableCommand( + tableDesc: CatalogTable, + mode: SaveMode, + query: LogicalPlan) extends LeafRunnableCommand { + override def run(sparkSession: SparkSession): Seq[Row] = { + val catalog = sparkSession.sessionState.catalog + val qualifiedIdent = catalog.qualifyIdentifier(tableDesc.identifier) + val tableDescWithQualifiedIdent = tableDesc.copy(identifier = qualifiedIdent) + val tableExists = catalog.tableExists(qualifiedIdent) + + (tableExists, mode) match { + case (true, SaveMode.Ignore) => + // Do nothing + + case (true, SaveMode.ErrorIfExists) => + throw QueryCompilationErrors.tableAlreadyExistsError(qualifiedIdent) + + case (true, SaveMode.Overwrite) => + // Get all input data source or hive relations of the query. + val srcRelations = query.collect { + case l: LogicalRelation => l.relation + case relation: HiveTableRelation => relation.tableMeta.identifier + } + + val tableRelation = sparkSession.table(qualifiedIdent).queryExecution.analyzed + EliminateSubqueryAliases(tableRelation) match { + // check if the table is a data source table (the relation is a BaseRelation). + case l: LogicalRelation if srcRelations.contains(l.relation) => + throw QueryCompilationErrors.cannotOverwriteTableThatIsBeingReadFromError( + qualifiedIdent) + // check hive table relation when overwrite mode + case relation: HiveTableRelation + if srcRelations.contains(relation.tableMeta.identifier) => + throw QueryCompilationErrors.cannotOverwriteTableThatIsBeingReadFromError( + qualifiedIdent) + case _ => // OK + } + + // Drop the existing table + catalog.dropTable(qualifiedIdent, ignoreIfNotExists = true, purge = false) + runCommand(sparkSession, CreateTable(tableDescWithQualifiedIdent, mode, Some(query))) + // Refresh the cache of the table in the catalog. + catalog.refreshTable(qualifiedIdent) + + case _ => + runCommand(sparkSession, CreateTable(tableDescWithQualifiedIdent, mode, Some(query))) + } + Seq.empty[Row] + } + + private def runCommand(session: SparkSession, command: LogicalPlan): Unit = { + val qe = session.sessionState.executePlan(command) + qe.assertCommandExecuted() + } +} object DDLUtils extends Logging { val HIVE_PROVIDER = "hive" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index 1ec9aca857e2..25c82717a515 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -256,11 +256,12 @@ class DataFrameCallbackSuite extends QueryTest withTable("tab") { spark.range(10).select($"id", $"id" % 5 as "p").write.partitionBy("p").saveAsTable("tab") sparkContext.listenerBus.waitUntilEmpty() - // CTAS would derive 3 query executions - // 1. CreateDataSourceTableAsSelectCommand + // CTAS would derive 4 query executions + // 1. DropTable // 2. InsertIntoHadoopFsRelationCommand - // 3. CommandResultExec - assert(commands.length == 6) + // 3. CreateDataSourceTableAsSelectCommand + // 4. SaveAsV1TableCommand + assert(commands.length == 7) assert(commands(5)._1 == "command") assert(commands(5)._2.isInstanceOf[CreateDataSourceTableAsSelectCommand]) assert(commands(5)._2.asInstanceOf[CreateDataSourceTableAsSelectCommand] --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org