This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 976f8875edd4 [SPARK-50286][SQL] Correctly propagate SQL options to 
WriteBuilder
976f8875edd4 is described below

commit 976f8875edd4669439880a06be041e262a2427f4
Author: Cheng Pan <[email protected]>
AuthorDate: Mon Nov 25 21:52:12 2024 +0800

    [SPARK-50286][SQL] Correctly propagate SQL options to WriteBuilder
    
    ### What changes were proposed in this pull request?
    
    SPARK-49098 introduced a SQL syntax to allow users to set table options on 
DSv2 write cases, but unfortunately, the options set by SQL are not propagated 
correctly to the underlying DSv2 `WriteBuilder`
    
    ```
    INSERT INTO $t1 WITH (`write.split-size` = 10) SELECT ...
    ```
    
    ```
    df.writeTo(t1).option("write.split-size", "10").append()
    ```
    
    From the user's perspective, the above two are equivalent, but internal 
implementations differ slightly. Both of them are going to construct an
    
    ```
    AppendData(r: DataSourceV2Relation, ..., writeOptions, ...)
    ```
    
    but the SQL `options` are carried by `r.options`, and the `DataFrame` API 
`options` are carried by `writeOptions`. Currently, only the latter is 
propagated to the `WriteBuilder`, and the former is silently dropped. This PR 
fixes the above issue by merging those two `options`.
    
    Currently, the `options` propagation is inconsistent in `DataFrame`, 
`DataFrameV2`, and SQL:
    - DataFrame API, the same `options` are carried by both `writeOptions` and 
`DataSourceV2Relation`
    - DataFrameV2 API cases, options are only carried by `write options`
    - SQL, `options` are only carried by `DataSourceV2Relation`
    
    BTW, `SessionConfigSupport` only takes effect on `DataFrame` and 
`DataFrameV2` API, it is not considered in the `SQL` read/write path entirely 
in the current codebase.
    
    ### Why are the changes needed?
    
    Correctly propagate SQL options to `WriteBuilder`, to complete the feature 
added in SPARK-49098, so that DSv2 implementations like Iceberg can benefit.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, it's an unreleased feature.
    
    ### How was this patch tested?
    
    UTs added by SPARK-36680 and SPARK-49098 are updated also to check SQL 
`options` are correctly propagated to the physical plan
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #48822 from pan3793/SPARK-50286.
    
    Authored-by: Cheng Pan <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../sql/connector/catalog/InMemoryBaseTable.scala  |  35 ++-
 .../catalog/InMemoryRowLevelOperationTable.scala   |   4 +-
 .../sql/connector/catalog/InMemoryTable.scala      |  14 +-
 .../catalog/InMemoryTableWithV2Filter.scala        |  27 +-
 .../sql/execution/datasources/v2/V2Writes.scala    |  36 ++-
 .../scala/org/apache/spark/sql/QueryTest.scala     |  10 +-
 .../sql/connector/DataSourceV2OptionSuite.scala    | 327 +++++++++++++++++++++
 .../spark/sql/connector/DataSourceV2SQLSuite.scala |  93 +-----
 .../spark/sql/connector/V1WriteFallbackSuite.scala |   6 +-
 9 files changed, 406 insertions(+), 146 deletions(-)

diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
index 497ef848ac78..ab17b93ad614 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
@@ -295,7 +295,7 @@ abstract class InMemoryBaseTable(
     TableCapability.TRUNCATE)
 
   override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder 
= {
-    new InMemoryScanBuilder(schema)
+    new InMemoryScanBuilder(schema, options)
   }
 
   private def canEvaluate(filter: Filter): Boolean = {
@@ -309,8 +309,10 @@ abstract class InMemoryBaseTable(
     }
   }
 
-  class InMemoryScanBuilder(tableSchema: StructType) extends ScanBuilder
-      with SupportsPushDownRequiredColumns with SupportsPushDownFilters {
+  class InMemoryScanBuilder(
+      tableSchema: StructType,
+      options: CaseInsensitiveStringMap) extends ScanBuilder
+    with SupportsPushDownRequiredColumns with SupportsPushDownFilters {
     private var schema: StructType = tableSchema
     private var postScanFilters: Array[Filter] = Array.empty
     private var evaluableFilters: Array[Filter] = Array.empty
@@ -318,7 +320,7 @@ abstract class InMemoryBaseTable(
 
     override def build: Scan = {
       val scan = InMemoryBatchScan(
-        data.map(_.asInstanceOf[InputPartition]).toImmutableArraySeq, schema, 
tableSchema)
+        data.map(_.asInstanceOf[InputPartition]).toImmutableArraySeq, schema, 
tableSchema, options)
       if (evaluableFilters.nonEmpty) {
         scan.filter(evaluableFilters)
       }
@@ -442,7 +444,8 @@ abstract class InMemoryBaseTable(
   case class InMemoryBatchScan(
       var _data: Seq[InputPartition],
       readSchema: StructType,
-      tableSchema: StructType)
+      tableSchema: StructType,
+      options: CaseInsensitiveStringMap)
     extends BatchScanBaseClass(_data, readSchema, tableSchema) with 
SupportsRuntimeFiltering {
 
     override def filterAttributes(): Array[NamedReference] = {
@@ -474,17 +477,17 @@ abstract class InMemoryBaseTable(
     }
   }
 
-  abstract class InMemoryWriterBuilder() extends SupportsTruncate with 
SupportsDynamicOverwrite
-    with SupportsStreamingUpdateAsAppend {
+  abstract class InMemoryWriterBuilder(val info: LogicalWriteInfo)
+    extends SupportsTruncate with SupportsDynamicOverwrite with 
SupportsStreamingUpdateAsAppend {
 
-    protected var writer: BatchWrite = Append
-    protected var streamingWriter: StreamingWrite = StreamingAppend
+    protected var writer: BatchWrite = new Append(info)
+    protected var streamingWriter: StreamingWrite = new StreamingAppend(info)
 
     override def overwriteDynamicPartitions(): WriteBuilder = {
-      if (writer != Append) {
+      if (!writer.isInstanceOf[Append]) {
         throw new IllegalArgumentException(s"Unsupported writer type: $writer")
       }
-      writer = DynamicOverwrite
+      writer = new DynamicOverwrite(info)
       streamingWriter = new 
StreamingNotSupportedOperation("overwriteDynamicPartitions")
       this
     }
@@ -529,13 +532,13 @@ abstract class InMemoryBaseTable(
     override def abort(messages: Array[WriterCommitMessage]): Unit = {}
   }
 
-  protected object Append extends TestBatchWrite {
+  class Append(val info: LogicalWriteInfo) extends TestBatchWrite {
     override def commit(messages: Array[WriterCommitMessage]): Unit = 
dataMap.synchronized {
       withData(messages.map(_.asInstanceOf[BufferedRows]))
     }
   }
 
-  private object DynamicOverwrite extends TestBatchWrite {
+  class DynamicOverwrite(val info: LogicalWriteInfo) extends TestBatchWrite {
     override def commit(messages: Array[WriterCommitMessage]): Unit = 
dataMap.synchronized {
       val newData = messages.map(_.asInstanceOf[BufferedRows])
       dataMap --= newData.flatMap(_.rows.map(getKey))
@@ -543,7 +546,7 @@ abstract class InMemoryBaseTable(
     }
   }
 
-  protected object TruncateAndAppend extends TestBatchWrite {
+  class TruncateAndAppend(val info: LogicalWriteInfo) extends TestBatchWrite {
     override def commit(messages: Array[WriterCommitMessage]): Unit = 
dataMap.synchronized {
       dataMap.clear()
       withData(messages.map(_.asInstanceOf[BufferedRows]))
@@ -572,7 +575,7 @@ abstract class InMemoryBaseTable(
       s"${operation} isn't supported for streaming query.")
   }
 
-  private object StreamingAppend extends TestStreamingWrite {
+  class StreamingAppend(val info: LogicalWriteInfo) extends TestStreamingWrite 
{
     override def commit(epochId: Long, messages: Array[WriterCommitMessage]): 
Unit = {
       dataMap.synchronized {
         withData(messages.map(_.asInstanceOf[BufferedRows]))
@@ -580,7 +583,7 @@ abstract class InMemoryBaseTable(
     }
   }
 
-  protected object StreamingTruncateAndAppend extends TestStreamingWrite {
+  class StreamingTruncateAndAppend(val info: LogicalWriteInfo) extends 
TestStreamingWrite {
     override def commit(epochId: Long, messages: Array[WriterCommitMessage]): 
Unit = {
       dataMap.synchronized {
         dataMap.clear()
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala
index 4abe4c8b3e3f..3a684dc57c02 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala
@@ -59,7 +59,7 @@ class InMemoryRowLevelOperationTable(
     }
 
     override def newScanBuilder(options: CaseInsensitiveStringMap): 
ScanBuilder = {
-      new InMemoryScanBuilder(schema) {
+      new InMemoryScanBuilder(schema, options) {
         override def build: Scan = {
           val scan = super.build()
           configuredScan = scan.asInstanceOf[InMemoryBatchScan]
@@ -115,7 +115,7 @@ class InMemoryRowLevelOperationTable(
     override def rowId(): Array[NamedReference] = Array(PK_COLUMN_REF)
 
     override def newScanBuilder(options: CaseInsensitiveStringMap): 
ScanBuilder = {
-      new InMemoryScanBuilder(schema)
+      new InMemoryScanBuilder(schema, options)
     }
 
     override def newWriteBuilder(info: LogicalWriteInfo): DeltaWriteBuilder =
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
index af04816e6b6f..c27b8fea059f 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
@@ -84,23 +84,23 @@ class InMemoryTable(
     InMemoryBaseTable.maybeSimulateFailedTableWrite(new 
CaseInsensitiveStringMap(properties))
     InMemoryBaseTable.maybeSimulateFailedTableWrite(info.options)
 
-    new InMemoryWriterBuilderWithOverWrite()
+    new InMemoryWriterBuilderWithOverWrite(info)
   }
 
-  private class InMemoryWriterBuilderWithOverWrite() extends 
InMemoryWriterBuilder
-    with SupportsOverwrite {
+  class InMemoryWriterBuilderWithOverWrite(override val info: LogicalWriteInfo)
+    extends InMemoryWriterBuilder(info) with SupportsOverwrite {
 
     override def truncate(): WriteBuilder = {
-      if (writer != Append) {
+      if (!writer.isInstanceOf[Append]) {
         throw new IllegalArgumentException(s"Unsupported writer type: $writer")
       }
-      writer = TruncateAndAppend
-      streamingWriter = StreamingTruncateAndAppend
+      writer = new TruncateAndAppend(info)
+      streamingWriter = new StreamingTruncateAndAppend(info)
       this
     }
 
     override def overwrite(filters: Array[Filter]): WriteBuilder = {
-      if (writer != Append) {
+      if (!writer.isInstanceOf[Append]) {
         throw new IllegalArgumentException(s"Unsupported writer type: $writer")
       }
       writer = new Overwrite(filters)
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala
index 20ada0d622bc..9b7a90774f91 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala
@@ -47,19 +47,22 @@ class InMemoryTableWithV2Filter(
   }
 
   override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder 
= {
-    new InMemoryV2FilterScanBuilder(schema)
+    new InMemoryV2FilterScanBuilder(schema, options)
   }
 
-  class InMemoryV2FilterScanBuilder(tableSchema: StructType)
-    extends InMemoryScanBuilder(tableSchema) {
+  class InMemoryV2FilterScanBuilder(
+     tableSchema: StructType,
+     options: CaseInsensitiveStringMap)
+    extends InMemoryScanBuilder(tableSchema, options) {
     override def build: Scan = InMemoryV2FilterBatchScan(
-      data.map(_.asInstanceOf[InputPartition]).toImmutableArraySeq, schema, 
tableSchema)
+      data.map(_.asInstanceOf[InputPartition]).toImmutableArraySeq, schema, 
tableSchema, options)
   }
 
   case class InMemoryV2FilterBatchScan(
       var _data: Seq[InputPartition],
       readSchema: StructType,
-      tableSchema: StructType)
+      tableSchema: StructType,
+      options: CaseInsensitiveStringMap)
     extends BatchScanBaseClass(_data, readSchema, tableSchema) with 
SupportsRuntimeV2Filtering {
 
     override def filterAttributes(): Array[NamedReference] = {
@@ -93,21 +96,21 @@ class InMemoryTableWithV2Filter(
     InMemoryBaseTable.maybeSimulateFailedTableWrite(new 
CaseInsensitiveStringMap(properties))
     InMemoryBaseTable.maybeSimulateFailedTableWrite(info.options)
 
-    new InMemoryWriterBuilderWithOverWrite()
+    new InMemoryWriterBuilderWithOverWrite(info)
   }
 
-  private class InMemoryWriterBuilderWithOverWrite() extends 
InMemoryWriterBuilder
-    with SupportsOverwriteV2 {
+  class InMemoryWriterBuilderWithOverWrite(override val info: LogicalWriteInfo)
+    extends InMemoryWriterBuilder(info) with SupportsOverwriteV2 {
 
     override def truncate(): WriteBuilder = {
-      assert(writer == Append)
-      writer = TruncateAndAppend
-      streamingWriter = StreamingTruncateAndAppend
+      assert(writer.isInstanceOf[Append])
+      writer = new TruncateAndAppend(info)
+      streamingWriter = new StreamingTruncateAndAppend(info)
       this
     }
 
     override def overwrite(predicates: Array[Predicate]): WriteBuilder = {
-      assert(writer == Append)
+      assert(writer.isInstanceOf[Append])
       writer = new Overwrite(predicates)
       streamingWriter = new StreamingNotSupportedOperation(
         s"overwrite (${predicates.mkString("filters(", ", ", ")")})")
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala
index 319cc1c73157..17b2579ca873 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.datasources.v2
 
 import java.util.{Optional, UUID}
 
+import scala.jdk.CollectionConverters._
+
 import org.apache.spark.sql.catalyst.expressions.PredicateHelper
 import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, 
OverwriteByExpression, OverwritePartitionsDynamic, Project, ReplaceData, 
WriteDelta}
 import org.apache.spark.sql.catalyst.rules.Rule
@@ -44,7 +46,8 @@ object V2Writes extends Rule[LogicalPlan] with 
PredicateHelper {
 
   override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
     case a @ AppendData(r: DataSourceV2Relation, query, options, _, None, _) =>
-      val writeBuilder = newWriteBuilder(r.table, options, query.schema)
+      val writeOptions = mergeOptions(options, r.options.asScala.toMap)
+      val writeBuilder = newWriteBuilder(r.table, writeOptions, query.schema)
       val write = writeBuilder.build()
       val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, 
r.funCatalog)
       a.copy(write = Some(write), query = newQuery)
@@ -61,7 +64,8 @@ object V2Writes extends Rule[LogicalPlan] with 
PredicateHelper {
       }.toArray
 
       val table = r.table
-      val writeBuilder = newWriteBuilder(table, options, query.schema)
+      val writeOptions = mergeOptions(options, r.options.asScala.toMap)
+      val writeBuilder = newWriteBuilder(table, writeOptions, query.schema)
       val write = writeBuilder match {
         case builder: SupportsTruncate if isTruncate(predicates) =>
           builder.truncate().build()
@@ -76,7 +80,8 @@ object V2Writes extends Rule[LogicalPlan] with 
PredicateHelper {
 
     case o @ OverwritePartitionsDynamic(r: DataSourceV2Relation, query, 
options, _, None) =>
       val table = r.table
-      val writeBuilder = newWriteBuilder(table, options, query.schema)
+      val writeOptions = mergeOptions(options, r.options.asScala.toMap)
+      val writeBuilder = newWriteBuilder(table, writeOptions, query.schema)
       val write = writeBuilder match {
         case builder: SupportsDynamicOverwrite =>
           builder.overwriteDynamicPartitions().build()
@@ -87,31 +92,44 @@ object V2Writes extends Rule[LogicalPlan] with 
PredicateHelper {
       o.copy(write = Some(write), query = newQuery)
 
     case WriteToMicroBatchDataSource(
-        relation, table, query, queryId, writeOptions, outputMode, 
Some(batchId)) =>
-
+        relationOpt, table, query, queryId, options, outputMode, 
Some(batchId)) =>
+      val writeOptions = mergeOptions(
+        options, relationOpt.map(r => 
r.options.asScala.toMap).getOrElse(Map.empty))
       val writeBuilder = newWriteBuilder(table, writeOptions, query.schema, 
queryId)
       val write = buildWriteForMicroBatch(table, writeBuilder, outputMode)
       val microBatchWrite = new MicroBatchWrite(batchId, write.toStreaming)
       val customMetrics = write.supportedCustomMetrics.toImmutableArraySeq
-      val funCatalogOpt = relation.flatMap(_.funCatalog)
+      val funCatalogOpt = relationOpt.flatMap(_.funCatalog)
       val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, 
funCatalogOpt)
-      WriteToDataSourceV2(relation, microBatchWrite, newQuery, customMetrics)
+      WriteToDataSourceV2(relationOpt, microBatchWrite, newQuery, 
customMetrics)
 
     case rd @ ReplaceData(r: DataSourceV2Relation, _, query, _, _, None) =>
       val rowSchema = DataTypeUtils.fromAttributes(rd.dataInput)
-      val writeBuilder = newWriteBuilder(r.table, Map.empty, rowSchema)
+      val writeOptions = mergeOptions(Map.empty, r.options.asScala.toMap)
+      val writeBuilder = newWriteBuilder(r.table, writeOptions, rowSchema)
       val write = writeBuilder.build()
       val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, 
r.funCatalog)
       // project away any metadata columns that could be used for distribution 
and ordering
       rd.copy(write = Some(write), query = Project(rd.dataInput, newQuery))
 
     case wd @ WriteDelta(r: DataSourceV2Relation, _, query, _, projections, 
None) =>
-      val deltaWriteBuilder = newDeltaWriteBuilder(r.table, Map.empty, 
projections)
+      val writeOptions = mergeOptions(Map.empty, r.options.asScala.toMap)
+      val deltaWriteBuilder = newDeltaWriteBuilder(r.table, writeOptions, 
projections)
       val deltaWrite = deltaWriteBuilder.build()
       val newQuery = DistributionAndOrderingUtils.prepareQuery(deltaWrite, 
query, r.funCatalog)
       wd.copy(write = Some(deltaWrite), query = newQuery)
   }
 
+  private def mergeOptions(
+      commandOptions: Map[String, String],
+      dsOptions: Map[String, String]): Map[String, String] = {
+    // for DataFrame API cases, same options are carried by both Command and 
DataSourceV2Relation
+    // for DataFrameV2 API cases, options are only carried by Command
+    // for SQL cases, options are only carried by DataSourceV2Relation
+    assert(commandOptions == dsOptions || commandOptions.isEmpty || 
dsOptions.isEmpty)
+    commandOptions ++ dsOptions
+  }
+
   private def buildWriteForMicroBatch(
       table: SupportsWrite,
       writeBuilder: WriteBuilder,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 30180d48da71..b59c83c23d3c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -27,7 +27,7 @@ import org.scalatest.Assertions
 import org.apache.spark.sql.catalyst.ExtendedAnalysisException
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.util._
-import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SQLExecution}
+import org.apache.spark.sql.execution.{QueryExecution, SQLExecution}
 import org.apache.spark.sql.execution.columnar.InMemoryRelation
 import org.apache.spark.sql.util.QueryExecutionListener
 import org.apache.spark.storage.StorageLevel
@@ -449,12 +449,12 @@ object QueryTest extends Assertions {
     }
   }
 
-  def withPhysicalPlansCaptured(spark: SparkSession, thunk: => Unit): 
Seq[SparkPlan] = {
-    var capturedPlans = Seq.empty[SparkPlan]
+  def withQueryExecutionsCaptured(spark: SparkSession)(thunk: => Unit): 
Seq[QueryExecution] = {
+    var capturedQueryExecutions = Seq.empty[QueryExecution]
 
     val listener = new QueryExecutionListener {
       override def onSuccess(funcName: String, qe: QueryExecution, durationNs: 
Long): Unit = {
-        capturedPlans = capturedPlans :+ qe.executedPlan
+        capturedQueryExecutions = capturedQueryExecutions :+ qe
       }
       override def onFailure(funcName: String, qe: QueryExecution, exception: 
Exception): Unit = {}
     }
@@ -468,7 +468,7 @@ object QueryTest extends Assertions {
       spark.listenerManager.unregister(listener)
     }
 
-    capturedPlans
+    capturedQueryExecutions
   }
 }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2OptionSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2OptionSuite.scala
new file mode 100644
index 000000000000..70291336ba31
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2OptionSuite.scala
@@ -0,0 +1,327 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector
+
+import org.apache.spark.sql.{AnalysisException, Row}
+import org.apache.spark.sql.QueryTest.withQueryExecutionsCaptured
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.connector.catalog.InMemoryBaseTable
+import org.apache.spark.sql.execution.CommandResultExec
+import org.apache.spark.sql.execution.datasources.v2._
+import org.apache.spark.sql.functions.lit
+
+class DataSourceV2OptionSuite extends DatasourceV2SQLBase {
+  import testImplicits._
+
+  private val catalogAndNamespace = "testcat.ns1.ns2."
+
+  test("SPARK-36680: Supports Dynamic Table Options for SQL Select") {
+    val t1 = s"${catalogAndNamespace}table"
+    withTable(t1) {
+      sql(s"CREATE TABLE $t1 (id bigint, data string)")
+      sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b')")
+
+      var df = sql(s"SELECT * FROM $t1")
+      var collected = df.queryExecution.optimizedPlan.collect {
+        case scan: DataSourceV2ScanRelation =>
+          assert(scan.relation.options.isEmpty)
+      }
+      assert (collected.size == 1)
+      checkAnswer(df, Seq(Row(1, "a"), Row(2, "b")))
+
+      df = sql(s"SELECT * FROM $t1 WITH (`split-size` = 5)")
+      collected = df.queryExecution.optimizedPlan.collect {
+        case scan: DataSourceV2ScanRelation =>
+          assert(scan.relation.options.get("split-size") == "5")
+      }
+      assert (collected.size == 1)
+      checkAnswer(df, Seq(Row(1, "a"), Row(2, "b")))
+
+      collected = df.queryExecution.executedPlan.collect {
+        case BatchScanExec(_, scan: InMemoryBaseTable#InMemoryBatchScan, _, _, 
_, _) =>
+          assert(scan.options.get("split-size") === "5")
+      }
+      assert (collected.size == 1)
+
+      val noValues = intercept[AnalysisException](
+        sql(s"SELECT * FROM $t1 WITH (`split-size`)"))
+      assert(noValues.message.contains(
+        "Operation not allowed: Values must be specified for key(s): 
[split-size]"))
+    }
+  }
+
+  test("SPARK-50286: Propagate options for DataFrameReader") {
+    val t1 = s"${catalogAndNamespace}table"
+    withTable(t1) {
+      sql(s"CREATE TABLE $t1 (id bigint, data string)")
+      sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b')")
+
+      var df = spark.table(t1)
+      var collected = df.queryExecution.optimizedPlan.collect {
+        case scan: DataSourceV2ScanRelation =>
+          assert(scan.relation.options.isEmpty)
+      }
+      assert (collected.size == 1)
+      checkAnswer(df, Seq(Row(1, "a"), Row(2, "b")))
+
+      df = spark.read.option("split-size", "5").table(t1)
+      collected = df.queryExecution.optimizedPlan.collect {
+        case scan: DataSourceV2ScanRelation =>
+          assert(scan.relation.options.get("split-size") == "5")
+      }
+      assert (collected.size == 1)
+      checkAnswer(df, Seq(Row(1, "a"), Row(2, "b")))
+
+      collected = df.queryExecution.executedPlan.collect {
+        case BatchScanExec(_, scan: InMemoryBaseTable#InMemoryBatchScan, _, _, 
_, _) =>
+          assert(scan.options.get("split-size") === "5")
+      }
+      assert (collected.size == 1)
+    }
+  }
+
+  test("SPARK-49098, SPARK-50286: Supports Dynamic Table Options for SQL 
Insert") {
+    val t1 = s"${catalogAndNamespace}table"
+    withTable(t1) {
+      sql(s"CREATE TABLE $t1 (id bigint, data string)")
+      val df = sql(s"INSERT INTO $t1 WITH (`write.split-size` = 10) VALUES (1, 
'a'), (2, 'b')")
+
+      var collected = df.queryExecution.optimizedPlan.collect {
+        case CommandResult(_, AppendData(relation: DataSourceV2Relation, _, _, 
_, _, _), _, _) =>
+          assert(relation.options.get("write.split-size") == "10")
+      }
+      assert (collected.size == 1)
+
+      collected = df.queryExecution.executedPlan.collect {
+        case CommandResultExec(
+          _, AppendDataExec(_, _, write),
+          _) =>
+          val append = write.toBatch.asInstanceOf[InMemoryBaseTable#Append]
+          assert(append.info.options.get("write.split-size") === "10")
+      }
+      assert (collected.size == 1)
+
+      val insertResult = sql(s"SELECT * FROM $t1")
+      checkAnswer(insertResult, Seq(Row(1, "a"), Row(2, "b")))
+    }
+  }
+
+  test("SPARK-50286: Propagate options for DataFrameWriter Append") {
+    val t1 = s"${catalogAndNamespace}table"
+    withTable(t1) {
+      sql(s"CREATE TABLE $t1 (id bigint, data string)")
+      val captured = withQueryExecutionsCaptured(spark) {
+        Seq(1 -> "a", 2 -> "b").toDF("id", "data")
+          .write
+          .option("write.split-size", "10")
+          .mode("append")
+          .insertInto(t1)
+      }
+      assert(captured.size === 1)
+      val qe = captured.head
+      var collected = qe.optimizedPlan.collect {
+        case AppendData(_: DataSourceV2Relation, _, writeOptions, _, _, _) =>
+          assert(writeOptions("write.split-size") == "10")
+      }
+      assert (collected.size == 1)
+
+      collected = qe.executedPlan.collect {
+        case AppendDataExec(_, _, write) =>
+          val append = write.toBatch.asInstanceOf[InMemoryBaseTable#Append]
+          assert(append.info.options.get("write.split-size") === "10")
+      }
+      assert (collected.size == 1)
+    }
+  }
+
+  test("SPARK-50286: Propagate options for DataFrameWriterV2 Append") {
+    val t1 = s"${catalogAndNamespace}table"
+    withTable(t1) {
+      sql(s"CREATE TABLE $t1 (id bigint, data string)")
+      val captured = withQueryExecutionsCaptured(spark) {
+        Seq(1 -> "a", 2 -> "b").toDF("id", "data")
+          .writeTo(t1)
+          .option("write.split-size", "10")
+          .append()
+      }
+      assert(captured.size === 1)
+      val qe = captured.head
+      var collected = qe.optimizedPlan.collect {
+        case AppendData(_: DataSourceV2Relation, _, writeOptions, _, _, _) =>
+          assert(writeOptions("write.split-size") == "10")
+      }
+      assert (collected.size == 1)
+
+      collected = qe.executedPlan.collect {
+        case AppendDataExec(_, _, write) =>
+          val append = write.toBatch.asInstanceOf[InMemoryBaseTable#Append]
+          assert(append.info.options.get("write.split-size") === "10")
+      }
+      assert (collected.size == 1)
+    }
+  }
+
+  test("SPARK-49098, SPARK-50286: Supports Dynamic Table Options for SQL 
Insert Overwrite") {
+    val t1 = s"${catalogAndNamespace}table"
+    withTable(t1) {
+      sql(s"CREATE TABLE $t1 (id bigint, data string)")
+      sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b')")
+
+      val df = sql(s"INSERT OVERWRITE $t1 WITH (`write.split-size` = 10) " +
+        s"VALUES (3, 'c'), (4, 'd')")
+      var collected = df.queryExecution.optimizedPlan.collect {
+        case CommandResult(_,
+          OverwriteByExpression(relation: DataSourceV2Relation, _, _, _, _, _, 
_),
+          _, _) =>
+          assert(relation.options.get("write.split-size") === "10")
+      }
+      assert (collected.size == 1)
+
+      collected = df.queryExecution.executedPlan.collect {
+        case CommandResultExec(
+          _, OverwriteByExpressionExec(_, _, write),
+          _) =>
+          val append = 
write.toBatch.asInstanceOf[InMemoryBaseTable#TruncateAndAppend]
+          assert(append.info.options.get("write.split-size") === "10")
+      }
+      assert (collected.size == 1)
+
+      val insertResult = sql(s"SELECT * FROM $t1")
+      checkAnswer(insertResult, Seq(Row(3, "c"), Row(4, "d")))
+    }
+  }
+
+  test("SPARK-50286: Propagate options for DataFrameWriterV2 
OverwritePartitions") {
+    val t1 = s"${catalogAndNamespace}table"
+    withTable(t1) {
+      sql(s"CREATE TABLE $t1 (id bigint, data string)")
+      sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b')")
+
+      val captured = withQueryExecutionsCaptured(spark) {
+        Seq(3 -> "c", 4 -> "d").toDF("id", "data")
+          .writeTo(t1)
+          .option("write.split-size", "10")
+          .overwritePartitions()
+      }
+      assert(captured.size === 1)
+      val qe = captured.head
+      var collected = qe.optimizedPlan.collect {
+        case OverwritePartitionsDynamic(_: DataSourceV2Relation, _, 
writeOptions, _, _) =>
+          assert(writeOptions("write.split-size") === "10")
+      }
+      assert (collected.size == 1)
+
+      collected = qe.executedPlan.collect {
+        case OverwritePartitionsDynamicExec(_, _, write) =>
+          val dynOverwrite = 
write.toBatch.asInstanceOf[InMemoryBaseTable#DynamicOverwrite]
+          assert(dynOverwrite.info.options.get("write.split-size") === "10")
+      }
+      assert (collected.size == 1)
+    }
+  }
+
+  test("SPARK-49098, SPARK-50286: Supports Dynamic Table Options for SQL 
Insert Replace") {
+    val t1 = s"${catalogAndNamespace}table"
+    withTable(t1) {
+      sql(s"CREATE TABLE $t1 (id bigint, data string)")
+      sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b')")
+
+      val df = sql(s"INSERT INTO $t1 WITH (`write.split-size` = 10) " +
+        s"REPLACE WHERE TRUE " +
+        s"VALUES (3, 'c'), (4, 'd')")
+      var collected = df.queryExecution.optimizedPlan.collect {
+        case CommandResult(_,
+          OverwriteByExpression(relation: DataSourceV2Relation, _, _, _, _, _, 
_),
+          _, _) =>
+          assert(relation.options.get("write.split-size") == "10")
+      }
+      assert (collected.size == 1)
+
+      collected = df.queryExecution.executedPlan.collect {
+        case CommandResultExec(
+          _, OverwriteByExpressionExec(_, _, write),
+          _) =>
+          val append = 
write.toBatch.asInstanceOf[InMemoryBaseTable#TruncateAndAppend]
+          assert(append.info.options.get("write.split-size") === "10")
+      }
+      assert (collected.size == 1)
+
+      val insertResult = sql(s"SELECT * FROM $t1")
+      checkAnswer(insertResult, Seq(Row(3, "c"), Row(4, "d")))
+    }
+  }
+
+  test("SPARK-50286: Propagate options for DataFrameWriter Overwrite") {
+    val t1 = s"${catalogAndNamespace}table"
+    withTable(t1) {
+      sql(s"CREATE TABLE $t1 (id bigint, data string)")
+      val captured = withQueryExecutionsCaptured(spark) {
+        Seq(1 -> "a", 2 -> "b").toDF("id", "data")
+          .write
+          .option("write.split-size", "10")
+          .mode("overwrite")
+          .insertInto(t1)
+      }
+      assert(captured.size === 1)
+
+      val qe = captured.head
+      var collected = qe.optimizedPlan.collect {
+        case OverwriteByExpression(_: DataSourceV2Relation, _, _, 
writeOptions, _, _, _) =>
+          assert(writeOptions("write.split-size") === "10")
+      }
+      assert (collected.size == 1)
+
+      collected = qe.executedPlan.collect {
+        case OverwriteByExpressionExec(_, _, write) =>
+          val append = 
write.toBatch.asInstanceOf[InMemoryBaseTable#TruncateAndAppend]
+          assert(append.info.options.get("write.split-size") === "10")
+      }
+      assert (collected.size == 1)
+    }
+  }
+
+  test("SPARK-50286: Propagate options for DataFrameWriterV2 Overwrite") {
+    val t1 = s"${catalogAndNamespace}table"
+    withTable(t1) {
+      sql(s"CREATE TABLE $t1 (id bigint, data string)")
+      sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b')")
+
+      val captured = withQueryExecutionsCaptured(spark) {
+        Seq(3 -> "c", 4 -> "d").toDF("id", "data")
+          .writeTo(t1)
+          .option("write.split-size", "10")
+          .overwrite(lit(true))
+      }
+      assert(captured.size === 1)
+      val qe = captured.head
+
+      var collected = qe.optimizedPlan.collect {
+        case OverwriteByExpression(_: DataSourceV2Relation, _, _, 
writeOptions, _, _, _) =>
+          assert(writeOptions("write.split-size") === "10")
+      }
+      assert (collected.size == 1)
+
+      collected = qe.executedPlan.collect {
+        case OverwriteByExpressionExec(_, _, write) =>
+          val append = 
write.toBatch.asInstanceOf[InMemoryBaseTable#TruncateAndAppend]
+          assert(append.info.options.get("write.split-size") === "10")
+      }
+      assert (collected.size == 1)
+    }
+  }
+}
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
index 510ea49b5841..6a659fa6e3ee 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
@@ -32,7 +32,7 @@ import 
org.apache.spark.sql.catalyst.CurrentUserContext.CURRENT_USER
 import 
org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, 
NoSuchNamespaceException, TableAlreadyExistsException}
 import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, 
CatalogTable, CatalogTableType, CatalogUtils}
 import org.apache.spark.sql.catalyst.parser.ParseException
-import org.apache.spark.sql.catalyst.plans.logical.{AppendData, ColumnStat, 
CommandResult, OverwriteByExpression}
+import org.apache.spark.sql.catalyst.plans.logical.ColumnStat
 import org.apache.spark.sql.catalyst.statsEstimation.StatsEstimationTestBase
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.connector.catalog.{Column => ColumnV2, _}
@@ -44,7 +44,6 @@ import org.apache.spark.sql.execution.FilterExec
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
 import org.apache.spark.sql.execution.columnar.InMemoryRelation
 import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, 
LogicalRelationWithTable}
-import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
 import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
 import org.apache.spark.sql.execution.streaming.MemoryStream
 import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
@@ -3634,96 +3633,6 @@ class DataSourceV2SQLSuiteV1Filter
     }
   }
 
-
-  test("SPARK-36680: Supports Dynamic Table Options for Spark SQL") {
-    val t1 = s"${catalogAndNamespace}table"
-    withTable(t1) {
-      sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format")
-      sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b')")
-
-      var df = sql(s"SELECT * FROM $t1")
-      var collected = df.queryExecution.optimizedPlan.collect {
-        case scan: DataSourceV2ScanRelation =>
-          assert(scan.relation.options.isEmpty)
-      }
-      assert (collected.size == 1)
-      checkAnswer(df, Seq(Row(1, "a"), Row(2, "b")))
-
-      df = sql(s"SELECT * FROM $t1 WITH (`split-size` = 5)")
-      collected = df.queryExecution.optimizedPlan.collect {
-        case scan: DataSourceV2ScanRelation =>
-          assert(scan.relation.options.get("split-size") == "5")
-      }
-      assert (collected.size == 1)
-      checkAnswer(df, Seq(Row(1, "a"), Row(2, "b")))
-
-      val noValues = intercept[AnalysisException](
-        sql(s"SELECT * FROM $t1 WITH (`split-size`)"))
-      assert(noValues.message.contains(
-        "Operation not allowed: Values must be specified for key(s): 
[split-size]"))
-    }
-  }
-
-  test("SPARK-36680: Supports Dynamic Table Options for Insert") {
-    val t1 = s"${catalogAndNamespace}table"
-    withTable(t1) {
-      sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format")
-      val df = sql(s"INSERT INTO $t1 WITH (`write.split-size` = 10) VALUES (1, 
'a'), (2, 'b')")
-
-      val collected = df.queryExecution.optimizedPlan.collect {
-        case CommandResult(_, AppendData(relation: DataSourceV2Relation, _, _, 
_, _, _), _, _) =>
-          assert(relation.options.get("write.split-size") == "10")
-      }
-      assert (collected.size == 1)
-
-      val insertResult = sql(s"SELECT * FROM $t1")
-      checkAnswer(insertResult, Seq(Row(1, "a"), Row(2, "b")))
-    }
-  }
-
-  test("SPARK-36680: Supports Dynamic Table Options for Insert Overwrite") {
-    val t1 = s"${catalogAndNamespace}table"
-    withTable(t1) {
-      sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format")
-      sql(s"INSERT INTO $t1 WITH (`write.split-size` = 10) VALUES (1, 'a'), 
(2, 'b')")
-
-      val df = sql(s"INSERT OVERWRITE $t1 WITH (`write.split-size` = 10) " +
-        s"VALUES (3, 'c'), (4, 'd')")
-      val collected = df.queryExecution.optimizedPlan.collect {
-        case CommandResult(_,
-          OverwriteByExpression(relation: DataSourceV2Relation, _, _, _, _, _, 
_),
-          _, _) =>
-          assert(relation.options.get("write.split-size") == "10")
-      }
-      assert (collected.size == 1)
-
-      val insertResult = sql(s"SELECT * FROM $t1")
-      checkAnswer(insertResult, Seq(Row(3, "c"), Row(4, "d")))
-    }
-  }
-
-  test("SPARK-36680: Supports Dynamic Table Options for Insert Replace") {
-    val t1 = s"${catalogAndNamespace}table"
-    withTable(t1) {
-      sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format")
-      sql(s"INSERT INTO $t1 WITH (`write.split-size` = 10) VALUES (1, 'a'), 
(2, 'b')")
-
-      val df = sql(s"INSERT INTO $t1 WITH (`write.split-size` = 10) " +
-        s"REPLACE WHERE TRUE " +
-        s"VALUES (3, 'c'), (4, 'd')")
-      val collected = df.queryExecution.optimizedPlan.collect {
-        case CommandResult(_,
-          OverwriteByExpression(relation: DataSourceV2Relation, _, _, _, _, _, 
_),
-          _, _) =>
-          assert(relation.options.get("write.split-size") == "10")
-      }
-      assert (collected.size == 1)
-
-      val insertResult = sql(s"SELECT * FROM $t1")
-      checkAnswer(insertResult, Seq(Row(3, "c"), Row(4, "d")))
-    }
-  }
-
   test("SPARK-49183: custom spark_catalog generates location for managed 
tables") {
     // Reset CatalogManager to clear the materialized `spark_catalog` 
instance, so that we can
     // configure a new implementation.
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala
index 04fc7e23ebb2..68c2a01c69ae 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala
@@ -24,7 +24,7 @@ import org.scalatest.BeforeAndAfter
 
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, 
SaveMode, SparkSession, SQLContext}
-import org.apache.spark.sql.QueryTest.withPhysicalPlansCaptured
+import org.apache.spark.sql.QueryTest.withQueryExecutionsCaptured
 import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
 import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
 import org.apache.spark.sql.catalyst.rules.Rule
@@ -213,8 +213,8 @@ class V1WriteFallbackSuite extends QueryTest with 
SharedSparkSession with Before
         .getOrCreate()
 
       def captureWrite(sparkSession: SparkSession)(thunk: => Unit): SparkPlan 
= {
-        val physicalPlans = withPhysicalPlansCaptured(sparkSession, thunk)
-        val v1FallbackWritePlans = physicalPlans.filter {
+        val queryExecutions = withQueryExecutionsCaptured(sparkSession)(thunk)
+        val v1FallbackWritePlans = queryExecutions.map(_.executedPlan).filter {
           case _: AppendDataExecV1 | _: OverwriteByExpressionExecV1 => true
           case _ => false
         }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to