This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 89727bfa7529 [SPARK-46644] Change add and merge in SQLMetric to use 
isZero
89727bfa7529 is described below

commit 89727bfa7529aa28d85e5d9a58b21d8aa035a23f
Author: Davin Tjong <davin.tj...@databricks.com>
AuthorDate: Thu Jan 18 11:01:37 2024 +0800

    [SPARK-46644] Change add and merge in SQLMetric to use isZero
    
    ### What changes were proposed in this pull request?
    
    A previous refactor mistakenly used `isValid` for add. Since 
`defaultValidValue` was always `0`, this didn't cause any correctness issues.
    
    What we really want to do for add (and merge) is `if (isZero) _value = 0`.
    
    Also removing `isValid` since its redundant, if `defaultValidValue` is 
always `0`.
    
    ### Why are the changes needed?
    
    There are no correctness errors, but this is confusing and error-prone.
    
    A negative `defaultValidValue` was intended to allow creating optional 
metrics. With the previous behavior this would incorrectly add the sentinel 
value. `defaultValidValue` is supposed to determine what value is exposed to 
the user.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Running the tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #44649 from davintjong-db/sql-metric-add-fix.
    
    Authored-by: Davin Tjong <davin.tj...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../spark/sql/execution/metric/SQLMetrics.scala    | 50 ++++++++++++----------
 .../sql/execution/metric/SQLMetricsSuite.scala     | 11 +++--
 2 files changed, 32 insertions(+), 29 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
index 8cd28f9a06a4..a246b47fe655 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
@@ -39,21 +39,21 @@ import org.apache.spark.util.AccumulatorContext.internOption
  */
 class SQLMetric(
     val metricType: String,
-    initValue: Long = 0L,
-    defaultValidValue: Long = 0L) extends AccumulatorV2[Long, Long] {
-  // initValue defines the initial value of the metric. defaultValidValue 
defines the lowest value
-  // considered valid. If a SQLMetric is invalid, it is set to 
defaultValidValue upon receiving any
-  // updates, and it also reports defaultValidValue as its value to avoid 
exposing it to the user
-  // programatically.
+    initValue: Long = 0L) extends AccumulatorV2[Long, Long] {
+  // initValue defines the initial value of the metric. 0 is the lowest value 
considered valid.
+  // If a SQLMetric is invalid, it is set to 0 upon receiving any updates, and 
it also reports
+  // 0 as its value to avoid exposing it to the user programmatically.
   //
-  // For many SQLMetrics, we use initValue = -1 and defaultValidValue = 0 to 
indicate that the
-  // metric is by default invalid. At the end of a task, we will update the 
metric making it valid,
-  // and the invalid metrics will be filtered out when calculating min, max, 
etc. as a workaround
+  // For many SQLMetrics, we use initValue = -1 to indicate that the metric is 
by default invalid.
+  // At the end of a task, we will update the metric making it valid, and the 
invalid metrics will
+  // be filtered out when calculating min, max, etc. as a workaround
   // for SPARK-11013.
+  assert(initValue <= 0)
+  // _value will always be either initValue or non-negative.
   private var _value = initValue
 
   override def copy(): SQLMetric = {
-    val newAcc = new SQLMetric(metricType, initValue, defaultValidValue)
+    val newAcc = new SQLMetric(metricType, initValue)
     newAcc._value = _value
     newAcc
   }
@@ -62,8 +62,8 @@ class SQLMetric(
 
   override def merge(other: AccumulatorV2[Long, Long]): Unit = other match {
     case o: SQLMetric =>
-      if (o.isValid) {
-        if (!isValid) _value = defaultValidValue
+      if (!o.isZero) {
+        if (isZero) _value = 0
         _value += o.value
       }
     case _ => throw QueryExecutionErrors.cannotMergeClassWithOtherClassError(
@@ -73,28 +73,32 @@ class SQLMetric(
   // This is used to filter out metrics. Metrics with value equal to initValue 
should
   // be filtered out, since they are either invalid or safe to filter without 
changing
   // the aggregation defined in [[SQLMetrics.stringValue]].
-  // Note that we don't use defaultValidValue here since we may want to collect
-  // defaultValidValue metrics for calculating min, max, etc. See SPARK-11013.
+  // Note that we don't use 0 here since we may want to collect 0 metrics for
+  // calculating min, max, etc. See SPARK-11013.
   override def isZero: Boolean = _value == initValue
 
-  def isValid: Boolean = _value >= defaultValidValue
-
   override def add(v: Long): Unit = {
-    if (!isValid) _value = defaultValidValue
-    _value += v
+    if (v >= 0) {
+      if (isZero) _value = 0
+      _value += v
+    }
   }
 
   // We can set a double value to `SQLMetric` which stores only long value, if 
it is
   // average metrics.
-  def set(v: Double): Unit = SQLMetrics.setDoubleForAverageMetrics(this, v)
+  def set(v: Double): Unit = if (v >= 0) {
+    SQLMetrics.setDoubleForAverageMetrics(this, v)
+  }
 
-  def set(v: Long): Unit = _value = v
+  def set(v: Long): Unit = if (v >= 0) {
+    _value = v
+  }
 
   def +=(v: Long): Unit = add(v)
 
-  // _value may be invalid, in many cases being -1. We should not expose it to 
the user
-  // and instead return defaultValidValue.
-  override def value: Long = if (!isValid) defaultValidValue else _value
+  // _value may be uninitialized, in many cases being -1. We should not expose 
it to the user
+  // and instead return 0.
+  override def value: Long = if (isZero) 0 else _value
 
   // Provide special identifier as metadata so we can tell that this is a 
`SQLMetric` later
   override def toInfo(update: Option[Any], value: Option[Any]): 
AccumulableInfo = {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index e71451f2f742..45c775e6c463 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -938,27 +938,26 @@ class SQLMetricsSuite extends SharedSparkSession with 
SQLMetricsTestUtils
   test("Creating metrics with initial values") {
     assert(SQLMetrics.createSizeMetric(sparkContext, name = "m").value === 0)
     assert(SQLMetrics.createSizeMetric(sparkContext, name = "m", initValue = 
-1).value === 0)
-    assert(SQLMetrics.createSizeMetric(sparkContext, name = "m", initValue = 
5).value === 5)
 
     assert(SQLMetrics.createSizeMetric(sparkContext, name = "m").isZero)
     assert(SQLMetrics.createSizeMetric(sparkContext, name = "m", initValue = 
-1).isZero)
-    assert(SQLMetrics.createSizeMetric(sparkContext, name = "m", initValue = 
5).isZero)
 
     assert(SQLMetrics.createTimingMetric(sparkContext, name = "m").value === 0)
     assert(SQLMetrics.createTimingMetric(sparkContext, name = "m", initValue = 
-1).value === 0)
-    assert(SQLMetrics.createTimingMetric(sparkContext, name = "m", initValue = 
5).value === 5)
 
     assert(SQLMetrics.createTimingMetric(sparkContext, name = "m").isZero)
     assert(SQLMetrics.createTimingMetric(sparkContext, name = "m", initValue = 
-1).isZero)
-    assert(SQLMetrics.createTimingMetric(sparkContext, name = "m", initValue = 
5).isZero)
 
     assert(SQLMetrics.createNanoTimingMetric(sparkContext, name = "m").value 
=== 0)
     assert(SQLMetrics.createNanoTimingMetric(sparkContext, name = "m", 
initValue = -1).value === 0)
-    assert(SQLMetrics.createNanoTimingMetric(sparkContext, name = "m", 
initValue = 5).value === 5)
 
     assert(SQLMetrics.createNanoTimingMetric(sparkContext, name = "m").isZero)
     assert(SQLMetrics.createNanoTimingMetric(sparkContext, name = "m", 
initValue = -1).isZero)
-    assert(SQLMetrics.createNanoTimingMetric(sparkContext, name = "m", 
initValue = 5).isZero)
+
+    // initValue must be <= 0
+    intercept[AssertionError] {
+      SQLMetrics.createNanoTimingMetric(sparkContext, name = "m", initValue = 
5)
+    }
   }
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to