This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new e5624320176d [SPARK-51668][SQL] Report metrics for failed writes to V2 
data sources
e5624320176d is described below

commit e5624320176dba38dc6b3df22cc4648b7eabe498
Author: Ole Sasse <[email protected]>
AuthorDate: Tue Apr 1 17:47:19 2025 +0800

    [SPARK-51668][SQL] Report metrics for failed writes to V2 data sources
    
    ### What changes were proposed in this pull request?
    
    Always post driver metrics for data source V2 writes
    
    ### Why are the changes needed?
    
    All metrics that have been collected are otherwise lost in case the command 
aborts
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Added new tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #50413 from olaky/custom-metrics-reporting-when-commands-abort.
    
    Authored-by: Ole Sasse <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
    (cherry picked from commit 6d606139d5787e1b41d75bd4886038b061a9deb0)
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../sql/connector/catalog/InMemoryTable.scala      | 42 +++++++++++++---------
 .../datasources/v2/V1FallbackWriters.scala         | 20 ++++++-----
 .../datasources/v2/WriteToDataSourceV2Exec.scala   |  7 ++--
 .../scala/org/apache/spark/sql/QueryTest.scala     |  4 ++-
 .../spark/sql/connector/V1WriteFallbackSuite.scala | 12 ++++++-
 .../datasources/InMemoryTableMetricSuite.scala     | 17 ++++++++-
 6 files changed, 72 insertions(+), 30 deletions(-)

diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
index c27b8fea059f..f8eb32f6e924 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
@@ -61,23 +61,31 @@ class InMemoryTable(
 
   override def withData(
       data: Array[BufferedRows],
-      writeSchema: StructType): InMemoryTable = dataMap.synchronized {
-    data.foreach(_.rows.foreach { row =>
-      val key = getKey(row, writeSchema)
-      dataMap += dataMap.get(key)
-        .map { splits =>
-          val newSplits = if (splits.last.rows.size >= numRowsPerSplit) {
-            splits :+ new BufferedRows(key)
-          } else {
-            splits
+      writeSchema: StructType): InMemoryTable = {
+    dataMap.synchronized {
+      data.foreach(_.rows.foreach { row =>
+        val key = getKey(row, writeSchema)
+        dataMap += dataMap.get(key)
+          .map { splits =>
+            val newSplits = if (splits.last.rows.size >= numRowsPerSplit) {
+              splits :+ new BufferedRows(key)
+            } else {
+              splits
+            }
+            newSplits.last.withRow(row)
+            key -> newSplits
           }
-          newSplits.last.withRow(row)
-          key -> newSplits
-        }
-        .getOrElse(key -> Seq(new BufferedRows(key).withRow(row)))
-      addPartitionKey(key)
-    })
-    this
+          .getOrElse(key -> Seq(new BufferedRows(key).withRow(row)))
+        addPartitionKey(key)
+      })
+
+      if (data.exists(_.rows.exists(row => row.numFields == 1 &&
+          row.getInt(0) == InMemoryTable.uncommittableValue()))) {
+        throw new IllegalArgumentException(s"Test only mock write failure")
+      }
+
+      this
+    }
   }
 
   override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
@@ -166,6 +174,8 @@ object InMemoryTable {
     }
   }
 
+  def uncommittableValue(): Int = Int.MaxValue / 2
+
   private def splitAnd(filter: Filter): Seq[Filter] = {
     filter match {
       case And(left, right) => splitAnd(left) ++ splitAnd(right)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala
index 358f35e11d65..3eadffb8f0ae 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala
@@ -69,17 +69,19 @@ sealed trait V1FallbackWriters extends LeafV2CommandExec 
with SupportsV1Write {
   def write: V1Write
 
   override def run(): Seq[InternalRow] = {
-    writeWithV1(write.toInsertableRelation)
-    refreshCache()
+    try {
+      writeWithV1(write.toInsertableRelation)
+      refreshCache()
 
-    write.reportDriverMetrics().foreach { customTaskMetric =>
-      
metrics.get(customTaskMetric.name()).foreach(_.set(customTaskMetric.value()))
-    }
-
-    val executionId = 
sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
-    SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, 
metrics.values.toSeq)
+      Nil
+    } finally {
+      write.reportDriverMetrics().foreach { customTaskMetric =>
+        
metrics.get(customTaskMetric.name()).foreach(_.set(customTaskMetric.value()))
+      }
 
-    Nil
+      val executionId = 
sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
+      SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, 
metrics.values.toSeq)
+    }
   }
 }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
index 016d6b5411ac..230864c0e267 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
@@ -356,8 +356,11 @@ trait V2ExistingTableWriteExec extends V2TableWriteExec {
     }.toMap
 
   override protected def run(): Seq[InternalRow] = {
-    val writtenRows = writeWithV2(write.toBatch)
-    postDriverMetrics()
+    val writtenRows = try {
+      writeWithV2(write.toBatch)
+    } finally {
+      postDriverMetrics()
+    }
     refreshCache()
     writtenRows
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index d2d119d1f581..71af79895d20 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -459,7 +459,9 @@ object QueryTest extends Assertions {
       override def onSuccess(funcName: String, qe: QueryExecution, durationNs: 
Long): Unit = {
         capturedQueryExecutions = capturedQueryExecutions :+ qe
       }
-      override def onFailure(funcName: String, qe: QueryExecution, exception: 
Exception): Unit = {}
+      override def onFailure(funcName: String, qe: QueryExecution, exception: 
Exception): Unit = {
+        capturedQueryExecutions = capturedQueryExecutions :+ qe
+      }
     }
 
     spark.sparkContext.listenerBus.waitUntilEmpty(15000)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala
index 68c2a01c69ae..e396232eb70f 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala
@@ -234,6 +234,14 @@ class V1WriteFallbackSuite extends QueryTest with 
SharedSparkSession with Before
         df2.writeTo("test").overwrite(lit(true))
       }
       assert(overwritePlan.metrics("numOutputRows").value === 1)
+
+      val failingPlan = captureWrite(session) {
+        assertThrows[IllegalStateException] {
+          val df3 = session.createDataFrame(Seq((3, 
"this-value-fails-the-write")))
+          df3.writeTo("test").overwrite(lit(true))
+        }
+      }
+      assert(failingPlan.metrics("numOutputRows").value === 1)
     } finally {
       SparkSession.setActiveSession(spark)
       SparkSession.setDefaultSession(spark)
@@ -435,7 +443,9 @@ class InMemoryTableWithV1Fallback(
           writeMetrics = Array(V1WriteTaskMetric("numOutputRows", rows.length))
 
           rows.groupBy(getPartitionValues).foreach { case (partition, 
elements) =>
-            if (dataMap.contains(partition) && mode == "append") {
+            if 
(elements.exists(_.toSeq.contains("this-value-fails-the-write"))) {
+              throw new IllegalStateException("Test only mock write failure")
+            } else if (dataMap.contains(partition) && mode == "append") {
               dataMap.put(partition, dataMap(partition) ++ elements)
             } else if (dataMap.contains(partition)) {
               throw new IllegalStateException("Partition was not removed 
properly")
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/InMemoryTableMetricSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/InMemoryTableMetricSuite.scala
index 7094404b3c1d..7e8a95f4d0cd 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/InMemoryTableMetricSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/InMemoryTableMetricSuite.scala
@@ -22,7 +22,7 @@ import org.scalatest.BeforeAndAfter
 import org.scalatest.time.SpanSugar._
 
 import org.apache.spark.sql.QueryTest
-import org.apache.spark.sql.connector.catalog.{Column, Identifier, 
InMemoryTableCatalog}
+import org.apache.spark.sql.connector.catalog.{Column, Identifier, 
InMemoryTable, InMemoryTableCatalog}
 import org.apache.spark.sql.connector.expressions.Transform
 import org.apache.spark.sql.functions.lit
 import org.apache.spark.sql.test.SharedSparkSession
@@ -100,4 +100,19 @@ class InMemoryTableMetricSuite
       assert(metrics.get("number of rows from driver").contains("3"))
     })
   }
+
+  test("Report metrics for aborted command") {
+    testMetricOnDSv2(table => {
+      assertThrows[IllegalArgumentException] {
+        val df = spark
+          .range(start = InMemoryTable.uncommittableValue(),
+            end = InMemoryTable.uncommittableValue() + 1)
+          .toDF("i")
+        val v2Writer = df.writeTo(table)
+        v2Writer.overwrite(lit(true))
+      }
+    }, metrics => {
+      assert(metrics.get("number of rows from driver").contains("1"))
+    })
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to