xuanyuanking closed pull request #23378: [SPARK-26225][SQL] Track decoding time 
for each record in scan node
URL: https://github.com/apache/spark/pull/23378
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
index 322ffffca564b..dc15cee29c47e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
@@ -298,7 +298,7 @@ case class FileSourceScanExec(
   }
 
   private lazy val inputRDD: RDD[InternalRow] = {
-    val readFile: (PartitionedFile) => Iterator[InternalRow] =
+    val readFunc: (PartitionedFile) => Iterator[InternalRow] =
       relation.fileFormat.buildReaderWithPartitionValues(
         sparkSession = relation.sparkSession,
         dataSchema = relation.dataSchema,
@@ -308,6 +308,15 @@ case class FileSourceScanExec(
         options = relation.options,
         hadoopConf = 
relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options))
 
+    // Wrap the function by calculating the time cost of records decode.
+    val readFile: (PartitionedFile) => Iterator[InternalRow] = file => {
+      val startTime = System.nanoTime()
+      val rowIter = readFunc.apply(file)
+      val timeNs = System.nanoTime() - startTime
+      longMetric("recordDecodingTime").add(timeNs)
+      rowIter
+    }
+
     val readRDD = relation.bucketSpec match {
       case Some(bucketing) if 
relation.sparkSession.sessionState.conf.bucketingEnabled =>
         createBucketedReadRDD(bucketing, readFile, selectedPartitions, 
relation)
@@ -326,7 +335,9 @@ case class FileSourceScanExec(
     Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of 
output rows"),
       "numFiles" -> SQLMetrics.createMetric(sparkContext, "number of files"),
       "metadataTime" -> SQLMetrics.createMetric(sparkContext, "metadata time"),
-      "scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time"))
+      "scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time"),
+      "recordDecodingTime" ->
+        SQLMetrics.createNanoTimingMetric(sparkContext, "record decoding 
time"))
 
   protected override def doExecute(): RDD[InternalRow] = {
     if (supportsBatch) {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index 47265df4831df..6989d7c034263 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -26,7 +26,7 @@ import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial}
 import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
-import org.apache.spark.sql.execution.{FilterExec, RangeExec, SparkPlan, 
WholeStageCodegenExec}
+import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.aggregate.HashAggregateExec
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
@@ -552,11 +552,14 @@ class SQLMetricsSuite extends SparkFunSuite with 
SQLMetricsTestUtils with Shared
       // The execution plan only has 1 FileScan node.
       val df = spark.sql(
         "SELECT * FROM testDataForScan WHERE p = 1")
-      testSparkPlanMetrics(df, 1, Map(
-        0L -> (("Scan parquet default.testdataforscan", Map(
-          "number of output rows" -> 3L,
-          "number of files" -> 2L))))
-      )
+      df.collect()
+      val metrics = df.queryExecution.executedPlan.collectLeaves()
+        .head.asInstanceOf[FileSourceScanExec].metrics
+      // Check deterministic metrics.
+      assert(metrics("numFiles").value == 2)
+      assert(metrics("numOutputRows").value == 3)
+      // Check decoding time metric changed.
+      assert(metrics("recordDecodingTime").value > 0)
     }
   }
 }


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to