neilramaswamy commented on code in PR #47895:
URL: https://github.com/apache/spark/pull/47895#discussion_r1779272544
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala:
##########
@@ -72,7 +72,8 @@ class RocksDB(
localRootDir: File = Utils.createTempDir(),
hadoopConf: Configuration = new Configuration,
loggingId: String = "",
- useColumnFamilies: Boolean = false) extends Logging {
+ useColumnFamilies: Boolean = false,
+ ifEnableCheckpointId: Boolean = false) extends Logging {
Review Comment:
More consistent to call it `enableStateStoreCheckpointIds`.
I also think that the term "checkpoint ID" is very confusing. The term makes
it feel like it's an ID for an _entire_ checkpoint, when really it's an ID for
a particular state store that has been checkpointed.
I know it's a tedious modification to make. I would be happy to alleviate
some of this work by creating a branch with that change and putting up a PR
that you can merge back in this branch.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala:
##########
@@ -808,6 +824,45 @@ object SymmetricHashJoinStateManager {
result
}
+ def mergeStateStoreCheckpointInfo(
Review Comment:
I already commented about this elsewhere (that it shouldn't be in the
symmetric hash join state manager), but this was confusing to read because it
is used in two places:
1. To merge the key with index state store with the key with index to value
state store
2. To merge the results from step (1) for both the left and the right into
one result
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala:
##########
@@ -190,6 +190,11 @@ trait StateStore extends ReadStateStore {
/** Current metrics of the state store */
def metrics: StateStoreMetrics
+ /** Return information on recently generated checkpoints */
+ def getCheckpointInfo: StateStoreCheckpointInfo = {
+ StateStoreCheckpointInfo(-1, -1, None, None)
Review Comment:
Why default implementation? If all the sub-classes are overriding it, let's
just make it required with no default.
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala:
##########
@@ -222,6 +375,456 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest
}
}
+ testWithChangelogCheckpointingEnabled(s"checkpointFormatVersion2") {
+ withSQLConf((SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key, "2")) {
+ val checkpointDir = Utils.createTempDir().getCanonicalFile
+ checkpointDir.delete()
+
+ val inputData = MemoryStream[Int]
+ val aggregated =
+ inputData
+ .toDF()
+ .groupBy($"value")
+ .agg(count("*"))
+ .as[(Int, Long)]
+
+ testStream(aggregated, Update)(
+ StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+ AddData(inputData, 3),
+ CheckLastBatch((3, 1)),
+ AddData(inputData, 3, 2),
+ CheckLastBatch((3, 2), (2, 1)),
+ StopStream
+ )
+
+ // Run the stream with changelog checkpointing enabled.
+ testStream(aggregated, Update)(
+ StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+ AddData(inputData, 3, 2, 1),
+ CheckLastBatch((3, 3), (2, 2), (1, 1)),
+ // By default we run in new tuple mode.
+ AddData(inputData, 4, 4, 4, 4),
+ CheckLastBatch((4, 4)),
+ AddData(inputData, 5, 5),
+ CheckLastBatch((5, 2))
+ )
+
+ // Run the stream with changelog checkpointing disabled.
+ testStream(aggregated, Update)(
+ StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+ AddData(inputData, 4),
+ CheckLastBatch((4, 5))
+ )
+ }
+ }
+
+ testWithChangelogCheckpointingEnabled(s"checkpointFormatVersion2 validate
ID") {
+ val providerClassName =
classOf[TestStateStoreProviderWrapper].getCanonicalName
+ TestStateStoreWrapper.clear()
+ withSQLConf(
+ (SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName),
+ (SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "2"),
+ (SQLConf.SHUFFLE_PARTITIONS.key, "2")) {
+ val checkpointDir = Utils.createTempDir().getCanonicalFile
+ checkpointDir.delete()
+
+ val inputData = MemoryStream[Int]
+ val aggregated =
+ inputData
+ .toDF()
+ .groupBy($"value")
+ .agg(count("*"))
+ .as[(Int, Long)]
+
+ testStream(aggregated, Update)(
+ StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+ AddData(inputData, 3),
+ CheckLastBatch((3, 1)),
+ AddData(inputData, 3, 2),
+ CheckLastBatch((3, 2), (2, 1)),
+ StopStream
+ )
+
+ // Test recovery
+ testStream(aggregated, Update)(
+ StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+ AddData(inputData, 3, 2, 1),
+ CheckLastBatch((3, 3), (2, 2), (1, 1)),
+ // By default we run in new tuple mode.
+ AddData(inputData, 4, 4, 4, 4),
+ CheckLastBatch((4, 4)),
+ AddData(inputData, 5, 5),
+ CheckLastBatch((5, 2)),
+ StopStream
+ )
+
+ // crash recovery again
Review Comment:
How is this simulating "crash recovery"? I think the previous stream would
have executed cleanly.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala:
##########
@@ -233,6 +238,15 @@ case class StateStoreMetrics(
memoryUsedBytes: Long,
customMetrics: Map[StateStoreCustomMetric, Long])
+case class StateStoreCheckpointInfo(
+ partitionId: Int,
+ batchVersion: Long,
+ // The checkpoint ID for a checkpoint at `batchVersion`. This is used to
identify the checkpoint
+ checkpointId: Option[String],
Review Comment:
If we use a String, we need to mention that it's not necessarily _one_
checkpoint ID. It could be many, comma-separated.
But to be honest, I don't think we should be using `String` here, because
it's ambiguous. Is it 1 checkpoint? 4 checkpoints? You cannot simply tell by
looking at the code. The naming is also off in the case of multiple
checkpoints; it's `StateStore*s*CheckpointInfo`.
I think it makes more sense for us to return, all the way through the
accumulator, a `Seq[String]`. Then, the only place that the merging should
happen is inside of `def getCheckpointInfo` inside of `StateStoreWriter`. This
avoids us from awkwardly having one-off merging logic inside of the s/s join,
even though I know it's the only place.
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala:
##########
@@ -222,6 +375,456 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest
}
}
+ testWithChangelogCheckpointingEnabled(s"checkpointFormatVersion2") {
+ withSQLConf((SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key, "2")) {
+ val checkpointDir = Utils.createTempDir().getCanonicalFile
+ checkpointDir.delete()
+
+ val inputData = MemoryStream[Int]
+ val aggregated =
+ inputData
+ .toDF()
+ .groupBy($"value")
+ .agg(count("*"))
+ .as[(Int, Long)]
+
+ testStream(aggregated, Update)(
+ StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+ AddData(inputData, 3),
+ CheckLastBatch((3, 1)),
+ AddData(inputData, 3, 2),
+ CheckLastBatch((3, 2), (2, 1)),
+ StopStream
+ )
+
+ // Run the stream with changelog checkpointing enabled.
Review Comment:
How exactly are you enabling/disabling changelog checkpointing here?
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala:
##########
@@ -222,6 +375,456 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest
}
}
+ testWithChangelogCheckpointingEnabled(s"checkpointFormatVersion2") {
+ withSQLConf((SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key, "2")) {
+ val checkpointDir = Utils.createTempDir().getCanonicalFile
+ checkpointDir.delete()
+
+ val inputData = MemoryStream[Int]
+ val aggregated =
+ inputData
+ .toDF()
+ .groupBy($"value")
+ .agg(count("*"))
+ .as[(Int, Long)]
+
+ testStream(aggregated, Update)(
+ StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+ AddData(inputData, 3),
+ CheckLastBatch((3, 1)),
+ AddData(inputData, 3, 2),
+ CheckLastBatch((3, 2), (2, 1)),
+ StopStream
+ )
+
+ // Run the stream with changelog checkpointing enabled.
+ testStream(aggregated, Update)(
+ StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+ AddData(inputData, 3, 2, 1),
+ CheckLastBatch((3, 3), (2, 2), (1, 1)),
+ // By default we run in new tuple mode.
+ AddData(inputData, 4, 4, 4, 4),
+ CheckLastBatch((4, 4)),
+ AddData(inputData, 5, 5),
+ CheckLastBatch((5, 2))
+ )
+
+ // Run the stream with changelog checkpointing disabled.
+ testStream(aggregated, Update)(
+ StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+ AddData(inputData, 4),
+ CheckLastBatch((4, 5))
+ )
+ }
+ }
+
+ testWithChangelogCheckpointingEnabled(s"checkpointFormatVersion2 validate
ID") {
+ val providerClassName =
classOf[TestStateStoreProviderWrapper].getCanonicalName
+ TestStateStoreWrapper.clear()
+ withSQLConf(
+ (SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName),
+ (SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "2"),
+ (SQLConf.SHUFFLE_PARTITIONS.key, "2")) {
+ val checkpointDir = Utils.createTempDir().getCanonicalFile
+ checkpointDir.delete()
+
+ val inputData = MemoryStream[Int]
+ val aggregated =
+ inputData
+ .toDF()
+ .groupBy($"value")
+ .agg(count("*"))
+ .as[(Int, Long)]
+
+ testStream(aggregated, Update)(
+ StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+ AddData(inputData, 3),
+ CheckLastBatch((3, 1)),
+ AddData(inputData, 3, 2),
+ CheckLastBatch((3, 2), (2, 1)),
+ StopStream
+ )
+
+ // Test recovery
+ testStream(aggregated, Update)(
+ StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+ AddData(inputData, 3, 2, 1),
+ CheckLastBatch((3, 3), (2, 2), (1, 1)),
+ // By default we run in new tuple mode.
+ AddData(inputData, 4, 4, 4, 4),
+ CheckLastBatch((4, 4)),
+ AddData(inputData, 5, 5),
+ CheckLastBatch((5, 2)),
+ StopStream
+ )
+
+ // crash recovery again
+ testStream(aggregated, Update)(
+ StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+ AddData(inputData, 4),
+ CheckLastBatch((4, 5))
+ )
+ }
+ val checkpointInfoList = TestStateStoreWrapper.getCheckpointInfos
+ // We have 6 batches, 2 partitions, and 1 state store per batch
+ assert(checkpointInfoList.size == 12)
+ checkpointInfoList.foreach { l =>
+ assert(l.checkpointId.isDefined)
+ if (l.batchVersion == 2 || l.batchVersion == 4 || l.batchVersion == 5) {
Review Comment:
Sorry, I don't follow this. Why are we just checking these specific
`batchVersions`? Shouldn't all of them, 0 to 5 inclusive, be present?
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala:
##########
@@ -222,6 +375,456 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest
}
}
+ testWithChangelogCheckpointingEnabled(s"checkpointFormatVersion2") {
+ withSQLConf((SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key, "2")) {
+ val checkpointDir = Utils.createTempDir().getCanonicalFile
+ checkpointDir.delete()
+
+ val inputData = MemoryStream[Int]
+ val aggregated =
+ inputData
+ .toDF()
+ .groupBy($"value")
+ .agg(count("*"))
+ .as[(Int, Long)]
+
+ testStream(aggregated, Update)(
+ StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+ AddData(inputData, 3),
+ CheckLastBatch((3, 1)),
+ AddData(inputData, 3, 2),
+ CheckLastBatch((3, 2), (2, 1)),
+ StopStream
+ )
+
+ // Run the stream with changelog checkpointing enabled.
+ testStream(aggregated, Update)(
+ StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+ AddData(inputData, 3, 2, 1),
+ CheckLastBatch((3, 3), (2, 2), (1, 1)),
+ // By default we run in new tuple mode.
+ AddData(inputData, 4, 4, 4, 4),
+ CheckLastBatch((4, 4)),
+ AddData(inputData, 5, 5),
+ CheckLastBatch((5, 2))
+ )
+
+ // Run the stream with changelog checkpointing disabled.
+ testStream(aggregated, Update)(
+ StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+ AddData(inputData, 4),
+ CheckLastBatch((4, 5))
+ )
+ }
+ }
+
+ testWithChangelogCheckpointingEnabled(s"checkpointFormatVersion2 validate
ID") {
+ val providerClassName =
classOf[TestStateStoreProviderWrapper].getCanonicalName
+ TestStateStoreWrapper.clear()
+ withSQLConf(
+ (SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName),
+ (SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "2"),
+ (SQLConf.SHUFFLE_PARTITIONS.key, "2")) {
+ val checkpointDir = Utils.createTempDir().getCanonicalFile
+ checkpointDir.delete()
+
+ val inputData = MemoryStream[Int]
+ val aggregated =
+ inputData
+ .toDF()
+ .groupBy($"value")
+ .agg(count("*"))
+ .as[(Int, Long)]
+
+ testStream(aggregated, Update)(
+ StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+ AddData(inputData, 3),
+ CheckLastBatch((3, 1)),
+ AddData(inputData, 3, 2),
+ CheckLastBatch((3, 2), (2, 1)),
+ StopStream
+ )
+
+ // Test recovery
+ testStream(aggregated, Update)(
+ StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+ AddData(inputData, 3, 2, 1),
+ CheckLastBatch((3, 3), (2, 2), (1, 1)),
+ // By default we run in new tuple mode.
+ AddData(inputData, 4, 4, 4, 4),
+ CheckLastBatch((4, 4)),
+ AddData(inputData, 5, 5),
+ CheckLastBatch((5, 2)),
+ StopStream
+ )
+
+ // crash recovery again
+ testStream(aggregated, Update)(
+ StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+ AddData(inputData, 4),
+ CheckLastBatch((4, 5))
+ )
+ }
+ val checkpointInfoList = TestStateStoreWrapper.getCheckpointInfos
+ // We have 6 batches, 2 partitions, and 1 state store per batch
+ assert(checkpointInfoList.size == 12)
+ checkpointInfoList.foreach { l =>
+ assert(l.checkpointId.isDefined)
+ if (l.batchVersion == 2 || l.batchVersion == 4 || l.batchVersion == 5) {
+ assert(l.baseCheckpointId.isDefined)
+ }
+ }
+ assert(checkpointInfoList.count(_.partitionId == 0) == 6)
+ assert(checkpointInfoList.count(_.partitionId == 1) == 6)
+ for (i <- 1 to 6) {
+ assert(checkpointInfoList.count(_.batchVersion == i) == 2)
+ }
+ for {
+ a <- checkpointInfoList
+ b <- checkpointInfoList
+ if a.partitionId == b.partitionId && a.batchVersion == b.batchVersion + 1
+ } {
+ // if batch version exists, it should be the same as the checkpoint ID
of the previous batch
+ assert(!a.baseCheckpointId.isDefined || b.checkpointId ==
a.baseCheckpointId)
+ }
+ }
+
+ testWithChangelogCheckpointingEnabled(
+ s"checkpointFormatVersion2 validate ID with dedup and groupBy") {
+ val providerClassName =
classOf[TestStateStoreProviderWrapper].getCanonicalName
+ TestStateStoreWrapper.clear()
Review Comment:
All of these can be refactored into a `beforeEach` in the class
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala:
##########
@@ -900,12 +907,42 @@ class MicroBatchExecution(
*/
protected def markMicroBatchExecutionStart(execCtx:
MicroBatchExecutionContext): Unit = {}
+ private def updateCheckpointIdForOperator(
+ execCtx: MicroBatchExecutionContext,
+ opId: Long,
+ checkpointInfo: Array[StateStoreCheckpointInfo]): Unit = {
+ // TODO validate baseCheckpointId
+ checkpointInfo.map(_.batchVersion).foreach { v =>
+ assert(
+ execCtx.batchId == -1 || v == execCtx.batchId + 1,
+ s"version $v doesn't match current Batch ID ${execCtx.batchId}")
Review Comment:
I don't understand the assertion here. We say `v == batchId + 1` and then
assert that `v` must match `batchId`?
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala:
##########
@@ -222,6 +375,456 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest
}
}
+ testWithChangelogCheckpointingEnabled(s"checkpointFormatVersion2") {
Review Comment:
From what I can tell, none of the new suites that were added cover the edge
case in the design doc, right? There's no speculative execution here.
I _think_ what you could do is create new manual StateStores that simulate
the race
[here](https://docs.google.com/document/d/1uWRMbN927cRXhSm5oeV3pbwb6o73am4r1ckEJDhAHa0/edit),
without needing to write a query that does this. Right?
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala:
##########
@@ -222,6 +375,456 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest
}
}
+ testWithChangelogCheckpointingEnabled(s"checkpointFormatVersion2") {
+ withSQLConf((SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key, "2")) {
+ val checkpointDir = Utils.createTempDir().getCanonicalFile
+ checkpointDir.delete()
+
+ val inputData = MemoryStream[Int]
+ val aggregated =
+ inputData
+ .toDF()
+ .groupBy($"value")
+ .agg(count("*"))
+ .as[(Int, Long)]
+
+ testStream(aggregated, Update)(
+ StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+ AddData(inputData, 3),
+ CheckLastBatch((3, 1)),
+ AddData(inputData, 3, 2),
+ CheckLastBatch((3, 2), (2, 1)),
+ StopStream
+ )
+
+ // Run the stream with changelog checkpointing enabled.
+ testStream(aggregated, Update)(
+ StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+ AddData(inputData, 3, 2, 1),
+ CheckLastBatch((3, 3), (2, 2), (1, 1)),
+ // By default we run in new tuple mode.
+ AddData(inputData, 4, 4, 4, 4),
+ CheckLastBatch((4, 4)),
+ AddData(inputData, 5, 5),
+ CheckLastBatch((5, 2))
+ )
+
+ // Run the stream with changelog checkpointing disabled.
+ testStream(aggregated, Update)(
+ StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+ AddData(inputData, 4),
+ CheckLastBatch((4, 5))
+ )
+ }
+ }
+
+ testWithChangelogCheckpointingEnabled(s"checkpointFormatVersion2 validate
ID") {
+ val providerClassName =
classOf[TestStateStoreProviderWrapper].getCanonicalName
+ TestStateStoreWrapper.clear()
+ withSQLConf(
+ (SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName),
+ (SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "2"),
+ (SQLConf.SHUFFLE_PARTITIONS.key, "2")) {
+ val checkpointDir = Utils.createTempDir().getCanonicalFile
+ checkpointDir.delete()
+
+ val inputData = MemoryStream[Int]
+ val aggregated =
+ inputData
+ .toDF()
+ .groupBy($"value")
+ .agg(count("*"))
+ .as[(Int, Long)]
+
+ testStream(aggregated, Update)(
+ StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+ AddData(inputData, 3),
+ CheckLastBatch((3, 1)),
+ AddData(inputData, 3, 2),
+ CheckLastBatch((3, 2), (2, 1)),
+ StopStream
+ )
+
+ // Test recovery
+ testStream(aggregated, Update)(
+ StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+ AddData(inputData, 3, 2, 1),
+ CheckLastBatch((3, 3), (2, 2), (1, 1)),
+ // By default we run in new tuple mode.
+ AddData(inputData, 4, 4, 4, 4),
+ CheckLastBatch((4, 4)),
+ AddData(inputData, 5, 5),
+ CheckLastBatch((5, 2)),
+ StopStream
+ )
+
+ // crash recovery again
+ testStream(aggregated, Update)(
+ StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+ AddData(inputData, 4),
+ CheckLastBatch((4, 5))
+ )
+ }
+ val checkpointInfoList = TestStateStoreWrapper.getCheckpointInfos
+ // We have 6 batches, 2 partitions, and 1 state store per batch
+ assert(checkpointInfoList.size == 12)
+ checkpointInfoList.foreach { l =>
+ assert(l.checkpointId.isDefined)
+ if (l.batchVersion == 2 || l.batchVersion == 4 || l.batchVersion == 5) {
+ assert(l.baseCheckpointId.isDefined)
+ }
+ }
+ assert(checkpointInfoList.count(_.partitionId == 0) == 6)
+ assert(checkpointInfoList.count(_.partitionId == 1) == 6)
+ for (i <- 1 to 6) {
+ assert(checkpointInfoList.count(_.batchVersion == i) == 2)
+ }
+ for {
+ a <- checkpointInfoList
+ b <- checkpointInfoList
+ if a.partitionId == b.partitionId && a.batchVersion == b.batchVersion + 1
+ } {
+ // if batch version exists, it should be the same as the checkpoint ID
of the previous batch
+ assert(!a.baseCheckpointId.isDefined || b.checkpointId ==
a.baseCheckpointId)
+ }
Review Comment:
This can definitely be refactored; you're using the same code snippet in all
tests? Seems like a `StateStoreCheckpointIdTestUtils` could be good.
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala:
##########
@@ -21,15 +21,168 @@ import java.io.File
import scala.jdk.CollectionConverters.SetHasAsScala
+import org.apache.hadoop.conf.Configuration
import org.scalatest.time.{Minute, Span}
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.streaming.{MemoryStream,
StreamingQueryWrapper}
import org.apache.spark.sql.functions.count
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming._
import org.apache.spark.sql.streaming.OutputMode.Update
+import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils
+object TestStateStoreWrapper {
Review Comment:
Probably will want a better name here.
```suggestion
object CheckpointInfoCollectingStateStore {
```
?
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala:
##########
@@ -222,6 +375,456 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest
}
}
+ testWithChangelogCheckpointingEnabled(s"checkpointFormatVersion2") {
+ withSQLConf((SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key, "2")) {
+ val checkpointDir = Utils.createTempDir().getCanonicalFile
+ checkpointDir.delete()
+
+ val inputData = MemoryStream[Int]
+ val aggregated =
+ inputData
+ .toDF()
+ .groupBy($"value")
+ .agg(count("*"))
+ .as[(Int, Long)]
+
+ testStream(aggregated, Update)(
+ StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+ AddData(inputData, 3),
+ CheckLastBatch((3, 1)),
+ AddData(inputData, 3, 2),
+ CheckLastBatch((3, 2), (2, 1)),
+ StopStream
+ )
+
+ // Run the stream with changelog checkpointing enabled.
+ testStream(aggregated, Update)(
+ StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+ AddData(inputData, 3, 2, 1),
+ CheckLastBatch((3, 3), (2, 2), (1, 1)),
+ // By default we run in new tuple mode.
+ AddData(inputData, 4, 4, 4, 4),
+ CheckLastBatch((4, 4)),
+ AddData(inputData, 5, 5),
+ CheckLastBatch((5, 2))
+ )
+
+ // Run the stream with changelog checkpointing disabled.
+ testStream(aggregated, Update)(
+ StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+ AddData(inputData, 4),
+ CheckLastBatch((4, 5))
+ )
+ }
+ }
+
+ testWithChangelogCheckpointingEnabled(s"checkpointFormatVersion2 validate
ID") {
+ val providerClassName =
classOf[TestStateStoreProviderWrapper].getCanonicalName
+ TestStateStoreWrapper.clear()
+ withSQLConf(
+ (SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName),
+ (SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "2"),
+ (SQLConf.SHUFFLE_PARTITIONS.key, "2")) {
+ val checkpointDir = Utils.createTempDir().getCanonicalFile
+ checkpointDir.delete()
+
+ val inputData = MemoryStream[Int]
+ val aggregated =
+ inputData
+ .toDF()
+ .groupBy($"value")
+ .agg(count("*"))
+ .as[(Int, Long)]
+
+ testStream(aggregated, Update)(
+ StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+ AddData(inputData, 3),
+ CheckLastBatch((3, 1)),
+ AddData(inputData, 3, 2),
+ CheckLastBatch((3, 2), (2, 1)),
+ StopStream
+ )
+
+ // Test recovery
+ testStream(aggregated, Update)(
+ StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+ AddData(inputData, 3, 2, 1),
+ CheckLastBatch((3, 3), (2, 2), (1, 1)),
+ // By default we run in new tuple mode.
+ AddData(inputData, 4, 4, 4, 4),
+ CheckLastBatch((4, 4)),
+ AddData(inputData, 5, 5),
+ CheckLastBatch((5, 2)),
+ StopStream
+ )
+
+ // crash recovery again
+ testStream(aggregated, Update)(
+ StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+ AddData(inputData, 4),
+ CheckLastBatch((4, 5))
+ )
+ }
+ val checkpointInfoList = TestStateStoreWrapper.getCheckpointInfos
+ // We have 6 batches, 2 partitions, and 1 state store per batch
+ assert(checkpointInfoList.size == 12)
+ checkpointInfoList.foreach { l =>
+ assert(l.checkpointId.isDefined)
+ if (l.batchVersion == 2 || l.batchVersion == 4 || l.batchVersion == 5) {
+ assert(l.baseCheckpointId.isDefined)
+ }
+ }
+ assert(checkpointInfoList.count(_.partitionId == 0) == 6)
+ assert(checkpointInfoList.count(_.partitionId == 1) == 6)
+ for (i <- 1 to 6) {
+ assert(checkpointInfoList.count(_.batchVersion == i) == 2)
+ }
+ for {
+ a <- checkpointInfoList
+ b <- checkpointInfoList
+ if a.partitionId == b.partitionId && a.batchVersion == b.batchVersion + 1
+ } {
+ // if batch version exists, it should be the same as the checkpoint ID
of the previous batch
+ assert(!a.baseCheckpointId.isDefined || b.checkpointId ==
a.baseCheckpointId)
+ }
+ }
+
+ testWithChangelogCheckpointingEnabled(
+ s"checkpointFormatVersion2 validate ID with dedup and groupBy") {
+ val providerClassName =
classOf[TestStateStoreProviderWrapper].getCanonicalName
+ TestStateStoreWrapper.clear()
+ withSQLConf(
+ (SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName),
+ (SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "2"),
+ (SQLConf.SHUFFLE_PARTITIONS.key, "2")) {
+ val checkpointDir = Utils.createTempDir().getCanonicalFile
+ checkpointDir.delete()
Review Comment:
Why do you need to delete this? And why not use `withTempDir`?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]