ericm-db commented on code in PR #49816:
URL: https://github.com/apache/spark/pull/49816#discussion_r1943550436


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala:
##########
@@ -1239,6 +1241,7 @@ class RocksDB(
         log"with uniqueId: ${MDC(LogKeys.UUID, snapshot.uniqueId)} " +
         log"time taken: ${MDC(LogKeys.TIME_UNITS, uploadTime)} ms. " +
         log"Current lineage: ${MDC(LogKeys.LINEAGE, lineageManager)}")
+      lastUploadedVersion = version

Review Comment:
   Does this not need to be snapshot.version?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala:
##########
@@ -320,11 +320,24 @@ trait StateStoreWriter
    * the driver after this SparkPlan has been executed and metrics have been 
updated.
    */
   def getProgress(): StateOperatorProgress = {
+    val customPartitionMetrics = stateStoreCustomMetrics
+      .map(entry => entry._1 -> longMetric(entry._1).value)
+      .filter(entry => 
entry._1.contains(StateStoreProvider.PARTITION_METRIC_SUFFIX) && entry._2 != 0)

Review Comment:
   Did we not find a better way to do this? If a partition's snapshot has never 
been uploaded, this is something we should be alerted to. 



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala:
##########
@@ -504,10 +504,19 @@ trait StateStoreProvider {
    * (specifically, same names) through `StateStore.metrics`.
    */
   def supportedCustomMetrics: Seq[StateStoreCustomMetric] = Nil
+
+  /**
+   * Optional custom partition-specific metrics that the implementation may 
want to report.
+   * @note The StateStore objects created by this provider must report the 
same custom metrics
+   * (specifically, same names) through `StateStore.metrics`.
+   */
+  def supportedCustomPartitionMetrics: Seq[Long => StateStoreCustomMetric] = 
Nil

Review Comment:
   Maybe we should create a new case class, `StateStoreCustomPartitionMetric` 
with partitionId, what do you think?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala:
##########
@@ -693,6 +697,15 @@ object RocksDBStateStoreProvider {
       expireAfterAccessTimeUnit = TimeUnit.HOURS
     )
 
+  val CUSTOM_METRIC_SNAPSHOT_LAST_UPLOADED = (partitionId: Long) => {

Review Comment:
   This seems more appropriate as a function as opposed to a val. 



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala:
##########
@@ -222,6 +234,81 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest
     }
   }
 
+  Seq(
+    classOf[SkipMaintenanceOnCertainPartitionsProvider].getName,
+    classOf[RocksDBStateStoreProvider].getName
+  ).foreach { stateStoreClass =>
+    test(
+      s"Verify snapshot lag metric is updated correctly with stateStoreClass = 
$stateStoreClass"
+    ) {
+      withSQLConf(
+        SQLConf.STATE_STORE_PROVIDER_CLASS.key -> stateStoreClass,
+        SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "500",
+        SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1",
+        SQLConf.STATE_STORE_PARTITION_METRICS_REPORT_LIMIT.key -> "2",
+        // Set shuffle partitions to 15 so that 20% of shuffle partitions is 
above the report limit
+        SQLConf.SHUFFLE_PARTITIONS.key -> "15",
+        RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + 
".changelogCheckpointing.enabled" -> "true"
+      ) {
+        withTempDir { checkpointDir =>
+          val inputData = MemoryStream[String]
+          val result = inputData.toDS().dropDuplicates()
+
+          testStream(result, outputMode = OutputMode.Update)(
+            StartStream(checkpointLocation = checkpointDir.getCanonicalPath),
+            AddData(inputData, "a"),
+            Execute { _ =>
+              Thread.sleep(500)
+            },
+            AddData(inputData, "b"),
+            Execute { _ =>
+              Thread.sleep(500)
+            },
+            AddData(inputData, "c"),
+            Execute { _ =>
+              Thread.sleep(500)
+            },
+            AddData(inputData, "d"),
+            CheckNewAnswer("a", "b", "c", "d"),
+            Execute { q =>
+              val metricNamePrefix =
+                "rocksdbSnapshotLastUploaded" + 
StateStoreProvider.PARTITION_METRIC_SUFFIX
+              if (stateStoreClass == 
classOf[SkipMaintenanceOnCertainPartitionsProvider].getName) {
+                // Partitions getting skipped (id 0 and 1) do not have an 
uploaded version, leaving
+                // the metric empty.
+                assert(
+                  q.lastProgress
+                    .stateOperators(0)
+                    .customMetrics
+                    .containsKey(metricNamePrefix + "0") === false
+                )
+                assert(
+                  q.lastProgress
+                    .stateOperators(0)
+                    .customMetrics
+                    .containsKey(metricNamePrefix + "1") === false
+                )
+              }

Review Comment:
   There should be an else cause for the RocksDBStateStoreProvider to check 
that all partitions are at the same value



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala:
##########
@@ -222,6 +234,81 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest
     }
   }
 
+  Seq(
+    classOf[SkipMaintenanceOnCertainPartitionsProvider].getName,
+    classOf[RocksDBStateStoreProvider].getName
+  ).foreach { stateStoreClass =>
+    test(
+      s"Verify snapshot lag metric is updated correctly with stateStoreClass = 
$stateStoreClass"
+    ) {
+      withSQLConf(
+        SQLConf.STATE_STORE_PROVIDER_CLASS.key -> stateStoreClass,
+        SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "500",
+        SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1",
+        SQLConf.STATE_STORE_PARTITION_METRICS_REPORT_LIMIT.key -> "2",
+        // Set shuffle partitions to 15 so that 20% of shuffle partitions is 
above the report limit
+        SQLConf.SHUFFLE_PARTITIONS.key -> "15",
+        RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + 
".changelogCheckpointing.enabled" -> "true"
+      ) {
+        withTempDir { checkpointDir =>
+          val inputData = MemoryStream[String]
+          val result = inputData.toDS().dropDuplicates()
+
+          testStream(result, outputMode = OutputMode.Update)(
+            StartStream(checkpointLocation = checkpointDir.getCanonicalPath),
+            AddData(inputData, "a"),
+            Execute { _ =>
+              Thread.sleep(500)
+            },
+            AddData(inputData, "b"),
+            Execute { _ =>
+              Thread.sleep(500)
+            },
+            AddData(inputData, "c"),
+            Execute { _ =>
+              Thread.sleep(500)
+            },
+            AddData(inputData, "d"),
+            CheckNewAnswer("a", "b", "c", "d"),
+            Execute { q =>
+              val metricNamePrefix =
+                "rocksdbSnapshotLastUploaded" + 
StateStoreProvider.PARTITION_METRIC_SUFFIX
+              if (stateStoreClass == 
classOf[SkipMaintenanceOnCertainPartitionsProvider].getName) {
+                // Partitions getting skipped (id 0 and 1) do not have an 
uploaded version, leaving
+                // the metric empty.
+                assert(
+                  q.lastProgress
+                    .stateOperators(0)
+                    .customMetrics
+                    .containsKey(metricNamePrefix + "0") === false

Review Comment:
   Can you make this a method? Maybe something like 
   ```
   def snapshotLagMetricName(partitionId: Long): String
   ```



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala:
##########
@@ -222,6 +234,81 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest
     }
   }
 
+  Seq(
+    classOf[SkipMaintenanceOnCertainPartitionsProvider].getName,
+    classOf[RocksDBStateStoreProvider].getName
+  ).foreach { stateStoreClass =>
+    test(
+      s"Verify snapshot lag metric is updated correctly with stateStoreClass = 
$stateStoreClass"
+    ) {
+      withSQLConf(
+        SQLConf.STATE_STORE_PROVIDER_CLASS.key -> stateStoreClass,
+        SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "500",
+        SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1",
+        SQLConf.STATE_STORE_PARTITION_METRICS_REPORT_LIMIT.key -> "2",
+        // Set shuffle partitions to 15 so that 20% of shuffle partitions is 
above the report limit
+        SQLConf.SHUFFLE_PARTITIONS.key -> "15",
+        RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX + 
".changelogCheckpointing.enabled" -> "true"

Review Comment:
   use `testWithChangelogCheckpointingEnabled` instead



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala:
##########
@@ -222,6 +234,81 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest
     }
   }
 
+  Seq(

Review Comment:
   Split this into two different tests



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to