comphead commented on code in PR #3842:
URL: https://github.com/apache/datafusion-comet/pull/3842#discussion_r3016561532


##########
spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala:
##########
@@ -91,4 +94,66 @@ class CometTaskMetricsSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
       }
     }
   }
+
+  test("native_datafusion scan reports task-level input metrics matching 
Spark") {
+    withParquetTable((0 until 10000).map(i => (i, (i + 1).toLong)), "tbl") {
+      // Collect baseline input metrics from vanilla Spark (Comet disabled)
+      val (sparkBytes, sparkRecords) = 
collectInputMetrics(CometConf.COMET_ENABLED.key -> "false")
+
+      // Collect input metrics from Comet native_datafusion scan
+      val (cometBytes, cometRecords) = collectInputMetrics(
+        CometConf.COMET_NATIVE_SCAN_IMPL.key -> 
CometConf.SCAN_NATIVE_DATAFUSION)
+
+      // Records must match exactly
+      assert(
+        cometRecords == sparkRecords,
+        s"recordsRead mismatch: comet=$cometRecords, spark=$sparkRecords")
+
+      // Bytes should be in the same ballpark -- both read the same Parquet 
file(s),
+      // but the exact byte count can differ due to reader implementation 
details
+      // (e.g. footer reads, page headers, buffering granularity).
+      assert(sparkBytes > 0, s"Spark bytesRead should be > 0, got $sparkBytes")
+      assert(cometBytes > 0, s"Comet bytesRead should be > 0, got $cometBytes")
+      val ratio = cometBytes.toDouble / sparkBytes.toDouble
+      assert(
+        ratio >= 0.8 && ratio <= 1.2,
+        s"bytesRead ratio out of range: comet=$cometBytes, spark=$sparkBytes, 
ratio=$ratio")
+    }
+  }
+
+  /**
+   * Runs `SELECT * FROM tbl` with the given SQL config overrides and returns 
the aggregated
+   * (bytesRead, recordsRead) across all tasks.
+   */
+  private def collectInputMetrics(confs: (String, String)*): (Long, Long) = {
+    val inputMetricsList = mutable.ArrayBuffer.empty[InputMetrics]
+
+    val listener = new SparkListener {
+      override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
+        val im = taskEnd.taskMetrics.inputMetrics
+        inputMetricsList.synchronized {
+          inputMetricsList += im
+        }
+      }
+    }
+
+    spark.sparkContext.addSparkListener(listener)
+    try {
+      // Drain any earlier events
+      spark.sparkContext.listenerBus.waitUntilEmpty()
+
+      withSQLConf(confs: _*) {
+        sql("SELECT * FROM tbl").collect()

Review Comment:
   Thanks @martin-g why the filter would be needed? I'd prefer to keep repro as 
simple as possible



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to