Copilot commented on code in PR #1982:
URL: https://github.com/apache/auron/pull/1982#discussion_r2762558400


##########
spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeHelper.scala:
##########
@@ -167,38 +167,50 @@ object NativeHelper extends Logging {
       })
   }
 
-  def getDefaultNativeMetrics(sc: SparkContext): Map[String, SQLMetric] = {
-    def metric(name: String) = SQLMetrics.createMetric(sc, name)
-    def nanoTimingMetric(name: String) = SQLMetrics.createNanoTimingMetric(sc, 
name)
-    def sizeMetric(name: String) = SQLMetrics.createSizeMetric(sc, name)
-
-    var metrics = TreeMap(
-      "stage_id" -> metric("stageId"),
-      "output_rows" -> metric("Native.output_rows"),
-      "output_batches" -> metric("Native.output_batches"),
-      "elapsed_compute" -> nanoTimingMetric("Native.elapsed_compute"),
-      "build_hash_map_time" -> nanoTimingMetric("Native.build_hash_map_time"),
-      "probed_side_hash_time" -> 
nanoTimingMetric("Native.probed_side_hash_time"),
-      "probed_side_search_time" -> 
nanoTimingMetric("Native.probed_side_search_time"),
-      "probed_side_compare_time" -> 
nanoTimingMetric("Native.probed_side_compare_time"),
-      "build_output_time" -> nanoTimingMetric("Native.build_output_time"),
-      "fallback_sort_merge_join_time" -> 
nanoTimingMetric("Native.fallback_sort_merge_join_time"),
-      "mem_spill_count" -> metric("Native.mem_spill_count"),
-      "mem_spill_size" -> sizeMetric("Native.mem_spill_size"),
-      "mem_spill_iotime" -> nanoTimingMetric("Native.mem_spill_iotime"),
-      "disk_spill_size" -> sizeMetric("Native.disk_spill_size"),
-      "disk_spill_iotime" -> nanoTimingMetric("Native.disk_spill_iotime"),
-      "shuffle_write_total_time" -> 
nanoTimingMetric("Native.shuffle_write_total_time"),
-      "shuffle_read_total_time" -> 
nanoTimingMetric("Native.shuffle_read_total_time"))
-
-    if (AuronAdaptor.getInstance.getAuronConfiguration.getBoolean(
-        SparkAuronConfiguration.INPUT_BATCH_STATISTICS_ENABLE)) {
-      metrics ++= TreeMap(
-        "input_batch_count" -> metric("Native.input_batches"),
-        "input_row_count" -> metric("Native.input_rows"),
-        "input_batch_mem_size" -> sizeMetric("Native.input_mem_bytes"))
+  private val defaultNativeMetricCreators: Map[String, SparkContext => 
SQLMetric] = Map(
+    "stage_id" -> (sc => SQLMetrics.createMetric(sc, "stageId")),
+    "output_rows" -> (sc => SQLMetrics.createMetric(sc, "Native.output_rows")),
+    "output_batches" -> (sc => SQLMetrics.createMetric(sc, 
"Native.output_batches")),
+    "elapsed_compute" -> (sc => SQLMetrics.createNanoTimingMetric(sc, 
"Native.elapsed_compute")),
+    "build_hash_map_time" -> (sc =>
+      SQLMetrics.createNanoTimingMetric(sc, "Native.build_hash_map_time")),
+    "probed_side_hash_time" -> (sc =>
+      SQLMetrics.createNanoTimingMetric(sc, "Native.probed_side_hash_time")),
+    "probed_side_search_time" -> (sc =>
+      SQLMetrics.createNanoTimingMetric(sc, "Native.probed_side_search_time")),
+    "probed_side_compare_time" -> (sc =>
+      SQLMetrics.createNanoTimingMetric(sc, 
"Native.probed_side_compare_time")),
+    "build_output_time" -> (sc =>
+      SQLMetrics.createNanoTimingMetric(sc, "Native.build_output_time")),
+    "fallback_sort_merge_join_time" -> (sc =>
+      SQLMetrics.createNanoTimingMetric(sc, 
"Native.fallback_sort_merge_join_time")),
+    "mem_spill_count" -> (sc => SQLMetrics.createMetric(sc, 
"Native.mem_spill_count")),
+    "mem_spill_size" -> (sc => SQLMetrics.createSizeMetric(sc, 
"Native.mem_spill_size")),
+    "mem_spill_iotime" -> (sc =>
+      SQLMetrics.createNanoTimingMetric(sc, "Native.mem_spill_iotime")),
+    "disk_spill_size" -> (sc => SQLMetrics.createSizeMetric(sc, 
"Native.disk_spill_size")),
+    "disk_spill_iotime" -> (sc =>
+      SQLMetrics.createNanoTimingMetric(sc, "Native.disk_spill_iotime")),
+    "shuffle_write_total_time" -> (sc =>
+      SQLMetrics.createNanoTimingMetric(sc, 
"Native.shuffle_write_total_time")),
+    "shuffle_read_total_time" -> (sc =>
+      SQLMetrics.createNanoTimingMetric(sc, "Native.shuffle_read_total_time")),
+    "input_batch_count" -> (sc => SQLMetrics.createMetric(sc, 
"Native.input_batches")),
+    "input_row_count" -> (sc => SQLMetrics.createMetric(sc, 
"Native.input_rows")),
+    "input_batch_mem_size" -> (sc => SQLMetrics.createSizeMetric(sc, 
"Native.input_mem_bytes")))
+
+  def getDefaultNativeMetrics(sc: SparkContext, keys: Set[String]): 
Map[String, SQLMetric] = {
+    val enabledKeys =
+      if (AuronAdaptor.getInstance.getAuronConfiguration.getBoolean(
+          SparkAuronConfiguration.INPUT_BATCH_STATISTICS_ENABLE)) {
+        keys
+      } else {
+        keys -- Set("input_batch_count", "input_row_count", 
"input_batch_mem_size")
+      }
+
+    TreeMap[String, SQLMetric]() ++ enabledKeys.flatMap { key =>

Review Comment:
   `enabledKeys` is a `Set`, so `enabledKeys.flatMap { ... }` produces a 
`Set[(String, SQLMetric)]`. That forces hashing/equality on `SQLMetric` 
instances and allocates an intermediate Set before building the TreeMap. To 
keep the optimization goal, consider iterating 
(`enabledKeys.iterator.flatMap(...)`) and accumulating directly into the 
TreeMap (or using `enabledKeys.iterator.map(...)` + `collect`), which avoids 
building a Set of metrics first.
   ```suggestion
       TreeMap[String, SQLMetric]() ++ enabledKeys.iterator.flatMap { key =>
   ```



##########
spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeHelper.scala:
##########
@@ -167,38 +167,50 @@ object NativeHelper extends Logging {
       })
   }
 
-  def getDefaultNativeMetrics(sc: SparkContext): Map[String, SQLMetric] = {
-    def metric(name: String) = SQLMetrics.createMetric(sc, name)
-    def nanoTimingMetric(name: String) = SQLMetrics.createNanoTimingMetric(sc, 
name)
-    def sizeMetric(name: String) = SQLMetrics.createSizeMetric(sc, name)
-
-    var metrics = TreeMap(
-      "stage_id" -> metric("stageId"),
-      "output_rows" -> metric("Native.output_rows"),
-      "output_batches" -> metric("Native.output_batches"),
-      "elapsed_compute" -> nanoTimingMetric("Native.elapsed_compute"),
-      "build_hash_map_time" -> nanoTimingMetric("Native.build_hash_map_time"),
-      "probed_side_hash_time" -> 
nanoTimingMetric("Native.probed_side_hash_time"),
-      "probed_side_search_time" -> 
nanoTimingMetric("Native.probed_side_search_time"),
-      "probed_side_compare_time" -> 
nanoTimingMetric("Native.probed_side_compare_time"),
-      "build_output_time" -> nanoTimingMetric("Native.build_output_time"),
-      "fallback_sort_merge_join_time" -> 
nanoTimingMetric("Native.fallback_sort_merge_join_time"),
-      "mem_spill_count" -> metric("Native.mem_spill_count"),
-      "mem_spill_size" -> sizeMetric("Native.mem_spill_size"),
-      "mem_spill_iotime" -> nanoTimingMetric("Native.mem_spill_iotime"),
-      "disk_spill_size" -> sizeMetric("Native.disk_spill_size"),
-      "disk_spill_iotime" -> nanoTimingMetric("Native.disk_spill_iotime"),
-      "shuffle_write_total_time" -> 
nanoTimingMetric("Native.shuffle_write_total_time"),
-      "shuffle_read_total_time" -> 
nanoTimingMetric("Native.shuffle_read_total_time"))
-
-    if (AuronAdaptor.getInstance.getAuronConfiguration.getBoolean(
-        SparkAuronConfiguration.INPUT_BATCH_STATISTICS_ENABLE)) {
-      metrics ++= TreeMap(
-        "input_batch_count" -> metric("Native.input_batches"),
-        "input_row_count" -> metric("Native.input_rows"),
-        "input_batch_mem_size" -> sizeMetric("Native.input_mem_bytes"))
+  private val defaultNativeMetricCreators: Map[String, SparkContext => 
SQLMetric] = Map(
+    "stage_id" -> (sc => SQLMetrics.createMetric(sc, "stageId")),
+    "output_rows" -> (sc => SQLMetrics.createMetric(sc, "Native.output_rows")),
+    "output_batches" -> (sc => SQLMetrics.createMetric(sc, 
"Native.output_batches")),
+    "elapsed_compute" -> (sc => SQLMetrics.createNanoTimingMetric(sc, 
"Native.elapsed_compute")),
+    "build_hash_map_time" -> (sc =>
+      SQLMetrics.createNanoTimingMetric(sc, "Native.build_hash_map_time")),
+    "probed_side_hash_time" -> (sc =>
+      SQLMetrics.createNanoTimingMetric(sc, "Native.probed_side_hash_time")),
+    "probed_side_search_time" -> (sc =>
+      SQLMetrics.createNanoTimingMetric(sc, "Native.probed_side_search_time")),
+    "probed_side_compare_time" -> (sc =>
+      SQLMetrics.createNanoTimingMetric(sc, 
"Native.probed_side_compare_time")),
+    "build_output_time" -> (sc =>
+      SQLMetrics.createNanoTimingMetric(sc, "Native.build_output_time")),
+    "fallback_sort_merge_join_time" -> (sc =>
+      SQLMetrics.createNanoTimingMetric(sc, 
"Native.fallback_sort_merge_join_time")),
+    "mem_spill_count" -> (sc => SQLMetrics.createMetric(sc, 
"Native.mem_spill_count")),
+    "mem_spill_size" -> (sc => SQLMetrics.createSizeMetric(sc, 
"Native.mem_spill_size")),
+    "mem_spill_iotime" -> (sc =>
+      SQLMetrics.createNanoTimingMetric(sc, "Native.mem_spill_iotime")),
+    "disk_spill_size" -> (sc => SQLMetrics.createSizeMetric(sc, 
"Native.disk_spill_size")),
+    "disk_spill_iotime" -> (sc =>
+      SQLMetrics.createNanoTimingMetric(sc, "Native.disk_spill_iotime")),
+    "shuffle_write_total_time" -> (sc =>
+      SQLMetrics.createNanoTimingMetric(sc, 
"Native.shuffle_write_total_time")),
+    "shuffle_read_total_time" -> (sc =>
+      SQLMetrics.createNanoTimingMetric(sc, "Native.shuffle_read_total_time")),
+    "input_batch_count" -> (sc => SQLMetrics.createMetric(sc, 
"Native.input_batches")),
+    "input_row_count" -> (sc => SQLMetrics.createMetric(sc, 
"Native.input_rows")),
+    "input_batch_mem_size" -> (sc => SQLMetrics.createSizeMetric(sc, 
"Native.input_mem_bytes")))
+

Review Comment:
   Changing `getDefaultNativeMetrics` from a 1-arg method to a 2-arg method is 
a source/binary breaking change for any downstream code compiled against this 
module. If this object is part of a published API surface, consider keeping an 
overloaded `getDefaultNativeMetrics(sc: SparkContext)` (possibly deprecated) 
that delegates to the new implementation with the full default key set.
   ```suggestion
   
     def getDefaultNativeMetrics(sc: SparkContext): Map[String, SQLMetric] = {
       getDefaultNativeMetrics(sc, defaultNativeMetricCreators.keySet)
     }
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to