This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a452e1bc2189 [SPARK-50977][CORE] Enhance availability of logic 
performing aggregation of accumulator results
a452e1bc2189 is described below

commit a452e1bc2189d4dda50df5c36a49e4d23e6db758
Author: Costas Zarifis <[email protected]>
AuthorDate: Tue Jan 28 09:39:12 2025 +0900

    [SPARK-50977][CORE] Enhance availability of logic performing aggregation of 
accumulator results
    
    ### What changes were proposed in this pull request?
    
    In this PR we introduce a minor refactor that enhances the availability of 
the functionality used to perform aggregation of accumulator results.
    
    ### Why are the changes needed?
    These changes make the aggregation logic accessible from other modules 
which enables various memory and disk optimizations.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    This is a minor refactor. No new code has been added, therefore no new 
tests are needed.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #49618 from costas-db/costas-zarifis_data/refactorMetrics.
    
    Lead-authored-by: Costas Zarifis <[email protected]>
    Co-authored-by: Costas Zarifis 
<[email protected]>
    Co-authored-by: Hyukjin Kwon <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../scala/org/apache/spark/util/MetricUtils.scala  | 106 +++++++++++++++++++++
 .../datasources/v2/python/PythonCustomMetric.scala |   5 +-
 .../spark/sql/execution/metric/SQLMetrics.scala    |  86 +----------------
 .../sql/execution/ui/SQLAppStatusListener.scala    |   8 +-
 .../execution/ui/SQLAppStatusListenerSuite.scala   |   6 +-
 5 files changed, 121 insertions(+), 90 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/util/MetricUtils.scala 
b/core/src/main/scala/org/apache/spark/util/MetricUtils.scala
new file mode 100644
index 000000000000..a6166f2129d1
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/MetricUtils.scala
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.text.NumberFormat
+import java.util.{Arrays, Locale}
+
+import scala.concurrent.duration._
+
+import org.apache.spark.SparkException
+import org.apache.spark.util.Utils
+
+private[spark] object MetricUtils {
+
+  val SUM_METRIC: String = "sum"
+  val SIZE_METRIC: String = "size"
+  val TIMING_METRIC: String = "timing"
+  val NS_TIMING_METRIC: String = "nsTiming"
+  val AVERAGE_METRIC: String = "average"
+  private val baseForAvgMetric: Int = 10
+  private val METRICS_NAME_SUFFIX = "(min, med, max (stageId: taskId))"
+
+  private def toNumberFormat(value: Long): String = {
+    val numberFormat = NumberFormat.getNumberInstance(Locale.US)
+    numberFormat.format(value.toDouble / baseForAvgMetric)
+  }
+
+  def metricNeedsMax(metricsType: String): Boolean = {
+    metricsType != SUM_METRIC
+  }
+
+/**
+   * A function that defines how we aggregate the final accumulator results 
among all tasks,
+   * and represent it in string for a SQL physical operator.
+    */
+  def stringValue(metricsType: String, values: Array[Long], maxMetrics: 
Array[Long]): String = {
+    // taskInfo = "(driver)" OR (stage ${stageId}.${attemptId}: task $taskId)
+    val taskInfo = if (maxMetrics.isEmpty) {
+      "(driver)"
+    } else {
+      s"(stage ${maxMetrics(1)}.${maxMetrics(2)}: task ${maxMetrics(3)})"
+    }
+    if (metricsType == SUM_METRIC) {
+      val numberFormat = NumberFormat.getIntegerInstance(Locale.US)
+      numberFormat.format(values.sum)
+    } else if (metricsType == AVERAGE_METRIC) {
+      val validValues = values.filter(_ > 0)
+      // When there are only 1 metrics value (or None), no need to display 
max/min/median. This is
+      // common for driver-side SQL metrics.
+      if (validValues.length <= 1) {
+        toNumberFormat(validValues.headOption.getOrElse(0))
+      } else {
+        val Seq(min, med, max) = {
+          Arrays.sort(validValues)
+          Seq(
+            toNumberFormat(validValues(0)),
+            toNumberFormat(validValues(validValues.length / 2)),
+            toNumberFormat(validValues(validValues.length - 1)))
+        }
+        s"$METRICS_NAME_SUFFIX:\n($min, $med, $max $taskInfo)"
+      }
+    } else {
+      val strFormat: Long => String = if (metricsType == SIZE_METRIC) {
+        Utils.bytesToString
+      } else if (metricsType == TIMING_METRIC) {
+        Utils.msDurationToString
+      } else if (metricsType == NS_TIMING_METRIC) {
+        duration => Utils.msDurationToString(duration.nanos.toMillis)
+      } else {
+        throw SparkException.internalError(s"unexpected metrics type: 
$metricsType")
+      }
+
+      val validValues = values.filter(_ >= 0)
+      // When there are only 1 metrics value (or None), no need to display 
max/min/median. This is
+      // common for driver-side SQL metrics.
+      if (validValues.length <= 1) {
+        strFormat(validValues.headOption.getOrElse(0))
+      } else {
+        val Seq(sum, min, med, max) = {
+          Arrays.sort(validValues)
+          Seq(
+            strFormat(validValues.sum),
+            strFormat(validValues(0)),
+            strFormat(validValues(validValues.length / 2)),
+            strFormat(validValues(validValues.length - 1)))
+        }
+        s"total $METRICS_NAME_SUFFIX\n$sum ($min, $med, $max $taskInfo)"
+      }
+    }
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonCustomMetric.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonCustomMetric.scala
index 7551cd04f20f..2db2ff74374c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonCustomMetric.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonCustomMetric.scala
@@ -17,8 +17,9 @@
 package org.apache.spark.sql.execution.datasources.v2.python
 
 import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric}
-import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
+import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.execution.python.PythonSQLMetrics
+import org.apache.spark.util.MetricUtils
 
 
 class PythonCustomMetric(
@@ -28,7 +29,7 @@ class PythonCustomMetric(
   def this() = this(null, null)
 
   override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = {
-    SQLMetrics.stringValue("size", taskMetrics, Array.empty[Long])
+    MetricUtils.stringValue("size", taskMetrics, Array.empty[Long])
   }
 }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
index f0c1c0900c7f..065c8db7ac6f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
@@ -17,19 +17,14 @@
 
 package org.apache.spark.sql.execution.metric
 
-import java.text.NumberFormat
-import java.util.{Arrays, Locale}
-
-import scala.concurrent.duration._
-
 import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache}
 
-import org.apache.spark.{SparkContext, SparkException}
+import org.apache.spark.SparkContext
 import org.apache.spark.scheduler.AccumulableInfo
 import org.apache.spark.sql.connector.metric.CustomMetric
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates
-import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, Utils}
+import org.apache.spark.util.{AccumulatorContext, AccumulatorV2}
 import org.apache.spark.util.AccumulatorContext.internOption
 
 /**
@@ -72,7 +67,7 @@ class SQLMetric(
 
   // This is used to filter out metrics. Metrics with value equal to initValue 
should
   // be filtered out, since they are either invalid or safe to filter without 
changing
-  // the aggregation defined in [[SQLMetrics.stringValue]].
+  // the aggregation defined in [[MetricUtils.stringValue]].
   // Note that we don't use 0 here since we may want to collect 0 metrics for
   // calculating min, max, etc. See SPARK-11013.
   override def isZero: Boolean = _value == initValue
@@ -106,8 +101,8 @@ class SQLMetric(
       SQLMetrics.cachedSQLAccumIdentifier)
   }
 
-  // We should provide the raw value which can be -1, so that 
`SQLMetrics.stringValue` can correctly
-  // filter out the invalid -1 values.
+  // We should provide the raw value which can be -1, so that 
`MetricUtils.stringValue` can
+  // correctly filter out the invalid -1 values.
   override def toInfoUpdate: AccumulableInfo = {
     AccumulableInfo(id, name, internOption(Some(_value)), None, true, true,
       SQLMetrics.cachedSQLAccumIdentifier)
@@ -203,77 +198,6 @@ object SQLMetrics {
     acc
   }
 
-  private def toNumberFormat(value: Long): String = {
-    val numberFormat = NumberFormat.getNumberInstance(Locale.US)
-    numberFormat.format(value.toDouble / baseForAvgMetric)
-  }
-
-  def metricNeedsMax(metricsType: String): Boolean = {
-    metricsType != SUM_METRIC
-  }
-
-  private val METRICS_NAME_SUFFIX = "(min, med, max (stageId: taskId))"
-
-  /**
-   * A function that defines how we aggregate the final accumulator results 
among all tasks,
-   * and represent it in string for a SQL physical operator.
-    */
-  def stringValue(metricsType: String, values: Array[Long], maxMetrics: 
Array[Long]): String = {
-    // taskInfo = "(driver)" OR (stage ${stageId}.${attemptId}: task $taskId)
-    val taskInfo = if (maxMetrics.isEmpty) {
-      "(driver)"
-    } else {
-      s"(stage ${maxMetrics(1)}.${maxMetrics(2)}: task ${maxMetrics(3)})"
-    }
-    if (metricsType == SUM_METRIC) {
-      val numberFormat = NumberFormat.getIntegerInstance(Locale.US)
-      numberFormat.format(values.sum)
-    } else if (metricsType == AVERAGE_METRIC) {
-      val validValues = values.filter(_ > 0)
-      // When there are only 1 metrics value (or None), no need to display 
max/min/median. This is
-      // common for driver-side SQL metrics.
-      if (validValues.length <= 1) {
-        toNumberFormat(validValues.headOption.getOrElse(0))
-      } else {
-        val Seq(min, med, max) = {
-          Arrays.sort(validValues)
-          Seq(
-            toNumberFormat(validValues(0)),
-            toNumberFormat(validValues(validValues.length / 2)),
-            toNumberFormat(validValues(validValues.length - 1)))
-        }
-        s"$METRICS_NAME_SUFFIX:\n($min, $med, $max $taskInfo)"
-      }
-    } else {
-      val strFormat: Long => String = if (metricsType == SIZE_METRIC) {
-        Utils.bytesToString
-      } else if (metricsType == TIMING_METRIC) {
-        Utils.msDurationToString
-      } else if (metricsType == NS_TIMING_METRIC) {
-        duration => Utils.msDurationToString(duration.nanos.toMillis)
-      } else {
-        throw SparkException.internalError(s"unexpected metrics type: 
$metricsType")
-      }
-
-      val validValues = values.filter(_ >= 0)
-      // When there are only 1 metrics value (or None), no need to display 
max/min/median. This is
-      // common for driver-side SQL metrics.
-      if (validValues.length <= 1) {
-        strFormat(validValues.headOption.getOrElse(0))
-      } else {
-        val Seq(sum, min, med, max) = {
-          Arrays.sort(validValues)
-          Seq(
-            strFormat(validValues.sum),
-            strFormat(validValues(0)),
-            strFormat(validValues(validValues.length / 2)),
-            strFormat(validValues(validValues.length - 1)))
-        }
-        s"total $METRICS_NAME_SUFFIX\n$sum ($min, $med, $max $taskInfo)"
-      }
-    }
-  }
-
   def postDriverMetricsUpdatedByValue(
       sc: SparkContext,
       executionId: String,
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala
index 3c8c7edfeb06..f680860231f0 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala
@@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.SQLExecution
 import org.apache.spark.sql.execution.metric._
 import org.apache.spark.sql.internal.StaticSQLConf._
 import org.apache.spark.status.{ElementTrackingStore, KVUtils, LiveEntity}
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{MetricUtils, Utils}
 import org.apache.spark.util.collection.OpenHashMap
 
 class SQLAppStatusListener(
@@ -235,7 +235,7 @@ class SQLAppStatusListener(
         }
       }.getOrElse(
         // Built-in SQLMetric
-        SQLMetrics.stringValue(m.metricType, _, _)
+        MetricUtils.stringValue(m.metricType, _, _)
       )
       (m.accumulatorId, metricAggMethod)
     }.toMap
@@ -554,7 +554,7 @@ private class LiveStageMetrics(
   /**
    * Task metrics values for the stage. Maps the metric ID to the metric 
values for each
    * index. For each metric ID, there will be the same number of values as the 
number
-   * of indices. This relies on `SQLMetrics.stringValue` treating 0 as a 
neutral value,
+   * of indices. This relies on `MetricUtils.stringValue` treating 0 as a 
neutral value,
    * independent of the actual metric type.
    */
   private val taskMetrics = new ConcurrentHashMap[Long, Array[Long]]()
@@ -601,7 +601,7 @@ private class LiveStageMetrics(
         val metricValues = taskMetrics.computeIfAbsent(acc.id, _ => new 
Array(numTasks))
         metricValues(taskIdx) = value
 
-        if (SQLMetrics.metricNeedsMax(accumIdsToMetricType(acc.id))) {
+        if (MetricUtils.metricNeedsMax(accumIdsToMetricType(acc.id))) {
           val maxMetricsTaskId = 
metricsIdToMaxTaskValue.computeIfAbsent(acc.id, _ => Array(value,
             taskId))
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala
index 45fce208c078..800a58f0c1d6 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala
@@ -57,7 +57,7 @@ import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.sql.util.CaseInsensitiveStringMap
 import org.apache.spark.status.{AppStatusStore, ElementTrackingStore}
-import org.apache.spark.util.{AccumulatorMetadata, JsonProtocol, 
LongAccumulator, SerializableConfiguration, Utils}
+import org.apache.spark.util.{AccumulatorMetadata, JsonProtocol, 
LongAccumulator, MetricUtils, SerializableConfiguration, Utils}
 import org.apache.spark.util.kvstore.InMemoryStore
 
 
@@ -597,9 +597,9 @@ abstract class SQLAppStatusListenerSuite extends 
SharedSparkSession with JsonTes
     val metrics = statusStore.executionMetrics(execId)
     val driverMetric = physicalPlan.metrics("dummy")
     val driverMetric2 = physicalPlan.metrics("dummy2")
-    val expectedValue = SQLMetrics.stringValue(driverMetric.metricType,
+    val expectedValue = MetricUtils.stringValue(driverMetric.metricType,
       Array(expectedAccumValue), Array.empty[Long])
-    val expectedValue2 = SQLMetrics.stringValue(driverMetric2.metricType,
+    val expectedValue2 = MetricUtils.stringValue(driverMetric2.metricType,
       Array(expectedAccumValue2), Array.empty[Long])
 
     assert(metrics.contains(driverMetric.id))


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to