This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 24b67ca  [SPARK-35896][SS] Include more granular metrics for stateful 
operators in StreamingQueryProgress
24b67ca is described below

commit 24b67ca9a837250d25dbcd189b75c919c06aec26
Author: Venki Korukanti <venki.koruka...@gmail.com>
AuthorDate: Wed Jun 30 13:41:26 2021 +0900

    [SPARK-35896][SS] Include more granular metrics for stateful operators in 
StreamingQueryProgress
    
    ### What changes were proposed in this pull request?
    
    Currently the `StateOperatorProgress` in `StreamingQueryProgress` is 
missing few metrics.
    
    ### Why are the changes needed?
    
    The main motivation is find hotspots and have better visibility in the 
stateful operations. Detailed explanations are in 
[SPARK-35896](https://issues.apache.org/jira/browse/SPARK-35896).
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. The `StateOperatorProgress` entries within `StreamingQueryProgress` 
now contain additional fields as listed in 
[SPARK-35896](https://issues.apache.org/jira/browse/SPARK-35896). Example 
`StreamingQueryProgress` output in JSON form.
    Before:
    ```
    {
    
      "id" : "510be3cd-a955-4faf-8456-d97c78d39af5",
      ....
      "durationMs" : {
        "triggerExecution" : 2856,
        ....
      },
      "stateOperators" : [ {
        "numRowsTotal" : 1,
        "numRowsUpdated" : 1,
        "numRowsDroppedByWatermark" : 0,
        "customMetrics" : {
          "loadedMapCacheHitCount" : 0,
          "loadedMapCacheMissCount" : 0,
          "stateOnCurrentVersionSizeBytes" : 392
        }
      }],
      ....
    }
    ```
    After:
    ```
    {
      "id" : "510be3cd-a955-4faf-8456-d97c78d39af5",
      ....
      "durationMs" : {
        "triggerExecution" : 2856,
        ....
      },
      "stateOperators" : [ {
        "operatorName" : "dedupe", <-- new
        "numRowsTotal" : 1,
        "numRowsUpdated" : 1, <-- new
        "allUpdatesTimeMs" : 56, <-- new
        "numRowsRemoved" : 2, <-- new
        "allRemovalsTimeMs" : 45, <-- new
        "commitTimeMs" : 40, <-- new
        "numRowsDroppedByWatermark" : 0,
        "numShufflePartitions" : 2, <-- new
        "numStateStoreInstances" : 2, <-- new
        "customMetrics" : {
          "loadedMapCacheHitCount" : 0,
          "loadedMapCacheMissCount" : 0,
          "stateOnCurrentVersionSizeBytes" : 392
        }
      }],
      ....
    }
    ```
    
    ### How was this patch tested?
    
    Existing tests for regressions. Added new UTs.
    
    Closes #33091 from vkorukanti/SPARK-35896.
    
    Lead-authored-by: Venki Korukanti <venki.koruka...@gmail.com>
    Co-authored-by: Venki Korukanti <venki.koruka...@databricks.com>
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
---
 .../streaming/FlatMapGroupsWithStateExec.scala     |  31 +++++-
 .../streaming/StreamingSymmetricHashJoinExec.scala |   8 ++
 .../execution/streaming/statefulOperators.scala    |  41 ++++++-
 .../sql/execution/streaming/streamingLimits.scala  |   3 +
 .../org/apache/spark/sql/streaming/progress.scala  |  23 +++-
 .../streaming/FlatMapGroupsWithStateSuite.scala    | 119 +++++++++++++--------
 .../sql/streaming/StateStoreMetricsTest.scala      |  45 +++++++-
 .../sql/streaming/StreamingAggregationSuite.scala  |  58 ++++++++++
 .../streaming/StreamingDeduplicationSuite.scala    |  34 +++++-
 .../spark/sql/streaming/StreamingJoinSuite.scala   |  43 ++++++++
 .../StreamingQueryStatusAndProgressSuite.scala     |  28 ++++-
 11 files changed, 368 insertions(+), 65 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
index 981586e..fda26b0 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
@@ -81,6 +81,8 @@ case class FlatMapGroupsWithStateExec(
 
   override def keyExpressions: Seq[Attribute] = groupingAttributes
 
+  override def shortName: String = "flatMapGroupsWithState"
+
   override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean 
= {
     timeoutConf match {
       case ProcessingTimeTimeout =>
@@ -115,10 +117,13 @@ case class FlatMapGroupsWithStateExec(
       Some(session.streams.stateStoreCoordinator)) { case (store, iter) =>
         val allUpdatesTimeMs = longMetric("allUpdatesTimeMs")
         val commitTimeMs = longMetric("commitTimeMs")
-        val updatesStartTimeNs = System.nanoTime
-
+        val timeoutLatencyMs = longMetric("allRemovalsTimeMs")
         val processor = new InputProcessor(store)
 
+        val currentTimeNs = System.nanoTime
+        val updatesStartTimeNs = currentTimeNs
+        var timeoutProcessingStartTimeNs = currentTimeNs
+
         // If timeout is based on event time, then filter late data based on 
watermark
         val filteredIter = watermarkPredicateForData match {
           case Some(predicate) if timeoutConf == EventTimeTimeout =>
@@ -127,12 +132,26 @@ case class FlatMapGroupsWithStateExec(
             iter
         }
 
+        val newDataProcessorIter =
+          CompletionIterator[InternalRow, Iterator[InternalRow]](
+            processor.processNewData(filteredIter), {
+            // Once the input is processed, mark the start time for timeout 
processing to measure
+            // it separately from the overall processing time.
+            timeoutProcessingStartTimeNs = System.nanoTime
+          })
+
+        val timeoutProcessorIter =
+          CompletionIterator[InternalRow, 
Iterator[InternalRow]](processor.processTimedOutState(), {
+            // Note: `timeoutLatencyMs` also includes the time the parent 
operator took for
+            // processing output returned through iterator.
+            timeoutLatencyMs += NANOSECONDS.toMillis(System.nanoTime - 
timeoutProcessingStartTimeNs)
+          })
+
         // Generate a iterator that returns the rows grouped by the grouping 
function
         // Note that this code ensures that the filtering for timeout occurs 
only after
         // all the data has been processed. This is to ensure that the timeout 
information of all
         // the keys with data is updated before they are processed for 
timeouts.
-        val outputIterator = processor.processNewData(filteredIter) ++
-          processor.processTimedOutState()
+        val outputIterator = newDataProcessorIter ++ timeoutProcessorIter
 
         // Return an iterator of all the rows generated by all the keys, such 
that when fully
         // consumed, all the state updates will be committed by the state store
@@ -144,6 +163,7 @@ case class FlatMapGroupsWithStateExec(
               store.commit()
             }
             setStoreMetrics(store)
+            setOperatorMetrics()
           }
         )
     }
@@ -162,6 +182,7 @@ case class FlatMapGroupsWithStateExec(
     // Metrics
     private val numUpdatedStateRows = longMetric("numUpdatedStateRows")
     private val numOutputRows = longMetric("numOutputRows")
+    private val numRemovedStateRows = longMetric("numRemovedStateRows")
 
     /**
      * For every group, get the key, values and corresponding state and call 
the function,
@@ -231,7 +252,7 @@ case class FlatMapGroupsWithStateExec(
       def onIteratorCompletion: Unit = {
         if (groupState.isRemoved && 
!groupState.getTimeoutTimestampMs.isPresent()) {
           stateManager.removeState(store, stateData.keyRow)
-          numUpdatedStateRows += 1
+          numRemovedStateRows += 1
         } else {
           val currentTimeoutTimestamp = 
groupState.getTimeoutTimestampMs.orElse(NO_TIMESTAMP)
           val hasTimeoutChanged = currentTimeoutTimestamp != 
stateData.timeoutTimestamp
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
index de2e19b..616ae08 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
@@ -207,6 +207,8 @@ case class StreamingSymmetricHashJoinExec(
     case _ => throwBadJoinTypeException()
   }
 
+  override def shortName: String = "symmetricHashJoin"
+
   override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean 
= {
     val watermarkUsedForStateCleanup =
       stateWatermarkPredicates.left.nonEmpty || 
stateWatermarkPredicates.right.nonEmpty
@@ -221,6 +223,7 @@ case class StreamingSymmetricHashJoinExec(
   protected override def doExecute(): RDD[InternalRow] = {
     val stateStoreCoord = 
session.sessionState.streamingQueryManager.stateStoreCoordinator
     val stateStoreNames = 
SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide)
+    metrics // initialize metrics
     left.execute().stateStoreAwareZipPartitions(
       right.execute(), stateInfo.get, stateStoreNames, 
stateStoreCoord)(processPartitions)
   }
@@ -237,6 +240,7 @@ case class StreamingSymmetricHashJoinExec(
     val numUpdatedStateRows = longMetric("numUpdatedStateRows")
     val numTotalStateRows = longMetric("numTotalStateRows")
     val allUpdatesTimeMs = longMetric("allUpdatesTimeMs")
+    val numRemovedStateRows = longMetric("numRemovedStateRows")
     val allRemovalsTimeMs = longMetric("allRemovalsTimeMs")
     val commitTimeMs = longMetric("commitTimeMs")
     val stateMemory = longMetric("stateMemory")
@@ -407,6 +411,7 @@ case class StreamingSymmetricHashJoinExec(
         }
         while (cleanupIter.hasNext) {
           cleanupIter.next()
+          numRemovedStateRows += 1
         }
       }
 
@@ -425,6 +430,9 @@ case class StreamingSymmetricHashJoinExec(
           longMetric(metric.name) += value
         }
       }
+
+      val stateStoreNames = 
SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide);
+      setOperatorMetrics(numStateStoreInstances = stateStoreNames.length)
     }
 
     CompletionIterator[InternalRow, Iterator[InternalRow]](
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
index 41dcfde..5365009 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
@@ -101,9 +101,13 @@ trait StateStoreWriter extends StatefulOperator { self: 
SparkPlan =>
     "numTotalStateRows" -> SQLMetrics.createMetric(sparkContext, "number of 
total state rows"),
     "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of 
updated state rows"),
     "allUpdatesTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "time to 
update"),
+    "numRemovedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of 
removed state rows"),
     "allRemovalsTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "time 
to remove"),
     "commitTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "time to 
commit changes"),
-    "stateMemory" -> SQLMetrics.createSizeMetric(sparkContext, "memory used by 
state")
+    "stateMemory" -> SQLMetrics.createSizeMetric(sparkContext, "memory used by 
state"),
+    "numShufflePartitions" -> SQLMetrics.createMetric(sparkContext, "number of 
shuffle partitions"),
+    "numStateStoreInstances" -> SQLMetrics.createMetric(sparkContext,
+      "number of state store instances")
   ) ++ stateStoreCustomMetrics
 
   /**
@@ -118,10 +122,17 @@ trait StateStoreWriter extends StatefulOperator { self: 
SparkPlan =>
       new java.util.HashMap(customMetrics.mapValues(long2Long).toMap.asJava)
 
     new StateOperatorProgress(
+      operatorName = shortName,
       numRowsTotal = longMetric("numTotalStateRows").value,
       numRowsUpdated = longMetric("numUpdatedStateRows").value,
+      allUpdatesTimeMs = longMetric("allUpdatesTimeMs").value,
+      numRowsRemoved = longMetric("numRemovedStateRows").value,
+      allRemovalsTimeMs = longMetric("allRemovalsTimeMs").value,
+      commitTimeMs = longMetric("commitTimeMs").value,
       memoryUsedBytes = longMetric("stateMemory").value,
       numRowsDroppedByWatermark = 
longMetric("numRowsDroppedByWatermark").value,
+      numShufflePartitions = longMetric("numShufflePartitions").value,
+      numStateStoreInstances = longMetric("numStateStoreInstances").value,
       javaConvertedCustomMetrics
     )
   }
@@ -129,6 +140,15 @@ trait StateStoreWriter extends StatefulOperator { self: 
SparkPlan =>
   /** Records the duration of running `body` for the next query progress 
update. */
   protected def timeTakenMs(body: => Unit): Long = Utils.timeTakenMs(body)._2
 
+  /** Set the operator level metrics */
+  protected def setOperatorMetrics(numStateStoreInstances: Int = 1): Unit = {
+    assert(numStateStoreInstances >= 1, s"invalid number of stores: 
$numStateStoreInstances")
+    // Shuffle partitions capture the number of tasks that have this stateful 
operator instance.
+    // For each task instance this number is incremented by one.
+    longMetric("numShufflePartitions") += 1
+    longMetric("numStateStoreInstances") += numStateStoreInstances
+  }
+
   /**
    * Set the SQL metrics related to the state store.
    * This should be called in that task after the store has been updated.
@@ -172,6 +192,9 @@ trait StateStoreWriter extends StatefulOperator { self: 
SparkPlan =>
     }
   }
 
+  /** Name to output in [[StreamingOperatorProgress]] to identify operator 
type */
+  protected def shortName: String = "defaultName"
+
   /**
    * Should the MicroBatchExecution run another batch based on this stateful 
operator and the
    * current updated metadata.
@@ -210,9 +233,11 @@ trait WatermarkSupport extends UnaryExecNode {
 
   protected def removeKeysOlderThanWatermark(store: StateStore): Unit = {
     if (watermarkPredicateForKeys.nonEmpty) {
+      val numRemovedStateRows = longMetric("numRemovedStateRows")
       store.getRange(None, None).foreach { rowPair =>
         if (watermarkPredicateForKeys.get.eval(rowPair.key)) {
           store.remove(rowPair.key)
+          numRemovedStateRows += 1
         }
       }
     }
@@ -222,9 +247,11 @@ trait WatermarkSupport extends UnaryExecNode {
       storeManager: StreamingAggregationStateManager,
       store: StateStore): Unit = {
     if (watermarkPredicateForKeys.nonEmpty) {
+      val numRemovedStateRows = longMetric("numRemovedStateRows")
       storeManager.keys(store).foreach { keyRow =>
         if (watermarkPredicateForKeys.get.eval(keyRow)) {
           storeManager.remove(store, keyRow)
+          numRemovedStateRows += 1
         }
       }
     }
@@ -345,6 +372,7 @@ case class StateStoreSaveExec(
         val numOutputRows = longMetric("numOutputRows")
         val numUpdatedStateRows = longMetric("numUpdatedStateRows")
         val allUpdatesTimeMs = longMetric("allUpdatesTimeMs")
+        val numRemovedStateRows = longMetric("numRemovedStateRows")
         val allRemovalsTimeMs = longMetric("allRemovalsTimeMs")
         val commitTimeMs = longMetric("commitTimeMs")
 
@@ -363,6 +391,7 @@ case class StateStoreSaveExec(
               stateManager.commit(store)
             }
             setStoreMetrics(store)
+            setOperatorMetrics()
             stateManager.values(store).map { valueRow =>
               numOutputRows += 1
               valueRow
@@ -391,6 +420,7 @@ case class StateStoreSaveExec(
                   val rowPair = rangeIter.next()
                   if (watermarkPredicateForKeys.get.eval(rowPair.key)) {
                     stateManager.remove(store, rowPair.key)
+                    numRemovedStateRows += 1
                     removedValueRow = rowPair.value
                   }
                 }
@@ -404,9 +434,12 @@ case class StateStoreSaveExec(
               }
 
               override protected def close(): Unit = {
+                // Note: Due to the iterator lazy exec, this metric also 
captures the time taken
+                // by the consumer operators in addition to the processing in 
this operator.
                 allRemovalsTimeMs += NANOSECONDS.toMillis(System.nanoTime - 
removalStartTimeNs)
                 commitTimeMs += timeTakenMs { stateManager.commit(store) }
                 setStoreMetrics(store)
+                setOperatorMetrics()
               }
             }
 
@@ -443,6 +476,7 @@ case class StateStoreSaveExec(
                 }
                 commitTimeMs += timeTakenMs { stateManager.commit(store) }
                 setStoreMetrics(store)
+                setOperatorMetrics()
               }
             }
 
@@ -463,6 +497,8 @@ case class StateStoreSaveExec(
     }
   }
 
+  override def shortName: String = "stateStoreSave"
+
   override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean 
= {
     (outputMode.contains(Append) || outputMode.contains(Update)) &&
       eventTimeWatermark.isDefined &&
@@ -534,6 +570,7 @@ case class StreamingDeduplicateExec(
         allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) 
}
         commitTimeMs += timeTakenMs { store.commit() }
         setStoreMetrics(store)
+        setOperatorMetrics()
       })
     }
   }
@@ -546,6 +583,8 @@ case class StreamingDeduplicateExec(
     Seq(StatefulOperatorCustomSumMetric("numDroppedDuplicateRows", "number of 
duplicates dropped"))
   }
 
+  override def shortName: String = "dedupe"
+
   override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean 
= {
     eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > 
eventTimeWatermark.get
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala
index 4200d49..0e9d12d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala
@@ -82,6 +82,7 @@ case class StreamingGlobalLimitExec(
         allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - 
updatesStartTimeNs)
         commitTimeMs += timeTakenMs { store.commit() }
         setStoreMetrics(store)
+        setOperatorMetrics()
       })
     }
   }
@@ -96,6 +97,8 @@ case class StreamingGlobalLimitExec(
     UnsafeProjection.create(valueSchema)(new 
GenericInternalRow(Array[Any](value)))
   }
 
+  override def shortName: String = "globalLimit"
+
   override protected def withNewChildInternal(newChild: SparkPlan): 
StreamingGlobalLimitExec =
     copy(child = newChild)
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala
index 554780a..1565658 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala
@@ -41,10 +41,17 @@ import 
org.apache.spark.sql.streaming.SinkProgress.DEFAULT_NUM_OUTPUT_ROWS
  */
 @Evolving
 class StateOperatorProgress private[sql](
+    val operatorName: String,
     val numRowsTotal: Long,
     val numRowsUpdated: Long,
+    val allUpdatesTimeMs: Long,
+    val numRowsRemoved: Long,
+    val allRemovalsTimeMs: Long,
+    val commitTimeMs: Long,
     val memoryUsedBytes: Long,
     val numRowsDroppedByWatermark: Long,
+    val numShufflePartitions: Long,
+    val numStateStoreInstances: Long,
     val customMetrics: ju.Map[String, JLong] = new ju.HashMap()
   ) extends Serializable {
 
@@ -57,14 +64,26 @@ class StateOperatorProgress private[sql](
   private[sql] def copy(
       newNumRowsUpdated: Long,
       newNumRowsDroppedByWatermark: Long): StateOperatorProgress =
-    new StateOperatorProgress(numRowsTotal, newNumRowsUpdated, memoryUsedBytes,
-      newNumRowsDroppedByWatermark, customMetrics)
+    new StateOperatorProgress(
+      operatorName = operatorName, numRowsTotal = numRowsTotal, numRowsUpdated 
= newNumRowsUpdated,
+      allUpdatesTimeMs = allUpdatesTimeMs, numRowsRemoved = numRowsRemoved,
+      allRemovalsTimeMs = allRemovalsTimeMs, commitTimeMs = commitTimeMs,
+      memoryUsedBytes = memoryUsedBytes, numRowsDroppedByWatermark = 
newNumRowsDroppedByWatermark,
+      numShufflePartitions = numShufflePartitions, numStateStoreInstances = 
numStateStoreInstances,
+      customMetrics = customMetrics)
 
   private[sql] def jsonValue: JValue = {
+    ("operatorName" -> JString(operatorName)) ~
     ("numRowsTotal" -> JInt(numRowsTotal)) ~
     ("numRowsUpdated" -> JInt(numRowsUpdated)) ~
+    ("allUpdatesTimeMs" -> JInt(allUpdatesTimeMs)) ~
+    ("numRowsRemoved" -> JInt(numRowsRemoved)) ~
+    ("allRemovalsTimeMs" -> JInt(allRemovalsTimeMs)) ~
+    ("commitTimeMs" -> JInt(commitTimeMs)) ~
     ("memoryUsedBytes" -> JInt(memoryUsedBytes)) ~
     ("numRowsDroppedByWatermark" -> JInt(numRowsDroppedByWatermark)) ~
+    ("numShufflePartitions" -> JInt(numShufflePartitions)) ~
+    ("numStateStoreInstances" -> JInt(numStateStoreInstances)) ~
     ("customMetrics" -> {
       if (!customMetrics.isEmpty) {
         val keys = customMetrics.keySet.asScala.toSeq.sorted
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
index ad12d0d..171d330 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
@@ -54,6 +54,30 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest {
   import GroupStateImpl._
   import GroupStateTimeout._
 
+  /**
+   * Sample `flatMapGroupsWithState` function implementation. It maintains the 
max event time as
+   * state and set the timeout timestamp based on the current max event time 
seen. It returns the
+   * max event time in the state, or -1 if the state was removed by timeout. 
Timeout is 5sec.
+   */
+  val sampleTestFunction =
+      (key: String, values: Iterator[(String, Long)], state: GroupState[Long]) 
=> {
+    assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 }
+    assertCanGetWatermark { state.getCurrentWatermarkMs() >= -1 }
+
+    val timeoutDelaySec = 5
+    if (state.hasTimedOut) {
+      state.remove()
+      Iterator((key, -1))
+    } else {
+      val valuesSeq = values.toSeq
+      val maxEventTimeSec = math.max(valuesSeq.map(_._2).max, 
state.getOption.getOrElse(0L))
+      val timeoutTimestampSec = maxEventTimeSec + timeoutDelaySec
+      state.update(maxEventTimeSec)
+      state.setTimeoutTimestamp(timeoutTimestampSec * 1000)
+      Iterator((key, maxEventTimeSec.toInt))
+    }
+  }
+
   test("SPARK-35800: ensure TestGroupState creates instances the same as 
prod") {
     val testState = TestGroupState.create[Int](
       Optional.of(5), EventTimeTimeout, 1L, Optional.of(1L), hasTimedOut = 
false)
@@ -737,7 +761,8 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest {
       StartStream(),
       AddData(inputData, "a", "b"), // should remove state for "a" and not 
return anything for a
       CheckNewAnswer(("b", "2")),
-      assertNumStateRows(total = 1, updated = 2),
+      assertNumStateRows(
+        total = Seq(1), updated = Seq(1), droppedByWatermark = Seq(0), removed 
= Some(Seq(1))),
       StopStream,
       StartStream(),
       AddData(inputData, "a", "c"), // should recreate state for "a" and 
return count as 1 and
@@ -881,7 +906,8 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest {
       AddData(inputData, "b"),
       AdvanceManualClock(10 * 1000),
       CheckNewAnswer(("a", "-1"), ("b", "2")),
-      assertNumStateRows(total = 1, updated = 2),
+      assertNumStateRows(
+        total = Seq(1), updated = Seq(1), droppedByWatermark = Seq(0), removed 
= Some(Seq(1))),
 
       StopStream,
       StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock),
@@ -889,7 +915,8 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest {
       AddData(inputData, "c"),
       AdvanceManualClock(11 * 1000),
       CheckNewAnswer(("b", "-1"), ("c", "1")),
-      assertNumStateRows(total = 1, updated = 2),
+      assertNumStateRows(
+        total = Seq(1), updated = Seq(1), droppedByWatermark = Seq(0), removed 
= Some(Seq(1))),
 
       AdvanceManualClock(12 * 1000),
       AssertOnQuery { _ => clock.getTimeMillis() == 35000 },
@@ -901,31 +928,12 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest {
         }
       },
       CheckNewAnswer(("c", "-1")),
-      assertNumStateRows(total = 0, updated = 1)
+      assertNumStateRows(
+        total = Seq(0), updated = Seq(0), droppedByWatermark = Seq(0), removed 
= Some(Seq(1)))
     )
   }
 
   testWithAllStateVersions("flatMapGroupsWithState - streaming w/ event time 
timeout + watermark") {
-    // Function to maintain the max event time as state and set the timeout 
timestamp based on the
-    // current max event time seen. It returns the max event time in the 
state, or -1 if the state
-    // was removed by timeout.
-    val stateFunc = (key: String, values: Iterator[(String, Long)], state: 
GroupState[Long]) => {
-      assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 }
-      assertCanGetWatermark { state.getCurrentWatermarkMs() >= -1 }
-
-      val timeoutDelaySec = 5
-      if (state.hasTimedOut) {
-        state.remove()
-        Iterator((key, -1))
-      } else {
-        val valuesSeq = values.toSeq
-        val maxEventTimeSec = math.max(valuesSeq.map(_._2).max, 
state.getOption.getOrElse(0L))
-        val timeoutTimestampSec = maxEventTimeSec + timeoutDelaySec
-        state.update(maxEventTimeSec)
-        state.setTimeoutTimestamp(timeoutTimestampSec * 1000)
-        Iterator((key, maxEventTimeSec.toInt))
-      }
-    }
     val inputData = MemoryStream[(String, Int)]
     val result =
       inputData.toDS
@@ -933,7 +941,7 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest {
         .withWatermark("eventTime", "10 seconds")
         .as[(String, Long)]
         .groupByKey(_._1)
-        .flatMapGroupsWithState(Update, EventTimeTimeout)(stateFunc)
+        .flatMapGroupsWithState(Update, EventTimeTimeout)(sampleTestFunction)
 
     testStream(result, Update)(
       StartStream(),
@@ -981,26 +989,6 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest {
   }
 
   test("flatMapGroupsWithState - recovery from checkpoint uses state format 
version 1") {
-    // Function to maintain the max event time as state and set the timeout 
timestamp based on the
-    // current max event time seen. It returns the max event time in the 
state, or -1 if the state
-    // was removed by timeout.
-    val stateFunc = (key: String, values: Iterator[(String, Long)], state: 
GroupState[Long]) => {
-      assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 }
-      assertCanGetWatermark { state.getCurrentWatermarkMs() >= -1 }
-
-      val timeoutDelaySec = 5
-      if (state.hasTimedOut) {
-        state.remove()
-        Iterator((key, -1))
-      } else {
-        val valuesSeq = values.toSeq
-        val maxEventTimeSec = math.max(valuesSeq.map(_._2).max, 
state.getOption.getOrElse(0L))
-        val timeoutTimestampSec = maxEventTimeSec + timeoutDelaySec
-        state.update(maxEventTimeSec)
-        state.setTimeoutTimestamp(timeoutTimestampSec * 1000)
-        Iterator((key, maxEventTimeSec.toInt))
-      }
-    }
     val inputData = MemoryStream[(String, Int)]
     val result =
       inputData.toDS
@@ -1008,7 +996,7 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest {
         .withWatermark("eventTime", "10 seconds")
         .as[(String, Long)]
         .groupByKey(_._1)
-        .flatMapGroupsWithState(Update, EventTimeTimeout)(stateFunc)
+        .flatMapGroupsWithState(Update, EventTimeTimeout)(sampleTestFunction)
 
     val resourceUri = this.getClass.getResource(
       
"/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/").toURI
@@ -1089,7 +1077,7 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest {
       StartStream(),
       AddData(inputData, "a", "b"), // should remove state for "a" and return 
count as -1
       CheckNewAnswer(("a", "-1"), ("b", "2")),
-      assertNumStateRows(total = 1, updated = 2),
+      assertNumStateRows(total = 1, updated = 1),
       StopStream,
       StartStream(),
       AddData(inputData, "a", "c"), // should recreate state for "a" and 
return count as 1
@@ -1122,6 +1110,43 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest {
       spark.createDataset(Seq(("a", 2), ("b", 1))).toDF)
   }
 
+  test("SPARK-35896: metrics in StateOperatorProgress are output correctly") {
+    val inputData = MemoryStream[(String, Int)]
+    val result =
+      inputData.toDS
+        .select($"_1".as("key"), timestamp_seconds($"_2").as("eventTime"))
+        .withWatermark("eventTime", "10 seconds")
+        .as[(String, Long)]
+        .groupByKey(_._1)
+        .flatMapGroupsWithState(Update, EventTimeTimeout)(sampleTestFunction)
+
+    testStream(result, Update)(
+      StartStream(additionalConfs = Map(SQLConf.SHUFFLE_PARTITIONS.key -> 
"3")),
+
+      AddData(inputData, ("a", 11), ("a", 13), ("a", 15)),
+      // Max event time = 15. Timeout timestamp for "a" = 15 + 5 = 20. 
Watermark = 15 - 10 = 5.
+      CheckNewAnswer(("a", 15)),  // Output = max event time of a
+      assertNumStateRows(
+        total = Seq(1), updated = Seq(1), droppedByWatermark = Seq(0), removed 
= Some(Seq(0))),
+
+      AddData(inputData, ("a", 4)),       // Add data older than watermark for 
"a"
+      CheckNewAnswer(),                   // No output as data should get 
filtered by watermark
+      assertStateOperatorProgressMetric(operatorName = 
"flatMapGroupsWithState",
+        numShufflePartitions = 3, numStateStoreInstances = 3),
+
+      AddData(inputData, ("a", 10)),      // Add data newer than watermark for 
"a"
+      CheckNewAnswer(("a", 15)),          // Max event time is still the same
+      // Timeout timestamp for "a" is still 20 as max event time for "a" is 
still 15.
+      // Watermark is still 5 as max event time for all data is still 15.
+
+      AddData(inputData, ("b", 31)),      // Add data newer than watermark for 
"b", not "a"
+      // Watermark = 31 - 10 = 21, so "a" should be timed out as timeout 
timestamp for "a" is 20.
+      CheckNewAnswer(("a", -1), ("b", 31)), // State for "a" should timeout 
and emit -1
+      assertNumStateRows(
+        total = Seq(1), updated = Seq(2), droppedByWatermark = Seq(0), removed 
= Some(Seq(1)))
+    )
+  }
+
   testWithAllStateVersions("SPARK-29438: ensure UNION doesn't lead 
(flat)MapGroupsWithState" +
     " to use shifted partition IDs") {
     val stateFunc = (key: String, values: Iterator[String], state: 
GroupState[RunningCount]) => {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala
index 5073723..0abc79a 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala
@@ -32,9 +32,10 @@ trait StateStoreMetricsTest extends StreamTest {
   def assertNumStateRows(
       total: Seq[Long],
       updated: Seq[Long],
-      droppedByWatermark: Seq[Long]): AssertOnQuery = {
+      droppedByWatermark: Seq[Long],
+      removed: Option[Seq[Long]]): AssertOnQuery = {
     AssertOnQuery(s"Check total state rows = $total, updated state rows = 
$updated" +
-      s", rows dropped by watermark = $droppedByWatermark") { q =>
+      s", rows dropped by watermark = $droppedByWatermark, removed state rows 
= $removed") { q =>
       // This assumes that the streaming query will not make any progress 
while the eventually
       // is being executed.
       eventually(timeout(streamingTimeout)) {
@@ -61,6 +62,13 @@ trait StateStoreMetricsTest extends StreamTest {
         assert(numRowsDroppedByWatermark === droppedByWatermark,
           s"incorrect dropped rows by watermark, $debugString")
 
+        if (removed.isDefined) {
+          val allNumRowsRemovedSinceLastCheck =
+            
progressesSinceLastCheck.map(_.stateOperators.map(_.numRowsRemoved))
+          val numRemovedRows = arraySum(allNumRowsRemovedSinceLastCheck, 
numStateOperators)
+          assert(numRemovedRows === removed.get, s"incorrect removed rows, 
$debugString")
+        }
+
         advanceLastCheckedRecentProgressIndex(lastCheckedProgressIndex)
       }
       true
@@ -92,16 +100,45 @@ trait StateStoreMetricsTest extends StreamTest {
     }
   }
 
+  /** Assert on [[StateOperatorProgress]] metrics */
+  def assertStateOperatorProgressMetric(operatorName: String, 
numShufflePartitions: Long,
+      numStateStoreInstances: Long, operatorIndex: Int = 0): AssertOnQuery = {
+    AssertOnQuery(s"Check operator progress metrics: operatorName = 
$operatorName, " +
+      s"numShufflePartitions = $numShufflePartitions, " +
+      s"numStateStoreInstances = $numStateStoreInstances") { q =>
+      eventually(timeout(streamingTimeout)) {
+        val (progressesSinceLastCheck, lastCheckedProgressIndex, 
numStateOperators) =
+          retrieveProgressesSinceLastCheck(q)
+        assert(operatorIndex < numStateOperators, s"Invalid operator Index: 
$operatorIndex")
+        val lastOpProgress = 
progressesSinceLastCheck.last.stateOperators(operatorIndex)
+
+        lazy val debugString = "recent progresses:\n" +
+          progressesSinceLastCheck.map(_.prettyJson).mkString("\n\n")
+
+        assert(lastOpProgress.operatorName === operatorName,
+          s"incorrect operator name, $debugString")
+        assert(lastOpProgress.numShufflePartitions === numShufflePartitions,
+          s"incorrect number of shuffle partitions, $debugString")
+        assert(lastOpProgress.numStateStoreInstances === 
numStateStoreInstances,
+          s"incorrect number of state stores, $debugString")
+
+        advanceLastCheckedRecentProgressIndex(lastCheckedProgressIndex)
+      }
+      true
+    }
+  }
+
   def assertNumStateRows(total: Seq[Long], updated: Seq[Long]): AssertOnQuery 
= {
     assert(total.length === updated.length)
-    assertNumStateRows(total, updated, droppedByWatermark = (0 until 
total.length).map(_ => 0L))
+    assertNumStateRows(
+      total, updated, droppedByWatermark = (0 until total.length).map(_ => 
0L), None)
   }
 
   def assertNumStateRows(
       total: Long,
       updated: Long,
       droppedByWatermark: Long = 0): AssertOnQuery = {
-    assertNumStateRows(Seq(total), Seq(updated), Seq(droppedByWatermark))
+    assertNumStateRows(Seq(total), Seq(updated), Seq(droppedByWatermark), None)
   }
 
   def arraySum(arraySeq: Seq[Array[Long]], arrayLength: Int): Seq[Long] = {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
index 491b0d8..eef13ca 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
@@ -796,6 +796,64 @@ class StreamingAggregationSuite extends 
StateStoreMetricsTest with Assertions {
     }
   }
 
+  test("SPARK-35896: metrics in StateOperatorProgress are output correctly") {
+    val inputData = MemoryStream[Int]
+    val aggregated =
+      inputData.toDF()
+        .groupBy($"value")
+        .agg(count("*"))
+        .as[(Int, Long)]
+
+    testStream(aggregated, Update) (
+      StartStream(additionalConfs = Map(SQLConf.SHUFFLE_PARTITIONS.key -> 
"3")),
+      AddData(inputData, 3, 2, 1, 3),
+      CheckLastBatch((3, 2), (2, 1), (1, 1)),
+      assertNumStateRows(
+        total = Seq(3), updated = Seq(3), droppedByWatermark = Seq(0), removed 
= Some(Seq(0))),
+
+      AddData(inputData, 1, 4),
+      CheckLastBatch((1, 2), (4, 1)),
+      assertStateOperatorProgressMetric(
+        operatorName = "stateStoreSave", numShufflePartitions = 3, 
numStateStoreInstances = 3)
+    )
+
+    inputData.reset() // reset the input to clear any data from prev test
+    testStream(aggregated, Complete) (
+      StartStream(additionalConfs = Map(SQLConf.SHUFFLE_PARTITIONS.key -> 
"3")),
+      AddData(inputData, 3, 2, 1, 3),
+      CheckLastBatch((3, 2), (2, 1), (1, 1)),
+      assertNumStateRows(
+        total = Seq(3), updated = Seq(3), droppedByWatermark = Seq(0), removed 
= Some(Seq(0))),
+
+      AddData(inputData, 1, 4),
+      CheckLastBatch((3, 2), (2, 1), (1, 2), (4, 1)),
+      assertStateOperatorProgressMetric(
+        operatorName = "stateStoreSave", numShufflePartitions = 3, 
numStateStoreInstances = 3)
+    )
+
+    // with watermark and append output mode
+    val aggWithWatermark = inputData.toDF()
+      .withColumn("eventTime", timestamp_seconds($"value"))
+      .withWatermark("eventTime", "10 seconds")
+      .groupBy(window($"eventTime", "5 seconds") as 'window)
+      .agg(count("*") as 'count)
+      .select($"window".getField("start").cast("long").as[Long], 
$"count".as[Long])
+
+    inputData.reset() // reset the input to clear any data from prev test
+    testStream(aggWithWatermark, Append) (
+      StartStream(additionalConfs = Map(SQLConf.SHUFFLE_PARTITIONS.key -> 
"3")),
+      AddData(inputData, 3, 2, 1, 9),
+      CheckLastBatch(),
+      assertStateOperatorProgressMetric(
+        operatorName = "stateStoreSave", numShufflePartitions = 3, 
numStateStoreInstances = 3),
+
+      AddData(inputData, 25), // Advance watermark to 15 secs, no-data-batch 
drops rows <= 15
+      CheckLastBatch((0, 3), (5, 1)),
+      assertNumStateRows(
+        total = Seq(1), updated = Seq(1), droppedByWatermark = Seq(0), removed 
= Some(Seq(2)))
+    )
+  }
+
   private def prepareTestForChangingSchemaOfState(
       tempDir: File): (MemoryStream[Int], DataFrame) = {
     val inputData = MemoryStream[Int]
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala
index dc2e787..aa03da6 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala
@@ -80,6 +80,38 @@ class StreamingDeduplicationSuite extends 
StateStoreMetricsTest {
     )
   }
 
+  test("SPARK-35896: metrics in StateOperatorProgress are output correctly") {
+    val inputData = MemoryStream[Int]
+    val result = inputData.toDS()
+      .withColumn("eventTime", timestamp_seconds($"value"))
+      .withWatermark("eventTime", "10 seconds")
+      .dropDuplicates()
+      .select($"eventTime".cast("long").as[Long])
+
+    testStream(result, Append)(
+      StartStream(additionalConfs = Map(SQLConf.SHUFFLE_PARTITIONS.key -> 
"2")),
+      AddData(inputData, (1 to 5).flatMap(_ => (10 to 15)): _*),
+      CheckAnswer(10 to 15: _*),
+      assertNumStateRows(
+        total = Seq(6), updated = Seq(6), droppedByWatermark = Seq(0), removed 
= Some(Seq(0))),
+
+      AddData(inputData, 25), // Advance watermark to 15 secs, no-data-batch 
drops rows <= 15
+      CheckNewAnswer(25),
+      assertNumStateRows(
+        total = Seq(1), updated = Seq(1), droppedByWatermark = Seq(0), removed 
= Some(Seq(6))),
+
+      AddData(inputData, 10), // Should not emit anything as data less than 
watermark
+      CheckNewAnswer(),
+      assertNumStateRows(
+        total = Seq(1), updated = Seq(0), droppedByWatermark = Seq(1), removed 
= Some(Seq(0))),
+
+      AddData(inputData, 10),
+      CheckNewAnswer(),
+      assertStateOperatorProgressMetric(
+        operatorName = "dedupe", numShufflePartitions = 2, 
numStateStoreInstances = 2)
+    )
+  }
+
   test("deduplicate with watermark") {
     val inputData = MemoryStream[Int]
     val result = inputData.toDS()
@@ -134,7 +166,7 @@ class StreamingDeduplicationSuite extends 
StateStoreMetricsTest {
       AddData(inputData, 10), // Should not emit anything as data less than 
watermark
       CheckLastBatch(),
       assertNumStateRows(total = Seq(2L, 1L), updated = Seq(0L, 0L),
-        droppedByWatermark = Seq(0L, 1L)),
+        droppedByWatermark = Seq(0L, 1L), None),
 
       AddData(inputData, 40), // Advance watermark to 30 seconds
       CheckLastBatch((15 -> 1), (25 -> 1)),
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
index 4a5422e..8c830d3 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
@@ -32,6 +32,7 @@ import 
org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
 import org.apache.spark.sql.execution.streaming.{MemoryStream, 
StatefulOperatorStateInfo, StreamingSymmetricHashJoinExec, 
StreamingSymmetricHashJoinHelper}
 import org.apache.spark.sql.execution.streaming.state.{StateStore, 
StateStoreProviderId}
 import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.util.Utils
 
 abstract class StreamingJoinSuite
@@ -645,6 +646,48 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite {
       }
     )
   }
+
+  test("SPARK-35896: metrics in StateOperatorProgress are output correctly") {
+    val input1 = MemoryStream[Int]
+    val input2 = MemoryStream[Int]
+
+    val df1 = input1.toDF
+      .select('value as "key", timestamp_seconds($"value") as "timestamp",
+        ('value * 2) as "leftValue")
+      .withWatermark("timestamp", "10 seconds")
+      .select('key, window('timestamp, "10 second"), 'leftValue)
+
+    val df2 = input2.toDF
+      .select('value as "key", timestamp_seconds($"value") as "timestamp",
+        ('value * 3) as "rightValue")
+      .select('key, window('timestamp, "10 second"), 'rightValue)
+
+    val joined = df1.join(df2, Seq("key", "window"))
+      .select('key, $"window.end".cast("long"), 'leftValue, 'rightValue)
+
+    testStream(joined)(
+      StartStream(additionalConfs = Map(SQLConf.SHUFFLE_PARTITIONS.key -> 
"3")),
+      AddData(input1, 1),
+      CheckAnswer(),
+      assertStateOperatorProgressMetric(operatorName = "symmetricHashJoin",
+        numShufflePartitions = 3, numStateStoreInstances = 3 * 4),
+
+      AddData(input2, 1),
+      CheckAnswer((1, 10, 2, 3)),
+      assertNumStateRows(
+        total = Seq(2), updated = Seq(1), droppedByWatermark = Seq(0), removed 
= Some(Seq(0))),
+
+      AddData(input1, 25),
+      CheckNewAnswer(),   // watermark = 15, no-data-batch should remove 2 
rows having window=[0,10]
+      assertNumStateRows(
+        total = Seq(1), updated = Seq(1), droppedByWatermark = Seq(0), removed 
= Some(Seq(2))),
+
+      AddData(input2, 25),
+      CheckNewAnswer((25, 30, 50, 75)),
+      assertNumStateRows(
+        total = Seq(2), updated = Seq(1), droppedByWatermark = Seq(0), removed 
= Some(Seq(0)))
+    )
+  }
 }
 
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala
index 3eca465..99fcef1 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala
@@ -60,10 +60,17 @@ class StreamingQueryStatusAndProgressSuite extends 
StreamTest with Eventually {
         |    "watermark" : "2016-12-05T20:54:20.827Z"
         |  },
         |  "stateOperators" : [ {
+        |    "operatorName" : "op1",
         |    "numRowsTotal" : 0,
         |    "numRowsUpdated" : 1,
+        |    "allUpdatesTimeMs" : 1,
+        |    "numRowsRemoved" : 2,
+        |    "allRemovalsTimeMs" : 34,
+        |    "commitTimeMs" : 23,
         |    "memoryUsedBytes" : 3,
         |    "numRowsDroppedByWatermark" : 0,
+        |    "numShufflePartitions" : 2,
+        |    "numStateStoreInstances" : 2,
         |    "customMetrics" : {
         |      "loadedMapCacheHitCount" : 1,
         |      "loadedMapCacheMissCount" : 0,
@@ -112,10 +119,17 @@ class StreamingQueryStatusAndProgressSuite extends 
StreamTest with Eventually {
          |    "total" : 0
          |  },
          |  "stateOperators" : [ {
+         |    "operatorName" : "op2",
          |    "numRowsTotal" : 0,
          |    "numRowsUpdated" : 1,
+         |    "allUpdatesTimeMs" : 1,
+         |    "numRowsRemoved" : 2,
+         |    "allRemovalsTimeMs" : 34,
+         |    "commitTimeMs" : 23,
          |    "memoryUsedBytes" : 2,
-         |    "numRowsDroppedByWatermark" : 0
+         |    "numRowsDroppedByWatermark" : 0,
+         |    "numShufflePartitions" : 2,
+         |    "numStateStoreInstances" : 2
          |  } ],
          |  "sources" : [ {
          |    "description" : "source",
@@ -323,8 +337,10 @@ object StreamingQueryStatusAndProgressSuite {
       "min" -> "2016-12-05T20:54:20.827Z",
       "avg" -> "2016-12-05T20:54:20.827Z",
       "watermark" -> "2016-12-05T20:54:20.827Z").asJava),
-    stateOperators = Array(new StateOperatorProgress(
-      numRowsTotal = 0, numRowsUpdated = 1, memoryUsedBytes = 3, 
numRowsDroppedByWatermark = 0,
+    stateOperators = Array(new StateOperatorProgress(operatorName = "op1",
+      numRowsTotal = 0, numRowsUpdated = 1, allUpdatesTimeMs = 1, 
numRowsRemoved = 2,
+      allRemovalsTimeMs = 34, commitTimeMs = 23, memoryUsedBytes = 3, 
numRowsDroppedByWatermark = 0,
+      numShufflePartitions = 2, numStateStoreInstances = 2,
       customMetrics = new 
java.util.HashMap(Map("stateOnCurrentVersionSizeBytes" -> 2L,
         "loadedMapCacheHitCount" -> 1L, "loadedMapCacheMissCount" -> 0L)
         .mapValues(long2Long).toMap.asJava)
@@ -356,8 +372,10 @@ object StreamingQueryStatusAndProgressSuite {
     durationMs = new java.util.HashMap(Map("total" -> 
0L).mapValues(long2Long).toMap.asJava),
     // empty maps should be handled correctly
     eventTime = new java.util.HashMap(Map.empty[String, String].asJava),
-    stateOperators = Array(new StateOperatorProgress(
-      numRowsTotal = 0, numRowsUpdated = 1, memoryUsedBytes = 2, 
numRowsDroppedByWatermark = 0)),
+    stateOperators = Array(new StateOperatorProgress(operatorName = "op2",
+      numRowsTotal = 0, numRowsUpdated = 1, allUpdatesTimeMs = 1, 
numRowsRemoved = 2,
+      allRemovalsTimeMs = 34, commitTimeMs = 23, memoryUsedBytes = 2, 
numRowsDroppedByWatermark = 0,
+      numShufflePartitions = 2, numStateStoreInstances = 2)),
     sources = Array(
       new SourceProgress(
         description = "source",

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to