This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.0 by this push:
new e5624320176d [SPARK-51668][SQL] Report metrics for failed writes to V2
data sources
e5624320176d is described below
commit e5624320176dba38dc6b3df22cc4648b7eabe498
Author: Ole Sasse <[email protected]>
AuthorDate: Tue Apr 1 17:47:19 2025 +0800
[SPARK-51668][SQL] Report metrics for failed writes to V2 data sources
### What changes were proposed in this pull request?
Always post driver metrics for data source V2 writes
### Why are the changes needed?
All metrics that have been collected are otherwise lost in case the command
aborts
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Added new tests
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #50413 from olaky/custom-metrics-reporting-when-commands-abort.
Authored-by: Ole Sasse <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
(cherry picked from commit 6d606139d5787e1b41d75bd4886038b061a9deb0)
Signed-off-by: Wenchen Fan <[email protected]>
---
.../sql/connector/catalog/InMemoryTable.scala | 42 +++++++++++++---------
.../datasources/v2/V1FallbackWriters.scala | 20 ++++++-----
.../datasources/v2/WriteToDataSourceV2Exec.scala | 7 ++--
.../scala/org/apache/spark/sql/QueryTest.scala | 4 ++-
.../spark/sql/connector/V1WriteFallbackSuite.scala | 12 ++++++-
.../datasources/InMemoryTableMetricSuite.scala | 17 ++++++++-
6 files changed, 72 insertions(+), 30 deletions(-)
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
index c27b8fea059f..f8eb32f6e924 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
@@ -61,23 +61,31 @@ class InMemoryTable(
override def withData(
data: Array[BufferedRows],
- writeSchema: StructType): InMemoryTable = dataMap.synchronized {
- data.foreach(_.rows.foreach { row =>
- val key = getKey(row, writeSchema)
- dataMap += dataMap.get(key)
- .map { splits =>
- val newSplits = if (splits.last.rows.size >= numRowsPerSplit) {
- splits :+ new BufferedRows(key)
- } else {
- splits
+ writeSchema: StructType): InMemoryTable = {
+ dataMap.synchronized {
+ data.foreach(_.rows.foreach { row =>
+ val key = getKey(row, writeSchema)
+ dataMap += dataMap.get(key)
+ .map { splits =>
+ val newSplits = if (splits.last.rows.size >= numRowsPerSplit) {
+ splits :+ new BufferedRows(key)
+ } else {
+ splits
+ }
+ newSplits.last.withRow(row)
+ key -> newSplits
}
- newSplits.last.withRow(row)
- key -> newSplits
- }
- .getOrElse(key -> Seq(new BufferedRows(key).withRow(row)))
- addPartitionKey(key)
- })
- this
+ .getOrElse(key -> Seq(new BufferedRows(key).withRow(row)))
+ addPartitionKey(key)
+ })
+
+ if (data.exists(_.rows.exists(row => row.numFields == 1 &&
+ row.getInt(0) == InMemoryTable.uncommittableValue()))) {
+ throw new IllegalArgumentException(s"Test only mock write failure")
+ }
+
+ this
+ }
}
override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
@@ -166,6 +174,8 @@ object InMemoryTable {
}
}
+ def uncommittableValue(): Int = Int.MaxValue / 2
+
private def splitAnd(filter: Filter): Seq[Filter] = {
filter match {
case And(left, right) => splitAnd(left) ++ splitAnd(right)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala
index 358f35e11d65..3eadffb8f0ae 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala
@@ -69,17 +69,19 @@ sealed trait V1FallbackWriters extends LeafV2CommandExec
with SupportsV1Write {
def write: V1Write
override def run(): Seq[InternalRow] = {
- writeWithV1(write.toInsertableRelation)
- refreshCache()
+ try {
+ writeWithV1(write.toInsertableRelation)
+ refreshCache()
- write.reportDriverMetrics().foreach { customTaskMetric =>
-
metrics.get(customTaskMetric.name()).foreach(_.set(customTaskMetric.value()))
- }
-
- val executionId =
sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
- SQLMetrics.postDriverMetricUpdates(sparkContext, executionId,
metrics.values.toSeq)
+ Nil
+ } finally {
+ write.reportDriverMetrics().foreach { customTaskMetric =>
+
metrics.get(customTaskMetric.name()).foreach(_.set(customTaskMetric.value()))
+ }
- Nil
+ val executionId =
sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
+ SQLMetrics.postDriverMetricUpdates(sparkContext, executionId,
metrics.values.toSeq)
+ }
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
index 016d6b5411ac..230864c0e267 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
@@ -356,8 +356,11 @@ trait V2ExistingTableWriteExec extends V2TableWriteExec {
}.toMap
override protected def run(): Seq[InternalRow] = {
- val writtenRows = writeWithV2(write.toBatch)
- postDriverMetrics()
+ val writtenRows = try {
+ writeWithV2(write.toBatch)
+ } finally {
+ postDriverMetrics()
+ }
refreshCache()
writtenRows
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index d2d119d1f581..71af79895d20 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -459,7 +459,9 @@ object QueryTest extends Assertions {
override def onSuccess(funcName: String, qe: QueryExecution, durationNs:
Long): Unit = {
capturedQueryExecutions = capturedQueryExecutions :+ qe
}
- override def onFailure(funcName: String, qe: QueryExecution, exception:
Exception): Unit = {}
+ override def onFailure(funcName: String, qe: QueryExecution, exception:
Exception): Unit = {
+ capturedQueryExecutions = capturedQueryExecutions :+ qe
+ }
}
spark.sparkContext.listenerBus.waitUntilEmpty(15000)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala
index 68c2a01c69ae..e396232eb70f 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala
@@ -234,6 +234,14 @@ class V1WriteFallbackSuite extends QueryTest with
SharedSparkSession with Before
df2.writeTo("test").overwrite(lit(true))
}
assert(overwritePlan.metrics("numOutputRows").value === 1)
+
+ val failingPlan = captureWrite(session) {
+ assertThrows[IllegalStateException] {
+ val df3 = session.createDataFrame(Seq((3,
"this-value-fails-the-write")))
+ df3.writeTo("test").overwrite(lit(true))
+ }
+ }
+ assert(failingPlan.metrics("numOutputRows").value === 1)
} finally {
SparkSession.setActiveSession(spark)
SparkSession.setDefaultSession(spark)
@@ -435,7 +443,9 @@ class InMemoryTableWithV1Fallback(
writeMetrics = Array(V1WriteTaskMetric("numOutputRows", rows.length))
rows.groupBy(getPartitionValues).foreach { case (partition,
elements) =>
- if (dataMap.contains(partition) && mode == "append") {
+ if
(elements.exists(_.toSeq.contains("this-value-fails-the-write"))) {
+ throw new IllegalStateException("Test only mock write failure")
+ } else if (dataMap.contains(partition) && mode == "append") {
dataMap.put(partition, dataMap(partition) ++ elements)
} else if (dataMap.contains(partition)) {
throw new IllegalStateException("Partition was not removed
properly")
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/InMemoryTableMetricSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/InMemoryTableMetricSuite.scala
index 7094404b3c1d..7e8a95f4d0cd 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/InMemoryTableMetricSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/InMemoryTableMetricSuite.scala
@@ -22,7 +22,7 @@ import org.scalatest.BeforeAndAfter
import org.scalatest.time.SpanSugar._
import org.apache.spark.sql.QueryTest
-import org.apache.spark.sql.connector.catalog.{Column, Identifier,
InMemoryTableCatalog}
+import org.apache.spark.sql.connector.catalog.{Column, Identifier,
InMemoryTable, InMemoryTableCatalog}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.test.SharedSparkSession
@@ -100,4 +100,19 @@ class InMemoryTableMetricSuite
assert(metrics.get("number of rows from driver").contains("3"))
})
}
+
+ test("Report metrics for aborted command") {
+ testMetricOnDSv2(table => {
+ assertThrows[IllegalArgumentException] {
+ val df = spark
+ .range(start = InMemoryTable.uncommittableValue(),
+ end = InMemoryTable.uncommittableValue() + 1)
+ .toDF("i")
+ val v2Writer = df.writeTo(table)
+ v2Writer.overwrite(lit(true))
+ }
+ }, metrics => {
+ assert(metrics.get("number of rows from driver").contains("1"))
+ })
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]