This is an automated email from the ASF dual-hosted git repository.
ptoth pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new c94ce2c23f27 [SPARK-55302][SQL] Fix custom metrics in case of
`KeyGroupedPartitioning`
c94ce2c23f27 is described below
commit c94ce2c23f27b6e9e3256681947e91e82698c8bc
Author: Peter Toth <[email protected]>
AuthorDate: Mon Feb 2 19:44:44 2026 +0100
[SPARK-55302][SQL] Fix custom metrics in case of `KeyGroupedPartitioning`
### What changes were proposed in this pull request?
This PR adds a new `initMetricsValues()` method to `PartitionReader` so as
to initialize custom metrics returned by `currentMetricsValues()`. In case of
`KeyGroupedPartitioning` multiple input partitions are grouped and so multiple
`PartitionReader` belong to one output partition. A `PartitionReader` needs to
be initialized with metrics calculated by the previous `PartitionReader` of the
same partition group to calculate the right value.
### Why are the changes needed?
To calculate custom metrics correctly.
### Does this PR introduce _any_ user-facing change?
It fixes metrics calculation.
### How was this patch tested?
New UT is added.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #54081 from peter-toth/SPARK-55302-fix-kgp-custom-metrics.
Authored-by: Peter Toth <[email protected]>
Signed-off-by: Peter Toth <[email protected]>
---
.../spark/sql/connector/read/PartitionReader.java | 9 +++++++
.../sql/connector/catalog/InMemoryBaseTable.scala | 30 +++++++++++++++++++++-
.../execution/datasources/v2/DataSourceRDD.scala | 21 ++++++++++-----
.../connector/KeyGroupedPartitioningSuite.scala | 18 +++++++++++++
.../datasources/InMemoryTableMetricSuite.scala | 22 +---------------
.../apache/spark/sql/test/SharedSparkSession.scala | 23 +++++++++++++++++
6 files changed, 95 insertions(+), 28 deletions(-)
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java
index 5286bbf9f85a..c12bc14a49c4 100644
---
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java
@@ -58,4 +58,13 @@ public interface PartitionReader<T> extends Closeable {
CustomTaskMetric[] NO_METRICS = {};
return NO_METRICS;
}
+
+ /**
+ * Sets the initial value of metrics before fetching any data from the
reader. This is called
+ * when multiple {@link PartitionReader}s are grouped into one partition in
case of
+ * {@link
org.apache.spark.sql.connector.read.partitioning.KeyGroupedPartitioning} and
the reader
+ * is initialized with the metrics returned by the previous reader that
belongs to the same
+ * partition. By default, this method does nothing.
+ */
+ default void initMetricsValues(CustomTaskMetric[] metrics) {}
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
index 18fe80c2e924..ebb4eef80f15 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
@@ -543,6 +543,10 @@ abstract class InMemoryBaseTable(
}
new BufferedRowsReaderFactory(metadataColumns.toSeq, nonMetadataColumns,
tableSchema)
}
+
+ override def supportedCustomMetrics(): Array[CustomMetric] = {
+ Array(new RowsReadCustomMetric)
+ }
}
case class InMemoryBatchScan(
@@ -830,10 +834,13 @@ private class BufferedRowsReader(
}
private var index: Int = -1
+ private var rowsRead: Long = 0
override def next(): Boolean = {
index += 1
- index < partition.rows.length
+ val hasNext = index < partition.rows.length
+ if (hasNext) rowsRead += 1
+ hasNext
}
override def get(): InternalRow = {
@@ -976,6 +983,22 @@ private class BufferedRowsReader(
private def castElement(elem: Any, toType: DataType, fromType: DataType):
Any =
Cast(Literal(elem, fromType), toType, None, EvalMode.TRY).eval(null)
+
+ override def initMetricsValues(metrics: Array[CustomTaskMetric]): Unit = {
+ metrics.foreach { m =>
+ m.name match {
+ case "rows_read" => rowsRead = m.value()
+ }
+ }
+ }
+
+ override def currentMetricsValues(): Array[CustomTaskMetric] = {
+ val metric = new CustomTaskMetric {
+ override def name(): String = "rows_read"
+ override def value(): Long = rowsRead
+ }
+ Array(metric)
+ }
}
private class BufferedRowsWriterFactory(schema: StructType)
@@ -1044,6 +1067,11 @@ class InMemoryCustomDriverTaskMetric(value: Long)
extends CustomTaskMetric {
override def value(): Long = value
}
+class RowsReadCustomMetric extends CustomSumMetric {
+ override def name(): String = "rows_read"
+ override def description(): String = "number of rows read"
+}
+
case class Commit(id: Long, writeSummary: Option[WriteSummary] = None)
sealed trait Operation
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
index 6a07d3c3931a..fbf5c06fe051 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
@@ -24,6 +24,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.connector.metric.CustomTaskMetric
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader,
PartitionReaderFactory}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric}
@@ -97,7 +98,8 @@ class DataSourceRDD(
}
// Once we advance to the next partition, update the metric callback
for early finish
- partitionMetricCallback.advancePartition(iter, reader)
+ val previousMetrics = partitionMetricCallback.advancePartition(iter,
reader)
+ previousMetrics.foreach(reader.initMetricsValues)
currentIter = Some(iter)
hasNext
@@ -118,19 +120,26 @@ private class PartitionMetricCallback
private var iter: MetricsIterator[_] = null
private var reader: PartitionReader[_] = null
- def advancePartition(iter: MetricsIterator[_], reader: PartitionReader[_]):
Unit = {
- execute()
+ def advancePartition(
+ iter: MetricsIterator[_],
+ reader: PartitionReader[_]): Option[Array[CustomTaskMetric]] = {
+ val metrics = execute()
this.iter = iter
this.reader = reader
+
+ metrics
}
- def execute(): Unit = {
+ def execute(): Option[Array[CustomTaskMetric]] = {
if (iter != null && reader != null) {
- CustomMetrics
- .updateMetrics(reader.currentMetricsValues.toImmutableArraySeq,
customMetrics)
+ val metrics = reader.currentMetricsValues
+ CustomMetrics.updateMetrics(metrics.toImmutableArraySeq, customMetrics)
iter.forceUpdateMetrics()
reader.close()
+ Some(metrics)
+ } else {
+ None
}
}
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
index a1b1b8444719..7c07d08d80af 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
@@ -2823,4 +2823,22 @@ class KeyGroupedPartitioningSuite extends
DistributionAndOrderingSuiteBase {
checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0)))
}
}
+
+ test("SPARK-55302: Custom metrics of grouped partitions") {
+ val items_partitions = Array(identity("id"))
+ createTable(items, itemsColumns, items_partitions)
+
+ sql(s"INSERT INTO testcat.ns.$items VALUES " +
+ "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+ "(4, 'bb', 10.0, cast('2021-01-01' as timestamp)), " +
+ "(4, 'cc', 15.5, cast('2021-02-01' as timestamp))")
+
+ val metrics = runAndFetchMetrics {
+ val df = sql(s"SELECT * FROM testcat.ns.$items")
+ val scans = collectScans(df.queryExecution.executedPlan)
+ assert(scans(0).inputRDD.partitions.length === 2, "items scan should
have 2 partition groups")
+ df.collect()
+ }
+ assert(metrics("number of rows read") == "3")
+ }
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/InMemoryTableMetricSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/InMemoryTableMetricSuite.scala
index 7e8a95f4d0cd..502424d58d2c 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/InMemoryTableMetricSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/InMemoryTableMetricSuite.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.datasources
import java.util.Collections
import org.scalatest.BeforeAndAfter
-import org.scalatest.time.SpanSugar._
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.connector.catalog.{Column, Identifier,
InMemoryTable, InMemoryTableCatalog}
@@ -54,27 +53,8 @@ class InMemoryTableMetricSuite
Array(Column.create("i", IntegerType)),
Array.empty[Transform], Collections.emptyMap[String, String])
- func("testcat.table_name")
+ val metrics = runAndFetchMetrics(func("testcat.table_name"))
- // Wait until the new execution is started and being tracked.
- eventually(timeout(10.seconds), interval(10.milliseconds)) {
- assert(statusStore.executionsCount() >= oldCount)
- }
-
- // Wait for listener to finish computing the metrics for the execution.
- eventually(timeout(10.seconds), interval(10.milliseconds)) {
- assert(statusStore.executionsList().nonEmpty &&
- statusStore.executionsList().last.metricValues != null)
- }
-
- val exec = statusStore.executionsList().last
- val execId = exec.executionId
- val sqlMetrics = exec.metrics.map { metric =>
- metric.accumulatorId -> metric.name
- }.toMap
- val metrics = statusStore.executionMetrics(execId).map { case (k, v) =>
- sqlMetrics(k) -> v
- }
checker(metrics)
}
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
index 245219c1756d..720b13b812e0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
@@ -54,6 +54,29 @@ trait SharedSparkSession extends SQLTestUtils with
SharedSparkSessionBase {
doThreadPostAudit()
}
}
+
+ def runAndFetchMetrics(func: => Unit): Map[String, String] = {
+ val statusStore = spark.sharedState.statusStore
+ val oldCount = statusStore.executionsList().size
+
+ func
+
+ // Wait until the new execution is started and being tracked.
+ eventually(timeout(10.seconds), interval(10.milliseconds)) {
+ assert(statusStore.executionsCount() >= oldCount)
+ }
+
+ // Wait for listener to finish computing the metrics for the execution.
+ eventually(timeout(10.seconds), interval(10.milliseconds)) {
+ assert(statusStore.executionsList().nonEmpty &&
+ statusStore.executionsList().last.metricValues != null)
+ }
+
+ val exec = statusStore.executionsList().last
+ val execId = exec.executionId
+ val sqlMetrics = exec.metrics.map { metric => metric.accumulatorId ->
metric.name }.toMap
+ statusStore.executionMetrics(execId).map { case (k, v) => sqlMetrics(k) ->
v }
+ }
}
/**
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]