This is an automated email from the ASF dual-hosted git repository.

ptoth pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c94ce2c23f27 [SPARK-55302][SQL] Fix custom metrics in case of 
`KeyGroupedPartitioning`
c94ce2c23f27 is described below

commit c94ce2c23f27b6e9e3256681947e91e82698c8bc
Author: Peter Toth <[email protected]>
AuthorDate: Mon Feb 2 19:44:44 2026 +0100

    [SPARK-55302][SQL] Fix custom metrics in case of `KeyGroupedPartitioning`
    
    ### What changes were proposed in this pull request?
    This PR adds a new `initMetricsValues()` method to `PartitionReader` so as 
to initialize custom metrics returned by `currentMetricsValues()`. In case of 
`KeyGroupedPartitioning` multiple input partitions are grouped and so multiple 
`PartitionReader` belong to one output partition. A `PartitionReader` needs to 
be initialized with metrics calculated by the previous `PartitionReader` of the 
same partition group to calculate the right value.
    
    ### Why are the changes needed?
    To calculate custom metrics correctly.
    
    ### Does this PR introduce _any_ user-facing change?
    It fixes metrics calculation.
    
    ### How was this patch tested?
    New UT is added.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #54081 from peter-toth/SPARK-55302-fix-kgp-custom-metrics.
    
    Authored-by: Peter Toth <[email protected]>
    Signed-off-by: Peter Toth <[email protected]>
---
 .../spark/sql/connector/read/PartitionReader.java  |  9 +++++++
 .../sql/connector/catalog/InMemoryBaseTable.scala  | 30 +++++++++++++++++++++-
 .../execution/datasources/v2/DataSourceRDD.scala   | 21 ++++++++++-----
 .../connector/KeyGroupedPartitioningSuite.scala    | 18 +++++++++++++
 .../datasources/InMemoryTableMetricSuite.scala     | 22 +---------------
 .../apache/spark/sql/test/SharedSparkSession.scala | 23 +++++++++++++++++
 6 files changed, 95 insertions(+), 28 deletions(-)

diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java
index 5286bbf9f85a..c12bc14a49c4 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java
@@ -58,4 +58,13 @@ public interface PartitionReader<T> extends Closeable {
     CustomTaskMetric[] NO_METRICS = {};
     return NO_METRICS;
   }
+
+  /**
+   * Sets the initial value of metrics before fetching any data from the 
reader. This is called
+   * when multiple {@link PartitionReader}s are grouped into one partition in 
case of
+   * {@link 
org.apache.spark.sql.connector.read.partitioning.KeyGroupedPartitioning} and 
the reader
+   * is initialized with the metrics returned by the previous reader that 
belongs to the same
+   * partition. By default, this method does nothing.
+   */
+  default void initMetricsValues(CustomTaskMetric[] metrics) {}
 }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
index 18fe80c2e924..ebb4eef80f15 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
@@ -543,6 +543,10 @@ abstract class InMemoryBaseTable(
       }
       new BufferedRowsReaderFactory(metadataColumns.toSeq, nonMetadataColumns, 
tableSchema)
     }
+
+    override def supportedCustomMetrics(): Array[CustomMetric] = {
+      Array(new RowsReadCustomMetric)
+    }
   }
 
   case class InMemoryBatchScan(
@@ -830,10 +834,13 @@ private class BufferedRowsReader(
   }
 
   private var index: Int = -1
+  private var rowsRead: Long = 0
 
   override def next(): Boolean = {
     index += 1
-    index < partition.rows.length
+    val hasNext = index < partition.rows.length
+    if (hasNext) rowsRead += 1
+    hasNext
   }
 
   override def get(): InternalRow = {
@@ -976,6 +983,22 @@ private class BufferedRowsReader(
 
   private def castElement(elem: Any, toType: DataType, fromType: DataType): 
Any =
     Cast(Literal(elem, fromType), toType, None, EvalMode.TRY).eval(null)
+
+  override def initMetricsValues(metrics: Array[CustomTaskMetric]): Unit = {
+    metrics.foreach { m =>
+      m.name match {
+        case "rows_read" => rowsRead = m.value()
+      }
+    }
+  }
+
+  override def currentMetricsValues(): Array[CustomTaskMetric] = {
+    val metric = new CustomTaskMetric {
+      override def name(): String = "rows_read"
+      override def value(): Long = rowsRead
+    }
+    Array(metric)
+  }
 }
 
 private class BufferedRowsWriterFactory(schema: StructType)
@@ -1044,6 +1067,11 @@ class InMemoryCustomDriverTaskMetric(value: Long) 
extends CustomTaskMetric {
   override def value(): Long = value
 }
 
+class RowsReadCustomMetric extends CustomSumMetric {
+  override def name(): String = "rows_read"
+  override def description(): String = "number of rows read"
+}
+
 case class Commit(id: Long, writeSummary: Option[WriteSummary] = None)
 
 sealed trait Operation
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
index 6a07d3c3931a..fbf5c06fe051 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
@@ -24,6 +24,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
 import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.connector.metric.CustomTaskMetric
 import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, 
PartitionReaderFactory}
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric}
@@ -97,7 +98,8 @@ class DataSourceRDD(
           }
 
           // Once we advance to the next partition, update the metric callback 
for early finish
-          partitionMetricCallback.advancePartition(iter, reader)
+          val previousMetrics = partitionMetricCallback.advancePartition(iter, 
reader)
+          previousMetrics.foreach(reader.initMetricsValues)
 
           currentIter = Some(iter)
           hasNext
@@ -118,19 +120,26 @@ private class PartitionMetricCallback
   private var iter: MetricsIterator[_] = null
   private var reader: PartitionReader[_] = null
 
-  def advancePartition(iter: MetricsIterator[_], reader: PartitionReader[_]): 
Unit = {
-    execute()
+  def advancePartition(
+      iter: MetricsIterator[_],
+      reader: PartitionReader[_]): Option[Array[CustomTaskMetric]] = {
+    val metrics = execute()
 
     this.iter = iter
     this.reader = reader
+
+    metrics
   }
 
-  def execute(): Unit = {
+  def execute(): Option[Array[CustomTaskMetric]] = {
     if (iter != null && reader != null) {
-      CustomMetrics
-        .updateMetrics(reader.currentMetricsValues.toImmutableArraySeq, 
customMetrics)
+      val metrics = reader.currentMetricsValues
+      CustomMetrics.updateMetrics(metrics.toImmutableArraySeq, customMetrics)
       iter.forceUpdateMetrics()
       reader.close()
+      Some(metrics)
+    } else {
+      None
     }
   }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
index a1b1b8444719..7c07d08d80af 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala
@@ -2823,4 +2823,22 @@ class KeyGroupedPartitioningSuite extends 
DistributionAndOrderingSuiteBase {
       checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0)))
     }
   }
+
+  test("SPARK-55302: Custom metrics of grouped partitions") {
+    val items_partitions = Array(identity("id"))
+    createTable(items, itemsColumns, items_partitions)
+
+    sql(s"INSERT INTO testcat.ns.$items VALUES " +
+      "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
+      "(4, 'bb', 10.0, cast('2021-01-01' as timestamp)), " +
+      "(4, 'cc', 15.5, cast('2021-02-01' as timestamp))")
+
+    val metrics = runAndFetchMetrics {
+      val df = sql(s"SELECT * FROM testcat.ns.$items")
+      val scans = collectScans(df.queryExecution.executedPlan)
+      assert(scans(0).inputRDD.partitions.length === 2, "items scan should 
have 2 partition groups")
+      df.collect()
+    }
+    assert(metrics("number of rows read") == "3")
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/InMemoryTableMetricSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/InMemoryTableMetricSuite.scala
index 7e8a95f4d0cd..502424d58d2c 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/InMemoryTableMetricSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/InMemoryTableMetricSuite.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.datasources
 import java.util.Collections
 
 import org.scalatest.BeforeAndAfter
-import org.scalatest.time.SpanSugar._
 
 import org.apache.spark.sql.QueryTest
 import org.apache.spark.sql.connector.catalog.{Column, Identifier, 
InMemoryTable, InMemoryTableCatalog}
@@ -54,27 +53,8 @@ class InMemoryTableMetricSuite
         Array(Column.create("i", IntegerType)),
         Array.empty[Transform], Collections.emptyMap[String, String])
 
-      func("testcat.table_name")
+      val metrics = runAndFetchMetrics(func("testcat.table_name"))
 
-      // Wait until the new execution is started and being tracked.
-      eventually(timeout(10.seconds), interval(10.milliseconds)) {
-        assert(statusStore.executionsCount() >= oldCount)
-      }
-
-      // Wait for listener to finish computing the metrics for the execution.
-      eventually(timeout(10.seconds), interval(10.milliseconds)) {
-        assert(statusStore.executionsList().nonEmpty &&
-          statusStore.executionsList().last.metricValues != null)
-      }
-
-      val exec = statusStore.executionsList().last
-      val execId = exec.executionId
-      val sqlMetrics = exec.metrics.map { metric =>
-        metric.accumulatorId -> metric.name
-      }.toMap
-      val metrics = statusStore.executionMetrics(execId).map { case (k, v) =>
-        sqlMetrics(k) -> v
-      }
       checker(metrics)
     }
   }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
index 245219c1756d..720b13b812e0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
@@ -54,6 +54,29 @@ trait SharedSparkSession extends SQLTestUtils with 
SharedSparkSessionBase {
       doThreadPostAudit()
     }
   }
+
+  def runAndFetchMetrics(func: => Unit): Map[String, String] = {
+    val statusStore = spark.sharedState.statusStore
+    val oldCount = statusStore.executionsList().size
+
+    func
+
+    // Wait until the new execution is started and being tracked.
+    eventually(timeout(10.seconds), interval(10.milliseconds)) {
+      assert(statusStore.executionsCount() >= oldCount)
+    }
+
+    // Wait for listener to finish computing the metrics for the execution.
+    eventually(timeout(10.seconds), interval(10.milliseconds)) {
+      assert(statusStore.executionsList().nonEmpty &&
+        statusStore.executionsList().last.metricValues != null)
+    }
+
+    val exec = statusStore.executionsList().last
+    val execId = exec.executionId
+    val sqlMetrics = exec.metrics.map { metric => metric.accumulatorId -> 
metric.name }.toMap
+    statusStore.executionMetrics(execId).map { case (k, v) => sqlMetrics(k) -> 
v }
+  }
 }
 
 /**


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to