neilramaswamy commented on code in PR #47895:
URL: https://github.com/apache/spark/pull/47895#discussion_r1779272544


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala:
##########
@@ -72,7 +72,8 @@ class RocksDB(
     localRootDir: File = Utils.createTempDir(),
     hadoopConf: Configuration = new Configuration,
     loggingId: String = "",
-    useColumnFamilies: Boolean = false) extends Logging {
+    useColumnFamilies: Boolean = false,
+    ifEnableCheckpointId: Boolean = false) extends Logging {

Review Comment:
   More consistent to call it `enableStateStoreCheckpointIds`.
   
   I also think that the term "checkpoint ID" is very confusing. The term makes 
it feel like it's an ID for an _entire_ checkpoint, when really it's an ID for 
a particular state store that has been checkpointed.
   
   I know it's a tedious modification to make. I would be happy to alleviate 
some of this work by creating a branch with that change and putting up a PR 
that you can merge back in this branch.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala:
##########
@@ -808,6 +824,45 @@ object SymmetricHashJoinStateManager {
     result
   }
 
+  def mergeStateStoreCheckpointInfo(

Review Comment:
   I already commented about this elsewhere (that it shouldn't be in the 
symmetric hash join state manager), but this was confusing to read because it 
is used in two places:
   
   1. To merge the key with index state store with the key with index to value 
state store
   2. To merge the results from step (1) for both the left and the right into 
one result



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala:
##########
@@ -190,6 +190,11 @@ trait StateStore extends ReadStateStore {
   /** Current metrics of the state store */
   def metrics: StateStoreMetrics
 
+  /** Return information on recently generated checkpoints */
+  def getCheckpointInfo: StateStoreCheckpointInfo = {
+    StateStoreCheckpointInfo(-1, -1, None, None)

Review Comment:
   Why default implementation? If all the sub-classes are overriding it, let's 
just make it required with no default.



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala:
##########
@@ -222,6 +375,456 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest
     }
   }
 
+  testWithChangelogCheckpointingEnabled(s"checkpointFormatVersion2") {
+    withSQLConf((SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key, "2")) {
+      val checkpointDir = Utils.createTempDir().getCanonicalFile
+      checkpointDir.delete()
+
+      val inputData = MemoryStream[Int]
+      val aggregated =
+        inputData
+          .toDF()
+          .groupBy($"value")
+          .agg(count("*"))
+          .as[(Int, Long)]
+
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 3),
+        CheckLastBatch((3, 1)),
+        AddData(inputData, 3, 2),
+        CheckLastBatch((3, 2), (2, 1)),
+        StopStream
+      )
+
+      // Run the stream with changelog checkpointing enabled.
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 3, 2, 1),
+        CheckLastBatch((3, 3), (2, 2), (1, 1)),
+        // By default we run in new tuple mode.
+        AddData(inputData, 4, 4, 4, 4),
+        CheckLastBatch((4, 4)),
+        AddData(inputData, 5, 5),
+        CheckLastBatch((5, 2))
+      )
+
+      // Run the stream with changelog checkpointing disabled.
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 4),
+        CheckLastBatch((4, 5))
+      )
+    }
+  }
+
+  testWithChangelogCheckpointingEnabled(s"checkpointFormatVersion2 validate 
ID") {
+    val providerClassName = 
classOf[TestStateStoreProviderWrapper].getCanonicalName
+    TestStateStoreWrapper.clear()
+    withSQLConf(
+      (SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName),
+      (SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "2"),
+      (SQLConf.SHUFFLE_PARTITIONS.key, "2")) {
+      val checkpointDir = Utils.createTempDir().getCanonicalFile
+      checkpointDir.delete()
+
+      val inputData = MemoryStream[Int]
+      val aggregated =
+        inputData
+          .toDF()
+          .groupBy($"value")
+          .agg(count("*"))
+          .as[(Int, Long)]
+
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 3),
+        CheckLastBatch((3, 1)),
+        AddData(inputData, 3, 2),
+        CheckLastBatch((3, 2), (2, 1)),
+        StopStream
+      )
+
+      // Test recovery
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 3, 2, 1),
+        CheckLastBatch((3, 3), (2, 2), (1, 1)),
+        // By default we run in new tuple mode.
+        AddData(inputData, 4, 4, 4, 4),
+        CheckLastBatch((4, 4)),
+        AddData(inputData, 5, 5),
+        CheckLastBatch((5, 2)),
+        StopStream
+      )
+
+      // crash recovery again

Review Comment:
   How is this simulating "crash recovery"? I think the previous stream would 
have executed cleanly.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala:
##########
@@ -233,6 +238,15 @@ case class StateStoreMetrics(
     memoryUsedBytes: Long,
     customMetrics: Map[StateStoreCustomMetric, Long])
 
+case class StateStoreCheckpointInfo(
+    partitionId: Int,
+    batchVersion: Long,
+    // The checkpoint ID for a checkpoint at `batchVersion`. This is used to 
identify the checkpoint
+    checkpointId: Option[String],

Review Comment:
   If we use a String, we need to mention that it's not necessarily _one_ 
checkpoint ID. It could be many, comma-separated.
   
   But to be honest, I don't think we should be using `String` here, because 
it's ambiguous. Is it 1 checkpoint? 4 checkpoints? You cannot simply tell by 
looking at the code. The naming is also off in the case of multiple 
checkpoints; it's `StateStore*s*CheckpointInfo`.
   
   I think it makes more sense for us to return, all the way through the 
accumulator, a `Seq[String]`. Then, the only place that the merging should 
happen is inside of `def getCheckpointInfo` inside of `StateStoreWriter`. This 
avoids us from awkwardly having one-off merging logic inside of the s/s join, 
even though I know it's the only place.



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala:
##########
@@ -222,6 +375,456 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest
     }
   }
 
+  testWithChangelogCheckpointingEnabled(s"checkpointFormatVersion2") {
+    withSQLConf((SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key, "2")) {
+      val checkpointDir = Utils.createTempDir().getCanonicalFile
+      checkpointDir.delete()
+
+      val inputData = MemoryStream[Int]
+      val aggregated =
+        inputData
+          .toDF()
+          .groupBy($"value")
+          .agg(count("*"))
+          .as[(Int, Long)]
+
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 3),
+        CheckLastBatch((3, 1)),
+        AddData(inputData, 3, 2),
+        CheckLastBatch((3, 2), (2, 1)),
+        StopStream
+      )
+
+      // Run the stream with changelog checkpointing enabled.

Review Comment:
   How exactly are you enabling/disabling changelog checkpointing here?



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala:
##########
@@ -222,6 +375,456 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest
     }
   }
 
+  testWithChangelogCheckpointingEnabled(s"checkpointFormatVersion2") {
+    withSQLConf((SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key, "2")) {
+      val checkpointDir = Utils.createTempDir().getCanonicalFile
+      checkpointDir.delete()
+
+      val inputData = MemoryStream[Int]
+      val aggregated =
+        inputData
+          .toDF()
+          .groupBy($"value")
+          .agg(count("*"))
+          .as[(Int, Long)]
+
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 3),
+        CheckLastBatch((3, 1)),
+        AddData(inputData, 3, 2),
+        CheckLastBatch((3, 2), (2, 1)),
+        StopStream
+      )
+
+      // Run the stream with changelog checkpointing enabled.
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 3, 2, 1),
+        CheckLastBatch((3, 3), (2, 2), (1, 1)),
+        // By default we run in new tuple mode.
+        AddData(inputData, 4, 4, 4, 4),
+        CheckLastBatch((4, 4)),
+        AddData(inputData, 5, 5),
+        CheckLastBatch((5, 2))
+      )
+
+      // Run the stream with changelog checkpointing disabled.
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 4),
+        CheckLastBatch((4, 5))
+      )
+    }
+  }
+
+  testWithChangelogCheckpointingEnabled(s"checkpointFormatVersion2 validate 
ID") {
+    val providerClassName = 
classOf[TestStateStoreProviderWrapper].getCanonicalName
+    TestStateStoreWrapper.clear()
+    withSQLConf(
+      (SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName),
+      (SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "2"),
+      (SQLConf.SHUFFLE_PARTITIONS.key, "2")) {
+      val checkpointDir = Utils.createTempDir().getCanonicalFile
+      checkpointDir.delete()
+
+      val inputData = MemoryStream[Int]
+      val aggregated =
+        inputData
+          .toDF()
+          .groupBy($"value")
+          .agg(count("*"))
+          .as[(Int, Long)]
+
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 3),
+        CheckLastBatch((3, 1)),
+        AddData(inputData, 3, 2),
+        CheckLastBatch((3, 2), (2, 1)),
+        StopStream
+      )
+
+      // Test recovery
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 3, 2, 1),
+        CheckLastBatch((3, 3), (2, 2), (1, 1)),
+        // By default we run in new tuple mode.
+        AddData(inputData, 4, 4, 4, 4),
+        CheckLastBatch((4, 4)),
+        AddData(inputData, 5, 5),
+        CheckLastBatch((5, 2)),
+        StopStream
+      )
+
+      // crash recovery again
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 4),
+        CheckLastBatch((4, 5))
+      )
+    }
+    val checkpointInfoList = TestStateStoreWrapper.getCheckpointInfos
+    // We have 6 batches, 2 partitions, and 1 state store per batch
+    assert(checkpointInfoList.size == 12)
+    checkpointInfoList.foreach { l =>
+      assert(l.checkpointId.isDefined)
+      if (l.batchVersion == 2 || l.batchVersion == 4 || l.batchVersion == 5) {

Review Comment:
   Sorry, I don't follow this. Why are we just checking these specific 
`batchVersions`? Shouldn't all of them, 0 to 5 inclusive, be present?



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala:
##########
@@ -222,6 +375,456 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest
     }
   }
 
+  testWithChangelogCheckpointingEnabled(s"checkpointFormatVersion2") {
+    withSQLConf((SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key, "2")) {
+      val checkpointDir = Utils.createTempDir().getCanonicalFile
+      checkpointDir.delete()
+
+      val inputData = MemoryStream[Int]
+      val aggregated =
+        inputData
+          .toDF()
+          .groupBy($"value")
+          .agg(count("*"))
+          .as[(Int, Long)]
+
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 3),
+        CheckLastBatch((3, 1)),
+        AddData(inputData, 3, 2),
+        CheckLastBatch((3, 2), (2, 1)),
+        StopStream
+      )
+
+      // Run the stream with changelog checkpointing enabled.
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 3, 2, 1),
+        CheckLastBatch((3, 3), (2, 2), (1, 1)),
+        // By default we run in new tuple mode.
+        AddData(inputData, 4, 4, 4, 4),
+        CheckLastBatch((4, 4)),
+        AddData(inputData, 5, 5),
+        CheckLastBatch((5, 2))
+      )
+
+      // Run the stream with changelog checkpointing disabled.
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 4),
+        CheckLastBatch((4, 5))
+      )
+    }
+  }
+
+  testWithChangelogCheckpointingEnabled(s"checkpointFormatVersion2 validate 
ID") {
+    val providerClassName = 
classOf[TestStateStoreProviderWrapper].getCanonicalName
+    TestStateStoreWrapper.clear()
+    withSQLConf(
+      (SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName),
+      (SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "2"),
+      (SQLConf.SHUFFLE_PARTITIONS.key, "2")) {
+      val checkpointDir = Utils.createTempDir().getCanonicalFile
+      checkpointDir.delete()
+
+      val inputData = MemoryStream[Int]
+      val aggregated =
+        inputData
+          .toDF()
+          .groupBy($"value")
+          .agg(count("*"))
+          .as[(Int, Long)]
+
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 3),
+        CheckLastBatch((3, 1)),
+        AddData(inputData, 3, 2),
+        CheckLastBatch((3, 2), (2, 1)),
+        StopStream
+      )
+
+      // Test recovery
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 3, 2, 1),
+        CheckLastBatch((3, 3), (2, 2), (1, 1)),
+        // By default we run in new tuple mode.
+        AddData(inputData, 4, 4, 4, 4),
+        CheckLastBatch((4, 4)),
+        AddData(inputData, 5, 5),
+        CheckLastBatch((5, 2)),
+        StopStream
+      )
+
+      // crash recovery again
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 4),
+        CheckLastBatch((4, 5))
+      )
+    }
+    val checkpointInfoList = TestStateStoreWrapper.getCheckpointInfos
+    // We have 6 batches, 2 partitions, and 1 state store per batch
+    assert(checkpointInfoList.size == 12)
+    checkpointInfoList.foreach { l =>
+      assert(l.checkpointId.isDefined)
+      if (l.batchVersion == 2 || l.batchVersion == 4 || l.batchVersion == 5) {
+        assert(l.baseCheckpointId.isDefined)
+      }
+    }
+    assert(checkpointInfoList.count(_.partitionId == 0) == 6)
+    assert(checkpointInfoList.count(_.partitionId == 1) == 6)
+    for (i <- 1 to 6) {
+      assert(checkpointInfoList.count(_.batchVersion == i) == 2)
+    }
+    for {
+      a <- checkpointInfoList
+      b <- checkpointInfoList
+      if a.partitionId == b.partitionId && a.batchVersion == b.batchVersion + 1
+    } {
+      // if batch version exists, it should be the same as the checkpoint ID 
of the previous batch
+      assert(!a.baseCheckpointId.isDefined || b.checkpointId == 
a.baseCheckpointId)
+    }
+  }
+
+  testWithChangelogCheckpointingEnabled(
+    s"checkpointFormatVersion2 validate ID with dedup and groupBy") {
+    val providerClassName = 
classOf[TestStateStoreProviderWrapper].getCanonicalName
+    TestStateStoreWrapper.clear()

Review Comment:
   All of these can be refactored into a `beforeEach` in the class



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala:
##########
@@ -900,12 +907,42 @@ class MicroBatchExecution(
    */
   protected def markMicroBatchExecutionStart(execCtx: 
MicroBatchExecutionContext): Unit = {}
 
+  private def updateCheckpointIdForOperator(
+      execCtx: MicroBatchExecutionContext,
+      opId: Long,
+      checkpointInfo: Array[StateStoreCheckpointInfo]): Unit = {
+    // TODO validate baseCheckpointId
+    checkpointInfo.map(_.batchVersion).foreach { v =>
+      assert(
+        execCtx.batchId == -1 || v == execCtx.batchId + 1,
+        s"version $v doesn't match current Batch ID ${execCtx.batchId}")

Review Comment:
   I don't understand the assertion here. We say `v == batchId + 1` and then 
assert that `v` must match `batchId`?



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala:
##########
@@ -222,6 +375,456 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest
     }
   }
 
+  testWithChangelogCheckpointingEnabled(s"checkpointFormatVersion2") {

Review Comment:
   From what I can tell, none of the new suites that were added cover the edge 
case in the design doc, right? There's no speculative execution here.
   
   I _think_ what you could do is create new manual StateStores that simulate 
the race 
[here](https://docs.google.com/document/d/1uWRMbN927cRXhSm5oeV3pbwb6o73am4r1ckEJDhAHa0/edit),
 without needing to write a query that does this. Right?



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala:
##########
@@ -222,6 +375,456 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest
     }
   }
 
+  testWithChangelogCheckpointingEnabled(s"checkpointFormatVersion2") {
+    withSQLConf((SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key, "2")) {
+      val checkpointDir = Utils.createTempDir().getCanonicalFile
+      checkpointDir.delete()
+
+      val inputData = MemoryStream[Int]
+      val aggregated =
+        inputData
+          .toDF()
+          .groupBy($"value")
+          .agg(count("*"))
+          .as[(Int, Long)]
+
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 3),
+        CheckLastBatch((3, 1)),
+        AddData(inputData, 3, 2),
+        CheckLastBatch((3, 2), (2, 1)),
+        StopStream
+      )
+
+      // Run the stream with changelog checkpointing enabled.
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 3, 2, 1),
+        CheckLastBatch((3, 3), (2, 2), (1, 1)),
+        // By default we run in new tuple mode.
+        AddData(inputData, 4, 4, 4, 4),
+        CheckLastBatch((4, 4)),
+        AddData(inputData, 5, 5),
+        CheckLastBatch((5, 2))
+      )
+
+      // Run the stream with changelog checkpointing disabled.
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 4),
+        CheckLastBatch((4, 5))
+      )
+    }
+  }
+
+  testWithChangelogCheckpointingEnabled(s"checkpointFormatVersion2 validate 
ID") {
+    val providerClassName = 
classOf[TestStateStoreProviderWrapper].getCanonicalName
+    TestStateStoreWrapper.clear()
+    withSQLConf(
+      (SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName),
+      (SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "2"),
+      (SQLConf.SHUFFLE_PARTITIONS.key, "2")) {
+      val checkpointDir = Utils.createTempDir().getCanonicalFile
+      checkpointDir.delete()
+
+      val inputData = MemoryStream[Int]
+      val aggregated =
+        inputData
+          .toDF()
+          .groupBy($"value")
+          .agg(count("*"))
+          .as[(Int, Long)]
+
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 3),
+        CheckLastBatch((3, 1)),
+        AddData(inputData, 3, 2),
+        CheckLastBatch((3, 2), (2, 1)),
+        StopStream
+      )
+
+      // Test recovery
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 3, 2, 1),
+        CheckLastBatch((3, 3), (2, 2), (1, 1)),
+        // By default we run in new tuple mode.
+        AddData(inputData, 4, 4, 4, 4),
+        CheckLastBatch((4, 4)),
+        AddData(inputData, 5, 5),
+        CheckLastBatch((5, 2)),
+        StopStream
+      )
+
+      // crash recovery again
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 4),
+        CheckLastBatch((4, 5))
+      )
+    }
+    val checkpointInfoList = TestStateStoreWrapper.getCheckpointInfos
+    // We have 6 batches, 2 partitions, and 1 state store per batch
+    assert(checkpointInfoList.size == 12)
+    checkpointInfoList.foreach { l =>
+      assert(l.checkpointId.isDefined)
+      if (l.batchVersion == 2 || l.batchVersion == 4 || l.batchVersion == 5) {
+        assert(l.baseCheckpointId.isDefined)
+      }
+    }
+    assert(checkpointInfoList.count(_.partitionId == 0) == 6)
+    assert(checkpointInfoList.count(_.partitionId == 1) == 6)
+    for (i <- 1 to 6) {
+      assert(checkpointInfoList.count(_.batchVersion == i) == 2)
+    }
+    for {
+      a <- checkpointInfoList
+      b <- checkpointInfoList
+      if a.partitionId == b.partitionId && a.batchVersion == b.batchVersion + 1
+    } {
+      // if batch version exists, it should be the same as the checkpoint ID 
of the previous batch
+      assert(!a.baseCheckpointId.isDefined || b.checkpointId == 
a.baseCheckpointId)
+    }

Review Comment:
   This can definitely be refactored; you're using the same code snippet in all 
tests? Seems like a `StateStoreCheckpointIdTestUtils` could be good.



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala:
##########
@@ -21,15 +21,168 @@ import java.io.File
 
 import scala.jdk.CollectionConverters.SetHasAsScala
 
+import org.apache.hadoop.conf.Configuration
 import org.scalatest.time.{Minute, Span}
 
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.sql.execution.streaming.{MemoryStream, 
StreamingQueryWrapper}
 import org.apache.spark.sql.functions.count
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming._
 import org.apache.spark.sql.streaming.OutputMode.Update
+import org.apache.spark.sql.types.StructType
 import org.apache.spark.util.Utils
 
+object TestStateStoreWrapper {

Review Comment:
   Probably will want a better name here.
   
   ```suggestion
   object CheckpointInfoCollectingStateStore {
   ```
   
   ?



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala:
##########
@@ -222,6 +375,456 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest
     }
   }
 
+  testWithChangelogCheckpointingEnabled(s"checkpointFormatVersion2") {
+    withSQLConf((SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key, "2")) {
+      val checkpointDir = Utils.createTempDir().getCanonicalFile
+      checkpointDir.delete()
+
+      val inputData = MemoryStream[Int]
+      val aggregated =
+        inputData
+          .toDF()
+          .groupBy($"value")
+          .agg(count("*"))
+          .as[(Int, Long)]
+
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 3),
+        CheckLastBatch((3, 1)),
+        AddData(inputData, 3, 2),
+        CheckLastBatch((3, 2), (2, 1)),
+        StopStream
+      )
+
+      // Run the stream with changelog checkpointing enabled.
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 3, 2, 1),
+        CheckLastBatch((3, 3), (2, 2), (1, 1)),
+        // By default we run in new tuple mode.
+        AddData(inputData, 4, 4, 4, 4),
+        CheckLastBatch((4, 4)),
+        AddData(inputData, 5, 5),
+        CheckLastBatch((5, 2))
+      )
+
+      // Run the stream with changelog checkpointing disabled.
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 4),
+        CheckLastBatch((4, 5))
+      )
+    }
+  }
+
+  testWithChangelogCheckpointingEnabled(s"checkpointFormatVersion2 validate 
ID") {
+    val providerClassName = 
classOf[TestStateStoreProviderWrapper].getCanonicalName
+    TestStateStoreWrapper.clear()
+    withSQLConf(
+      (SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName),
+      (SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "2"),
+      (SQLConf.SHUFFLE_PARTITIONS.key, "2")) {
+      val checkpointDir = Utils.createTempDir().getCanonicalFile
+      checkpointDir.delete()
+
+      val inputData = MemoryStream[Int]
+      val aggregated =
+        inputData
+          .toDF()
+          .groupBy($"value")
+          .agg(count("*"))
+          .as[(Int, Long)]
+
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 3),
+        CheckLastBatch((3, 1)),
+        AddData(inputData, 3, 2),
+        CheckLastBatch((3, 2), (2, 1)),
+        StopStream
+      )
+
+      // Test recovery
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 3, 2, 1),
+        CheckLastBatch((3, 3), (2, 2), (1, 1)),
+        // By default we run in new tuple mode.
+        AddData(inputData, 4, 4, 4, 4),
+        CheckLastBatch((4, 4)),
+        AddData(inputData, 5, 5),
+        CheckLastBatch((5, 2)),
+        StopStream
+      )
+
+      // crash recovery again
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 4),
+        CheckLastBatch((4, 5))
+      )
+    }
+    val checkpointInfoList = TestStateStoreWrapper.getCheckpointInfos
+    // We have 6 batches, 2 partitions, and 1 state store per batch
+    assert(checkpointInfoList.size == 12)
+    checkpointInfoList.foreach { l =>
+      assert(l.checkpointId.isDefined)
+      if (l.batchVersion == 2 || l.batchVersion == 4 || l.batchVersion == 5) {
+        assert(l.baseCheckpointId.isDefined)
+      }
+    }
+    assert(checkpointInfoList.count(_.partitionId == 0) == 6)
+    assert(checkpointInfoList.count(_.partitionId == 1) == 6)
+    for (i <- 1 to 6) {
+      assert(checkpointInfoList.count(_.batchVersion == i) == 2)
+    }
+    for {
+      a <- checkpointInfoList
+      b <- checkpointInfoList
+      if a.partitionId == b.partitionId && a.batchVersion == b.batchVersion + 1
+    } {
+      // if batch version exists, it should be the same as the checkpoint ID 
of the previous batch
+      assert(!a.baseCheckpointId.isDefined || b.checkpointId == 
a.baseCheckpointId)
+    }
+  }
+
+  testWithChangelogCheckpointingEnabled(
+    s"checkpointFormatVersion2 validate ID with dedup and groupBy") {
+    val providerClassName = 
classOf[TestStateStoreProviderWrapper].getCanonicalName
+    TestStateStoreWrapper.clear()
+    withSQLConf(
+      (SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName),
+      (SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "2"),
+      (SQLConf.SHUFFLE_PARTITIONS.key, "2")) {
+      val checkpointDir = Utils.createTempDir().getCanonicalFile
+      checkpointDir.delete()

Review Comment:
   Why do you need to delete this? And why not use `withTempDir`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to