0lai0 commented on code in PR #3999:
URL: https://github.com/apache/datafusion-comet/pull/3999#discussion_r3114877914
##########
spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala:
##########
@@ -100,6 +105,73 @@ class CometTaskMetricsSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
}
}
+ test("native parquet write reports task-level output metrics") {
+ withParquetTable((0 until 5000).map(i => (i, (i + 1).toLong)), "tbl") {
+ withTempPath { dir =>
+ val outPath = new File(dir, "written").getAbsolutePath
+ val outputBytes = mutable.ArrayBuffer.empty[Long]
+ val outputRecords = mutable.ArrayBuffer.empty[Long]
+ val targetStageIds = mutable.HashSet.empty[Int]
+ val jobGroupId =
s"native-write-metrics-${java.util.UUID.randomUUID().toString}"
+
+ val listener = new SparkListener {
+ override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
+ val isTargetJob = Option(jobStart.properties)
+ .flatMap(props =>
Option(props.getProperty(SparkContext.SPARK_JOB_GROUP_ID)))
+ .contains(jobGroupId)
+ if (isTargetJob) {
+ targetStageIds.synchronized {
+ targetStageIds ++= jobStart.stageInfos.map(_.stageId)
+ }
+ }
+ }
+
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
+ val isTargetStage = targetStageIds.synchronized {
+ targetStageIds.contains(taskEnd.stageId)
+ }
+ if (isTargetStage) {
+ val om = taskEnd.taskMetrics.outputMetrics
+ if (om.bytesWritten > 0) {
+ outputBytes.synchronized {
+ outputBytes += om.bytesWritten
+ outputRecords += om.recordsWritten
+ }
+ }
+ }
+ }
+ }
+ spark.sparkContext.addSparkListener(listener)
+
+ try {
+ spark.sparkContext.listenerBus.waitUntilEmpty()
+
+ withSQLConf(
+ CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true",
+ CometConf.COMET_EXEC_ENABLED.key -> "true",
+ CometConf.getOperatorAllowIncompatConfigKey(
+ classOf[DataWritingCommandExec]) -> "true",
+ SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Halifax") {
+ spark.sparkContext.setJobGroup(jobGroupId, "native parquet write
output metrics")
+ try {
+ sql("SELECT * FROM tbl").write.parquet(outPath)
+ } finally {
+ spark.sparkContext.clearJobGroup()
+ }
+ }
+
+ spark.sparkContext.listenerBus.waitUntilEmpty()
+
+ assert(outputBytes.nonEmpty, "No task reported
outputMetrics.bytesWritten")
Review Comment:
Updated. Now assert recordsWritten with exact equality, and use the existing
suite convention for bytes approximation (ratio in 0.7–1.3), consistent with
other input-metrics checks in CometTaskMetricsSuite.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]