This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c9cfaac90fd4 [SPARK-46452][SQL] Add a new API in DataWriter to write 
an iterator of records
c9cfaac90fd4 is described below

commit c9cfaac90fd423c3a38e295234e24744b946cb02
Author: allisonwang-db <allison.w...@databricks.com>
AuthorDate: Wed Dec 20 19:17:21 2023 +0800

    [SPARK-46452][SQL] Add a new API in DataWriter to write an iterator of 
records
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to add a new method in `DataWriter` that supports writing 
an iterator of records:
    ```java
    void writeAll(Iterator<T> records) throws IOException
    ```
    
    ### Why are the changes needed?
    
    To make the API more flexible and support more use cases (e.g Python data 
sources). See https://github.com/apache/spark/pull/43791
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. This PR introduces a new method in `DataWriter`.
    
    ### How was this patch tested?
    
    Existing unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #44410 from allisonwang-db/spark-46452-dsv2-write-all.
    
    Authored-by: allisonwang-db <allison.w...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../spark/sql/connector/write/DataWriter.java      |  18 +++
 .../datasources/v2/WriteToDataSourceV2Exec.scala   | 121 ++++++++++++---------
 2 files changed, 88 insertions(+), 51 deletions(-)

diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DataWriter.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DataWriter.java
index 6a1cee181bc2..d6e94fe2ca8b 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DataWriter.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DataWriter.java
@@ -19,6 +19,7 @@ package org.apache.spark.sql.connector.write;
 
 import java.io.Closeable;
 import java.io.IOException;
+import java.util.Iterator;
 
 import org.apache.spark.annotation.Evolving;
 import org.apache.spark.sql.connector.metric.CustomTaskMetric;
@@ -74,6 +75,23 @@ public interface DataWriter<T> extends Closeable {
    */
   void write(T record) throws IOException;
 
+  /**
+   * Writes all records provided by the given iterator. By default, it calls 
the {@link #write}
+   * method for each record in the iterator.
+   * <p>
+   * If this method fails (by throwing an exception), {@link #abort()} will be 
called and this
+   * data writer is considered to have been failed.
+   *
+   * @throws IOException if failure happens during disk/network IO like 
writing files.
+   *
+   * @since 4.0.0
+   */
+  default void writeAll(Iterator<T> records) throws IOException {
+    while (records.hasNext()) {
+      write(records.next());
+    }
+  }
+
   /**
    * Commits this writer after all records are written successfully, returns a 
commit message which
    * will be sent back to driver side and passed to
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
index 2527f201f3a8..97c1f7ced508 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
@@ -421,7 +421,7 @@ trait V2TableWriteExec extends V2CommandExec with 
UnaryExecNode {
 
 trait WritingSparkTask[W <: DataWriter[InternalRow]] extends Logging with 
Serializable {
 
-  protected def write(writer: W, row: InternalRow): Unit
+  protected def write(writer: W, iter: java.util.Iterator[InternalRow]): Unit
 
   def run(
       writerFactory: DataWriterFactory,
@@ -436,19 +436,11 @@ trait WritingSparkTask[W <: DataWriter[InternalRow]] 
extends Logging with Serial
     val attemptId = context.attemptNumber()
     val dataWriter = writerFactory.createWriter(partId, taskId).asInstanceOf[W]
 
-    var count = 0L
+    val iterWithMetrics = IteratorWithMetrics(iter, dataWriter, customMetrics)
+
     // write the data and commit this writer.
     Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
-      while (iter.hasNext) {
-        if (count % CustomMetrics.NUM_ROWS_PER_UPDATE == 0) {
-          CustomMetrics.updateMetrics(
-            dataWriter.currentMetricsValues.toImmutableArraySeq, customMetrics)
-        }
-
-        // Count is here.
-        count += 1
-        write(dataWriter, iter.next())
-      }
+      write(dataWriter, iterWithMetrics)
 
       CustomMetrics.updateMetrics(
         dataWriter.currentMetricsValues.toImmutableArraySeq, customMetrics)
@@ -476,7 +468,7 @@ trait WritingSparkTask[W <: DataWriter[InternalRow]] 
extends Logging with Serial
       logInfo(s"Committed partition $partId (task $taskId, attempt $attemptId, 
" +
         s"stage $stageId.$stageAttempt)")
 
-      DataWritingSparkTaskResult(count, msg)
+      DataWritingSparkTaskResult(iterWithMetrics.count, msg)
 
     })(catchBlock = {
       // If there is an error, abort this writer
@@ -489,11 +481,30 @@ trait WritingSparkTask[W <: DataWriter[InternalRow]] 
extends Logging with Serial
       dataWriter.close()
     })
   }
+
+  private case class IteratorWithMetrics(
+      iter: Iterator[InternalRow],
+      dataWriter: W,
+      customMetrics: Map[String, SQLMetric]) extends 
java.util.Iterator[InternalRow] {
+    var count = 0L
+
+    override def hasNext: Boolean = iter.hasNext
+
+    override def next(): InternalRow = {
+      if (count % CustomMetrics.NUM_ROWS_PER_UPDATE == 0) {
+        CustomMetrics.updateMetrics(
+          dataWriter.currentMetricsValues.toImmutableArraySeq, customMetrics)
+      }
+      count += 1
+      iter.next()
+    }
+  }
 }
 
 object DataWritingSparkTask extends WritingSparkTask[DataWriter[InternalRow]] {
-  override protected def write(writer: DataWriter[InternalRow], row: 
InternalRow): Unit = {
-    writer.write(row)
+  override protected def write(
+      writer: DataWriter[InternalRow], iter: java.util.Iterator[InternalRow]): 
Unit = {
+    writer.writeAll(iter)
   }
 }
 
@@ -503,25 +514,29 @@ case class DeltaWritingSparkTask(
   private lazy val rowProjection = projections.rowProjection.orNull
   private lazy val rowIdProjection = projections.rowIdProjection
 
-  override protected def write(writer: DeltaWriter[InternalRow], row: 
InternalRow): Unit = {
-    val operation = row.getInt(0)
+  override protected def write(
+      writer: DeltaWriter[InternalRow], iter: 
java.util.Iterator[InternalRow]): Unit = {
+    while (iter.hasNext) {
+      val row = iter.next()
+      val operation = row.getInt(0)
 
-    operation match {
-      case DELETE_OPERATION =>
-        rowIdProjection.project(row)
-        writer.delete(null, rowIdProjection)
+      operation match {
+        case DELETE_OPERATION =>
+          rowIdProjection.project(row)
+          writer.delete(null, rowIdProjection)
 
-      case UPDATE_OPERATION =>
-        rowProjection.project(row)
-        rowIdProjection.project(row)
-        writer.update(null, rowIdProjection, rowProjection)
+        case UPDATE_OPERATION =>
+          rowProjection.project(row)
+          rowIdProjection.project(row)
+          writer.update(null, rowIdProjection, rowProjection)
 
-      case INSERT_OPERATION =>
-        rowProjection.project(row)
-        writer.insert(rowProjection)
+        case INSERT_OPERATION =>
+          rowProjection.project(row)
+          writer.insert(rowProjection)
 
-      case other =>
-        throw new SparkException(s"Unexpected operation ID: $other")
+        case other =>
+          throw new SparkException(s"Unexpected operation ID: $other")
+      }
     }
   }
 }
@@ -533,27 +548,31 @@ case class DeltaWithMetadataWritingSparkTask(
   private lazy val rowIdProjection = projections.rowIdProjection
   private lazy val metadataProjection = projections.metadataProjection.orNull
 
-  override protected def write(writer: DeltaWriter[InternalRow], row: 
InternalRow): Unit = {
-    val operation = row.getInt(0)
-
-    operation match {
-      case DELETE_OPERATION =>
-        rowIdProjection.project(row)
-        metadataProjection.project(row)
-        writer.delete(metadataProjection, rowIdProjection)
-
-      case UPDATE_OPERATION =>
-        rowProjection.project(row)
-        rowIdProjection.project(row)
-        metadataProjection.project(row)
-        writer.update(metadataProjection, rowIdProjection, rowProjection)
-
-      case INSERT_OPERATION =>
-        rowProjection.project(row)
-        writer.insert(rowProjection)
-
-      case other =>
-        throw new SparkException(s"Unexpected operation ID: $other")
+  override protected def write(
+      writer: DeltaWriter[InternalRow], iter: 
java.util.Iterator[InternalRow]): Unit = {
+    while (iter.hasNext) {
+      val row = iter.next()
+      val operation = row.getInt(0)
+
+      operation match {
+        case DELETE_OPERATION =>
+          rowIdProjection.project(row)
+          metadataProjection.project(row)
+          writer.delete(metadataProjection, rowIdProjection)
+
+        case UPDATE_OPERATION =>
+          rowProjection.project(row)
+          rowIdProjection.project(row)
+          metadataProjection.project(row)
+          writer.update(metadataProjection, rowIdProjection, rowProjection)
+
+        case INSERT_OPERATION =>
+          rowProjection.project(row)
+          writer.insert(rowProjection)
+
+        case other =>
+          throw new SparkException(s"Unexpected operation ID: $other")
+      }
     }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to