This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.0 by this push:
new d3fad477295e [SPARK-50977][CORE] Enhance availability of logic
performing aggregation of accumulator results
d3fad477295e is described below
commit d3fad477295e332a4d6c4414b6d098b3a9850a80
Author: Costas Zarifis <[email protected]>
AuthorDate: Tue Jan 28 09:39:12 2025 +0900
[SPARK-50977][CORE] Enhance availability of logic performing aggregation of
accumulator results
### What changes were proposed in this pull request?
In this PR we introduce a minor refactor that enhances the availability of
the functionality used to perform aggregation of accumulator results.
### Why are the changes needed?
These changes make the aggregation logic accessible from other modules
which enables various memory and disk optimizations.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
This is a minor refactor. No new code has been added, therefore no new
tests are needed.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #49618 from costas-db/costas-zarifis_data/refactorMetrics.
Lead-authored-by: Costas Zarifis <[email protected]>
Co-authored-by: Costas Zarifis
<[email protected]>
Co-authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
(cherry picked from commit a452e1bc2189d4dda50df5c36a49e4d23e6db758)
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../scala/org/apache/spark/util/MetricUtils.scala | 106 +++++++++++++++++++++
.../datasources/v2/python/PythonCustomMetric.scala | 5 +-
.../spark/sql/execution/metric/SQLMetrics.scala | 86 +----------------
.../sql/execution/ui/SQLAppStatusListener.scala | 8 +-
.../execution/ui/SQLAppStatusListenerSuite.scala | 6 +-
5 files changed, 121 insertions(+), 90 deletions(-)
diff --git a/core/src/main/scala/org/apache/spark/util/MetricUtils.scala
b/core/src/main/scala/org/apache/spark/util/MetricUtils.scala
new file mode 100644
index 000000000000..a6166f2129d1
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/MetricUtils.scala
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.text.NumberFormat
+import java.util.{Arrays, Locale}
+
+import scala.concurrent.duration._
+
+import org.apache.spark.SparkException
+import org.apache.spark.util.Utils
+
+private[spark] object MetricUtils {
+
+ val SUM_METRIC: String = "sum"
+ val SIZE_METRIC: String = "size"
+ val TIMING_METRIC: String = "timing"
+ val NS_TIMING_METRIC: String = "nsTiming"
+ val AVERAGE_METRIC: String = "average"
+ private val baseForAvgMetric: Int = 10
+ private val METRICS_NAME_SUFFIX = "(min, med, max (stageId: taskId))"
+
+ private def toNumberFormat(value: Long): String = {
+ val numberFormat = NumberFormat.getNumberInstance(Locale.US)
+ numberFormat.format(value.toDouble / baseForAvgMetric)
+ }
+
+ def metricNeedsMax(metricsType: String): Boolean = {
+ metricsType != SUM_METRIC
+ }
+
+/**
+ * A function that defines how we aggregate the final accumulator results
among all tasks,
+ * and represent it in string for a SQL physical operator.
+ */
+ def stringValue(metricsType: String, values: Array[Long], maxMetrics:
Array[Long]): String = {
+ // taskInfo = "(driver)" OR (stage ${stageId}.${attemptId}: task $taskId)
+ val taskInfo = if (maxMetrics.isEmpty) {
+ "(driver)"
+ } else {
+ s"(stage ${maxMetrics(1)}.${maxMetrics(2)}: task ${maxMetrics(3)})"
+ }
+ if (metricsType == SUM_METRIC) {
+ val numberFormat = NumberFormat.getIntegerInstance(Locale.US)
+ numberFormat.format(values.sum)
+ } else if (metricsType == AVERAGE_METRIC) {
+ val validValues = values.filter(_ > 0)
+ // When there are only 1 metrics value (or None), no need to display
max/min/median. This is
+ // common for driver-side SQL metrics.
+ if (validValues.length <= 1) {
+ toNumberFormat(validValues.headOption.getOrElse(0))
+ } else {
+ val Seq(min, med, max) = {
+ Arrays.sort(validValues)
+ Seq(
+ toNumberFormat(validValues(0)),
+ toNumberFormat(validValues(validValues.length / 2)),
+ toNumberFormat(validValues(validValues.length - 1)))
+ }
+ s"$METRICS_NAME_SUFFIX:\n($min, $med, $max $taskInfo)"
+ }
+ } else {
+ val strFormat: Long => String = if (metricsType == SIZE_METRIC) {
+ Utils.bytesToString
+ } else if (metricsType == TIMING_METRIC) {
+ Utils.msDurationToString
+ } else if (metricsType == NS_TIMING_METRIC) {
+ duration => Utils.msDurationToString(duration.nanos.toMillis)
+ } else {
+ throw SparkException.internalError(s"unexpected metrics type:
$metricsType")
+ }
+
+ val validValues = values.filter(_ >= 0)
+ // When there are only 1 metrics value (or None), no need to display
max/min/median. This is
+ // common for driver-side SQL metrics.
+ if (validValues.length <= 1) {
+ strFormat(validValues.headOption.getOrElse(0))
+ } else {
+ val Seq(sum, min, med, max) = {
+ Arrays.sort(validValues)
+ Seq(
+ strFormat(validValues.sum),
+ strFormat(validValues(0)),
+ strFormat(validValues(validValues.length / 2)),
+ strFormat(validValues(validValues.length - 1)))
+ }
+ s"total $METRICS_NAME_SUFFIX\n$sum ($min, $med, $max $taskInfo)"
+ }
+ }
+ }
+}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonCustomMetric.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonCustomMetric.scala
index 7551cd04f20f..2db2ff74374c 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonCustomMetric.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonCustomMetric.scala
@@ -17,8 +17,9 @@
package org.apache.spark.sql.execution.datasources.v2.python
import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric}
-import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
+import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.python.PythonSQLMetrics
+import org.apache.spark.util.MetricUtils
class PythonCustomMetric(
@@ -28,7 +29,7 @@ class PythonCustomMetric(
def this() = this(null, null)
override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = {
- SQLMetrics.stringValue("size", taskMetrics, Array.empty[Long])
+ MetricUtils.stringValue("size", taskMetrics, Array.empty[Long])
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
index f0c1c0900c7f..065c8db7ac6f 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
@@ -17,19 +17,14 @@
package org.apache.spark.sql.execution.metric
-import java.text.NumberFormat
-import java.util.{Arrays, Locale}
-
-import scala.concurrent.duration._
-
import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache}
-import org.apache.spark.{SparkContext, SparkException}
+import org.apache.spark.SparkContext
import org.apache.spark.scheduler.AccumulableInfo
import org.apache.spark.sql.connector.metric.CustomMetric
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates
-import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, Utils}
+import org.apache.spark.util.{AccumulatorContext, AccumulatorV2}
import org.apache.spark.util.AccumulatorContext.internOption
/**
@@ -72,7 +67,7 @@ class SQLMetric(
// This is used to filter out metrics. Metrics with value equal to initValue
should
// be filtered out, since they are either invalid or safe to filter without
changing
- // the aggregation defined in [[SQLMetrics.stringValue]].
+ // the aggregation defined in [[MetricUtils.stringValue]].
// Note that we don't use 0 here since we may want to collect 0 metrics for
// calculating min, max, etc. See SPARK-11013.
override def isZero: Boolean = _value == initValue
@@ -106,8 +101,8 @@ class SQLMetric(
SQLMetrics.cachedSQLAccumIdentifier)
}
- // We should provide the raw value which can be -1, so that
`SQLMetrics.stringValue` can correctly
- // filter out the invalid -1 values.
+ // We should provide the raw value which can be -1, so that
`MetricUtils.stringValue` can
+ // correctly filter out the invalid -1 values.
override def toInfoUpdate: AccumulableInfo = {
AccumulableInfo(id, name, internOption(Some(_value)), None, true, true,
SQLMetrics.cachedSQLAccumIdentifier)
@@ -203,77 +198,6 @@ object SQLMetrics {
acc
}
- private def toNumberFormat(value: Long): String = {
- val numberFormat = NumberFormat.getNumberInstance(Locale.US)
- numberFormat.format(value.toDouble / baseForAvgMetric)
- }
-
- def metricNeedsMax(metricsType: String): Boolean = {
- metricsType != SUM_METRIC
- }
-
- private val METRICS_NAME_SUFFIX = "(min, med, max (stageId: taskId))"
-
- /**
- * A function that defines how we aggregate the final accumulator results
among all tasks,
- * and represent it in string for a SQL physical operator.
- */
- def stringValue(metricsType: String, values: Array[Long], maxMetrics:
Array[Long]): String = {
- // taskInfo = "(driver)" OR (stage ${stageId}.${attemptId}: task $taskId)
- val taskInfo = if (maxMetrics.isEmpty) {
- "(driver)"
- } else {
- s"(stage ${maxMetrics(1)}.${maxMetrics(2)}: task ${maxMetrics(3)})"
- }
- if (metricsType == SUM_METRIC) {
- val numberFormat = NumberFormat.getIntegerInstance(Locale.US)
- numberFormat.format(values.sum)
- } else if (metricsType == AVERAGE_METRIC) {
- val validValues = values.filter(_ > 0)
- // When there are only 1 metrics value (or None), no need to display
max/min/median. This is
- // common for driver-side SQL metrics.
- if (validValues.length <= 1) {
- toNumberFormat(validValues.headOption.getOrElse(0))
- } else {
- val Seq(min, med, max) = {
- Arrays.sort(validValues)
- Seq(
- toNumberFormat(validValues(0)),
- toNumberFormat(validValues(validValues.length / 2)),
- toNumberFormat(validValues(validValues.length - 1)))
- }
- s"$METRICS_NAME_SUFFIX:\n($min, $med, $max $taskInfo)"
- }
- } else {
- val strFormat: Long => String = if (metricsType == SIZE_METRIC) {
- Utils.bytesToString
- } else if (metricsType == TIMING_METRIC) {
- Utils.msDurationToString
- } else if (metricsType == NS_TIMING_METRIC) {
- duration => Utils.msDurationToString(duration.nanos.toMillis)
- } else {
- throw SparkException.internalError(s"unexpected metrics type:
$metricsType")
- }
-
- val validValues = values.filter(_ >= 0)
- // When there are only 1 metrics value (or None), no need to display
max/min/median. This is
- // common for driver-side SQL metrics.
- if (validValues.length <= 1) {
- strFormat(validValues.headOption.getOrElse(0))
- } else {
- val Seq(sum, min, med, max) = {
- Arrays.sort(validValues)
- Seq(
- strFormat(validValues.sum),
- strFormat(validValues(0)),
- strFormat(validValues(validValues.length / 2)),
- strFormat(validValues(validValues.length - 1)))
- }
- s"total $METRICS_NAME_SUFFIX\n$sum ($min, $med, $max $taskInfo)"
- }
- }
- }
-
def postDriverMetricsUpdatedByValue(
sc: SparkContext,
executionId: String,
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala
index 3c8c7edfeb06..f680860231f0 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala
@@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.metric._
import org.apache.spark.sql.internal.StaticSQLConf._
import org.apache.spark.status.{ElementTrackingStore, KVUtils, LiveEntity}
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{MetricUtils, Utils}
import org.apache.spark.util.collection.OpenHashMap
class SQLAppStatusListener(
@@ -235,7 +235,7 @@ class SQLAppStatusListener(
}
}.getOrElse(
// Built-in SQLMetric
- SQLMetrics.stringValue(m.metricType, _, _)
+ MetricUtils.stringValue(m.metricType, _, _)
)
(m.accumulatorId, metricAggMethod)
}.toMap
@@ -554,7 +554,7 @@ private class LiveStageMetrics(
/**
* Task metrics values for the stage. Maps the metric ID to the metric
values for each
* index. For each metric ID, there will be the same number of values as the
number
- * of indices. This relies on `SQLMetrics.stringValue` treating 0 as a
neutral value,
+ * of indices. This relies on `MetricUtils.stringValue` treating 0 as a
neutral value,
* independent of the actual metric type.
*/
private val taskMetrics = new ConcurrentHashMap[Long, Array[Long]]()
@@ -601,7 +601,7 @@ private class LiveStageMetrics(
val metricValues = taskMetrics.computeIfAbsent(acc.id, _ => new
Array(numTasks))
metricValues(taskIdx) = value
- if (SQLMetrics.metricNeedsMax(accumIdsToMetricType(acc.id))) {
+ if (MetricUtils.metricNeedsMax(accumIdsToMetricType(acc.id))) {
val maxMetricsTaskId =
metricsIdToMaxTaskValue.computeIfAbsent(acc.id, _ => Array(value,
taskId))
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala
index e63ff019a2b6..256d9d156c18 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala
@@ -57,7 +57,7 @@ import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.status.{AppStatusStore, ElementTrackingStore}
-import org.apache.spark.util.{AccumulatorMetadata, JsonProtocol,
LongAccumulator, SerializableConfiguration, Utils}
+import org.apache.spark.util.{AccumulatorMetadata, JsonProtocol,
LongAccumulator, MetricUtils, SerializableConfiguration, Utils}
import org.apache.spark.util.kvstore.InMemoryStore
@@ -597,9 +597,9 @@ abstract class SQLAppStatusListenerSuite extends
SharedSparkSession with JsonTes
val metrics = statusStore.executionMetrics(execId)
val driverMetric = physicalPlan.metrics("dummy")
val driverMetric2 = physicalPlan.metrics("dummy2")
- val expectedValue = SQLMetrics.stringValue(driverMetric.metricType,
+ val expectedValue = MetricUtils.stringValue(driverMetric.metricType,
Array(expectedAccumValue), Array.empty[Long])
- val expectedValue2 = SQLMetrics.stringValue(driverMetric2.metricType,
+ val expectedValue2 = MetricUtils.stringValue(driverMetric2.metricType,
Array(expectedAccumValue2), Array.empty[Long])
assert(metrics.contains(driverMetric.id))
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]