micheal-o commented on code in PR #53720:
URL: https://github.com/apache/spark/pull/53720#discussion_r2687692566
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala:
##########
@@ -289,12 +290,27 @@ class OfflineStateRepartitionRunner(
}
}
- private def commitBatch(newBatchId: Long, lastCommittedBatchId: Long): Unit
= {
+ private def commitBatch(
+ newBatchId: Long,
+ lastCommittedBatchId: Long,
+ opIdToStateStoreCkptInfo: Map[Long,
Array[Array[StateStoreCheckpointInfo]]]): Unit = {
+ val enableCheckpointId = StatefulOperatorStateInfo.
+ enableStateStoreCheckpointIds(sparkSession.sessionState.conf)
val latestCommit =
checkpointMetadata.commitLog.get(lastCommittedBatchId).get
+ val commitMetadata = if (enableCheckpointId) {
+ val opIdToPartitionCkptInfo: Map[Long, Array[Array[String]]] =
+ opIdToStateStoreCkptInfo.map {
+ case (operator, partitionSeq) =>
+ operator -> partitionSeq.map(storeSeq =>
+ storeSeq.flatMap(info => info.stateStoreCkptId))
Review Comment:
Why are we still doing this conversion here. I mentioned this previously
that it should be done in StateRewriter, to avoid having all callers
implementing this additional conversion
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala:
##########
@@ -289,12 +290,27 @@ class OfflineStateRepartitionRunner(
}
}
- private def commitBatch(newBatchId: Long, lastCommittedBatchId: Long): Unit
= {
+ private def commitBatch(
+ newBatchId: Long,
+ lastCommittedBatchId: Long,
+ opIdToStateStoreCkptInfo: Map[Long,
Array[Array[StateStoreCheckpointInfo]]]): Unit = {
+ val enableCheckpointId = StatefulOperatorStateInfo.
+ enableStateStoreCheckpointIds(sparkSession.sessionState.conf)
val latestCommit =
checkpointMetadata.commitLog.get(lastCommittedBatchId).get
+ val commitMetadata = if (enableCheckpointId) {
Review Comment:
Why are we doing this check? Why not just check if checkpointId was passed
in?
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala:
##########
@@ -443,77 +443,84 @@ class RocksDB(
private def loadWithCheckpointId(
version: Long,
stateStoreCkptId: Option[String],
- readOnly: Boolean = false): RocksDB = {
+ readOnly: Boolean = false,
+ loadEmpty: Boolean = false): RocksDB = {
// An array contains lineage information from [snapShotVersion, version]
// (inclusive in both ends)
var currVersionLineage: Array[LineageItem] =
lineageManager.getLineageForCurrVersion()
try {
- if (loadedVersion != version || (loadedStateStoreCkptId.isEmpty ||
- stateStoreCkptId.get != loadedStateStoreCkptId.get)) {
+ if (loadEmpty || loadedVersion != version ||
loadedStateStoreCkptId.isEmpty ||
+ stateStoreCkptId.get != loadedStateStoreCkptId.get) {
closeDB(ignoreException = false)
-
- val (latestSnapshotVersion, latestSnapshotUniqueId) = {
- // Special handling when version is 0.
- // When loading the very first version (0), stateStoreCkptId does
not need to be defined
- // because there won't be 0.changelog / 0.zip file created in
RocksDB under v2.
- if (version == 0) {
- assert(stateStoreCkptId.isEmpty,
- "stateStoreCkptId should be empty when version is zero")
- (0L, None)
- // When there is a snapshot file, it is the ground truth, we can skip
- // reconstructing the lineage from changelog file.
- } else if (fileManager.existsSnapshotFile(version,
stateStoreCkptId)) {
- currVersionLineage = Array(LineageItem(version,
stateStoreCkptId.get))
- (version, stateStoreCkptId)
- } else {
- currVersionLineage = getLineageFromChangelogFile(version,
stateStoreCkptId) :+
- LineageItem(version, stateStoreCkptId.get)
- currVersionLineage = currVersionLineage.sortBy(_.version)
-
- val latestSnapshotVersionsAndUniqueId =
-
fileManager.getLatestSnapshotVersionAndUniqueIdFromLineage(currVersionLineage)
- latestSnapshotVersionsAndUniqueId match {
- case Some(pair) => (pair._1, Option(pair._2))
- case None if currVersionLineage.head.version == 1L =>
- logDebug(log"Cannot find latest snapshot based on lineage but
first version " +
- log"is 1, use 0 as default. Lineage: ${MDC(LogKeys.LINEAGE,
lineageManager)}")
- (0L, None)
- case _ =>
- throw QueryExecutionErrors.cannotFindBaseSnapshotCheckpoint(
- printLineageItems(currVersionLineage))
+ if (loadEmpty) {
+ require(stateStoreCkptId.isEmpty,
+ "stateStoeCkptId should be empty when loadEmpty is true")
Review Comment:
nit: typo `stateStoreCkptId`
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala:
##########
@@ -530,9 +537,13 @@ class RocksDB(
if (conf.resetStatsOnLoad) {
nativeStats.reset
}
-
- logInfo(log"Loaded ${MDC(LogKeys.VERSION_NUM, version)} " +
- log"with uniqueId ${MDC(LogKeys.UUID, stateStoreCkptId)}")
+ if (loadEmpty) {
+ logInfo(log"Loaded empty store at version ${MDC(LogKeys.VERSION_NUM,
version)} " +
+ log"without uniqueId")
Review Comment:
nit: This can be confusing and someone might think this is for
loadWithoutCheckpointId. Just say "with uniqueId"
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala:
##########
@@ -81,24 +82,46 @@ class StateRewriter(
private val stateRootLocation = new Path(
resolvedCheckpointLocation,
StreamingCheckpointConstants.DIR_NAME_STATE).toString
- def run(): Unit = {
+ // return a Map[operator id, Array[partition -> Array[stateStore ->
StateStoreCheckpointInfo]]]
+ def run(): Map[Long, Array[Array[StateStoreCheckpointInfo]]] = {
logInfo(log"Starting state rewrite for " +
log"checkpointLocation=${MDC(CHECKPOINT_LOCATION,
resolvedCheckpointLocation)}, " +
log"readCheckpointLocation=" +
log"${MDC(CHECKPOINT_LOCATION,
readResolvedCheckpointLocation.getOrElse(""))}, " +
log"readBatchId=${MDC(BATCH_ID, readBatchId)}, " +
log"writeBatchId=${MDC(BATCH_ID, writeBatchId)}")
- val (_, timeTakenMs) = Utils.timeTakenMs {
+ val (checkpointInfos, timeTakenMs) = Utils.timeTakenMs {
runInternal()
}
+ val allCheckpointInfos = checkpointInfos.values.flatMap { checkpointInfos
=>
+ checkpointInfos.flatten.toSeq
+ }.toSeq
logInfo(log"State rewrite completed in ${MDC(DURATION, timeTakenMs)} ms
for " +
- log"checkpointLocation=${MDC(CHECKPOINT_LOCATION,
resolvedCheckpointLocation)}")
+ log"checkpointLocation=${MDC(CHECKPOINT_LOCATION,
resolvedCheckpointLocation)} " +
+ log"checkpointInfos=${MDC(LAST_COMMITTED_CHECKPOINT_ID,
allCheckpointInfos)}"
+ )
+ checkpointInfos.map {
Review Comment:
this should be a separate func. Then it should be called within `runInternal`
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala:
##########
@@ -81,24 +82,46 @@ class StateRewriter(
private val stateRootLocation = new Path(
resolvedCheckpointLocation,
StreamingCheckpointConstants.DIR_NAME_STATE).toString
- def run(): Unit = {
+ // return a Map[operator id, Array[partition -> Array[stateStore ->
StateStoreCheckpointInfo]]]
+ def run(): Map[Long, Array[Array[StateStoreCheckpointInfo]]] = {
Review Comment:
Also, should return the structure needed in commit Log. So callers don't
need to do extra conversion work.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala:
##########
@@ -81,24 +82,46 @@ class StateRewriter(
private val stateRootLocation = new Path(
resolvedCheckpointLocation,
StreamingCheckpointConstants.DIR_NAME_STATE).toString
- def run(): Unit = {
+ // return a Map[operator id, Array[partition -> Array[stateStore ->
StateStoreCheckpointInfo]]]
+ def run(): Map[Long, Array[Array[StateStoreCheckpointInfo]]] = {
logInfo(log"Starting state rewrite for " +
log"checkpointLocation=${MDC(CHECKPOINT_LOCATION,
resolvedCheckpointLocation)}, " +
log"readCheckpointLocation=" +
log"${MDC(CHECKPOINT_LOCATION,
readResolvedCheckpointLocation.getOrElse(""))}, " +
log"readBatchId=${MDC(BATCH_ID, readBatchId)}, " +
log"writeBatchId=${MDC(BATCH_ID, writeBatchId)}")
- val (_, timeTakenMs) = Utils.timeTakenMs {
+ val (checkpointInfos, timeTakenMs) = Utils.timeTakenMs {
runInternal()
}
+ val allCheckpointInfos = checkpointInfos.values.flatMap { checkpointInfos
=>
+ checkpointInfos.flatten.toSeq
+ }.toSeq
logInfo(log"State rewrite completed in ${MDC(DURATION, timeTakenMs)} ms
for " +
- log"checkpointLocation=${MDC(CHECKPOINT_LOCATION,
resolvedCheckpointLocation)}")
+ log"checkpointLocation=${MDC(CHECKPOINT_LOCATION,
resolvedCheckpointLocation)} " +
+ log"checkpointInfos=${MDC(LAST_COMMITTED_CHECKPOINT_ID,
allCheckpointInfos)}"
+ )
+ checkpointInfos.map {
Review Comment:
The `run` function here should just return the checkpointIds without doing
any extra conversion work
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala:
##########
@@ -81,24 +82,46 @@ class StateRewriter(
private val stateRootLocation = new Path(
resolvedCheckpointLocation,
StreamingCheckpointConstants.DIR_NAME_STATE).toString
- def run(): Unit = {
+ // return a Map[operator id, Array[partition -> Array[stateStore ->
StateStoreCheckpointInfo]]]
+ def run(): Map[Long, Array[Array[StateStoreCheckpointInfo]]] = {
logInfo(log"Starting state rewrite for " +
log"checkpointLocation=${MDC(CHECKPOINT_LOCATION,
resolvedCheckpointLocation)}, " +
log"readCheckpointLocation=" +
log"${MDC(CHECKPOINT_LOCATION,
readResolvedCheckpointLocation.getOrElse(""))}, " +
log"readBatchId=${MDC(BATCH_ID, readBatchId)}, " +
log"writeBatchId=${MDC(BATCH_ID, writeBatchId)}")
- val (_, timeTakenMs) = Utils.timeTakenMs {
+ val (checkpointInfos, timeTakenMs) = Utils.timeTakenMs {
runInternal()
}
+ val allCheckpointInfos = checkpointInfos.values.flatMap { checkpointInfos
=>
+ checkpointInfos.flatten.toSeq
Review Comment:
Why are you doing this? Why not log the same thing as what you're returning.
Makes it easier to debug.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala:
##########
@@ -81,24 +82,46 @@ class StateRewriter(
private val stateRootLocation = new Path(
resolvedCheckpointLocation,
StreamingCheckpointConstants.DIR_NAME_STATE).toString
- def run(): Unit = {
+ // return a Map[operator id, Array[partition -> Array[stateStore ->
StateStoreCheckpointInfo]]]
+ def run(): Map[Long, Array[Array[StateStoreCheckpointInfo]]] = {
Review Comment:
Should return an `Option`. Should return None if no checkpointIds
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala:
##########
@@ -342,6 +366,21 @@ class StateRewriter(
None
}
}
+
+ private def setCheckpointVersion(): Unit = {
+ // Setting checkpoint version in sqlConf based on previous commitLog in
case user forgot to
+ // set STATE_STORE_CHECKPOINT_FORMAT_VERSION and crash rewriter.
+ val commitLog = writeCheckpointMetadata.getOrElse(
+ new StreamingQueryCheckpointMetadata(sparkSession,
resolvedCheckpointLocation)).commitLog
Review Comment:
I mentioned in my previous
[comment](https://github.com/apache/spark/pull/53720#discussion_r2679166160)
that we should be checking this for the `readBatchId`. The latest commit in
the commitLog could be a skipped batch (e.g. skipped by repartition) with no
checkpointId.
So we should do something like:
create 2 class vals:
```
private lazy val writeCheckpoint = writeCheckpointMetadata.getOrElse(
new StreamingQueryCheckpointMetadata(sparkSession,
resolvedCheckpointLocation))
private lazy val readCheckpoint = if
(readResolvedCheckpointLocation.isDefined) {
new StreamingQueryCheckpointMetadata(sparkSession,
readResolvedCheckpointLocation.get)
} else {
// Same checkpoint for read & write
writeCheckpoint
}
```
Then in your func:
```
// Using read batch commit since the latest commit could be a skipped batch
val readBatchCommit = readCheckpoint.commitLog.get(readBatchId)
```
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala:
##########
@@ -342,6 +366,21 @@ class StateRewriter(
None
}
}
+
+ private def setCheckpointVersion(): Unit = {
+ // Setting checkpoint version in sqlConf based on previous commitLog in
case user forgot to
+ // set STATE_STORE_CHECKPOINT_FORMAT_VERSION and crash rewriter.
+ val commitLog = writeCheckpointMetadata.getOrElse(
+ new StreamingQueryCheckpointMetadata(sparkSession,
resolvedCheckpointLocation)).commitLog
+ val latestCommitLog = commitLog.getLatest()
+ assert(latestCommitLog.isDefined)
+ if (latestCommitLog.get._2.stateUniqueIds.isDefined) {
+ // Hard code checkpoint version to 2 if previous commit log has
checkpoint id.
+ // This prevents StateRewriter from crashing in case client forget to
set correct
Review Comment:
This prevents the StateRewriter from failing to write the correct state
files ...
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala:
##########
@@ -296,8 +323,21 @@ object OfflineStateRepartitionTestUtils {
verifyOffsetAndCommitLog(
batchId, previousBatchId, expectedShufflePartitions, checkpointMetadata)
verifyPartitionDirs(checkpointLocation, expectedShufflePartitions)
+
+ val serializableConf = new SerializableConfiguration(hadoopConf)
+ val baseOperatorsMetadata = getOperatorMetadata(
+ checkpointLocation, serializableConf, previousBatchId)
+ val repartitionOperatorsMetadata = getOperatorMetadata(
+ checkpointLocation, serializableConf, batchId)
verifyOperatorMetadata(
- batchId, previousBatchId, checkpointLocation, expectedShufflePartitions,
hadoopConf)
+ baseOperatorsMetadata, repartitionOperatorsMetadata,
expectedShufflePartitions)
+ if
(StatefulOperatorStateInfo.enableStateStoreCheckpointIds(spark.sessionState.conf))
{
+ val expectedStoreCnt: Int = baseOperatorsMetadata.head match {
Review Comment:
move this into `verifyCheckpointIds`. Also, we need to get the store count
per operator. Since this verification method can be called for test query with
multiple operators.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala:
##########
@@ -342,6 +366,21 @@ class StateRewriter(
None
}
}
+
+ private def setCheckpointVersion(): Unit = {
+ // Setting checkpoint version in sqlConf based on previous commitLog in
case user forgot to
+ // set STATE_STORE_CHECKPOINT_FORMAT_VERSION and crash rewriter.
+ val commitLog = writeCheckpointMetadata.getOrElse(
+ new StreamingQueryCheckpointMetadata(sparkSession,
resolvedCheckpointLocation)).commitLog
+ val latestCommitLog = commitLog.getLatest()
+ assert(latestCommitLog.isDefined)
+ if (latestCommitLog.get._2.stateUniqueIds.isDefined) {
+ // Hard code checkpoint version to 2 if previous commit log has
checkpoint id.
Review Comment:
"Set checkpoint..." and fix the comment
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala:
##########
@@ -376,23 +416,20 @@ object OfflineStateRepartitionTestUtils {
}
}
- private def verifyOperatorMetadata(
- repartitionBatchId: Long,
- baseBatchId: Long,
+ private def getOperatorMetadata(
checkpointLocation: String,
- expectedShufflePartitions: Int,
- hadoopConf: Configuration): Unit = {
- val serializableConf = new SerializableConfiguration(hadoopConf)
-
- // Read operator metadata for both batches
- val baseMetadataReader = new StateMetadataPartitionReader(
- checkpointLocation, serializableConf, baseBatchId)
+ serializableConf: SerializableConfiguration,
+ batchId: Long
+ ): Array[OperatorStateMetadata] = {
val repartitionMetadataReader = new StateMetadataPartitionReader(
Review Comment:
fix val name
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala:
##########
@@ -3943,47 +3943,55 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures
with SharedSparkSession
}}
}
- test("SPARK-54420: load with createEmpty creates empty store") {
- val remoteDir = Utils.createTempDir().toString
- new File(remoteDir).delete()
-
- withDB(remoteDir) { db =>
- // loading batch 0 with loadEmpty = true
- db.load(0, None, loadEmpty = true)
- assert(iterator(db).isEmpty)
- db.put("a", "1")
- val (version1, _) = db.commit()
- assert(toStr(db.get("a")) === "1")
+ testWithStateStoreCheckpointIds(
+ "SPARK-54420: load with createEmpty creates empty store") { enableCkptId =>
+ val remoteDir = Utils.createTempDir().toString
+ new File(remoteDir).delete()
- // check we can load store normally even the previous one loadEmpty =
true
- db.load(version1)
- db.put("b", "2")
- val (version2, _) = db.commit()
- assert(version2 === version1 + 1)
- assert(toStr(db.get("b")) === "2")
- assert(toStr(db.get("a")) === "1")
+ withDB(remoteDir, enableStateStoreCheckpointIds = enableCkptId) { db =>
+ // loading batch 0 with loadEmpty = true
+ db.load(0, None, loadEmpty = true)
+ assert(iterator(db).isEmpty)
+ db.put("a", "1")
+ val (version1, checkpointInfoV1) = db.commit()
+ assert(toStr(db.get("a")) === "1")
- // load an empty store
- db.load(version2, loadEmpty = true)
- db.put("c", "3")
- val (version3, _) = db.commit()
- assert(db.get("b") === null)
- assert(db.get("a") === null)
- assert(toStr(db.get("c")) === "3")
- assert(version3 === version2 + 1)
-
- // load 2 empty store in a row
- db.load(version3, loadEmpty = true)
- db.put("d", "4")
- val (version4, _) = db.commit()
- assert(db.get("c") === null)
- assert(toStr(db.get("d")) === "4")
- assert(version4 === version3 + 1)
-
- db.load(version4)
- db.put("e", "5")
- db.commit()
- assert(db.iterator().map(toStr).toSet === Set(("d", "4"), ("e", "5")))
+ // check we can load store normally even the previous one loadEmpty =
true
+ db.load(version1, checkpointInfoV1.stateStoreCkptId)
+ db.put("b", "2")
+ val (version2, _) = db.commit()
+ assert(version2 === version1 + 1)
+ assert(toStr(db.get("b")) === "2")
+ assert(toStr(db.get("a")) === "1")
+
+ // load an empty store
+ db.load(version2, loadEmpty = true)
+ db.put("c", "3")
+ val (version3, _) = db.commit()
+ assert(db.get("b") === null)
+ assert(db.get("a") === null)
+ assert(toStr(db.get("c")) === "3")
+ assert(version3 === version2 + 1)
+
+ // load 2 empty store in a row
+ db.load(version3, loadEmpty = true)
+ db.put("d", "4")
+ val (version4, checkpointInfoV4) = db.commit()
+ assert(db.get("c") === null)
+ assert(toStr(db.get("d")) === "4")
+ assert(version4 === version3 + 1)
+
+ db.load(version4, checkpointInfoV4.stateStoreCkptId)
Review Comment:
Use a new `withDB` for this one. To make sure we are reloading from
checkpoint. Because currently, since last commit is v4, the db is already
loaded as 4, so we won't do anything in load and just reuse.
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala:
##########
@@ -95,11 +101,29 @@ class StatePartitionAllColumnFamiliesWriterSuite extends
StateDataSourceTestBase
transformFunc = None,
writeCheckpointMetadata = Some(targetCheckpointMetadata)
)
- rewriter.run()
+ val checkpointInfos = rewriter.run()
+ assert(spark.conf.get(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key)
== previousCkptVersion)
- // Commit to commitLog
+ // Commit to commitLog with checkpoint IDs
val latestCommit = targetCheckpointMetadata.commitLog.get(lastBatch).get
- targetCheckpointMetadata.commitLog.add(writeBatchId, latestCommit)
+ val operatorId = 0L
Review Comment:
why only for operator 0?
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala:
##########
@@ -3943,47 +3943,55 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures
with SharedSparkSession
}}
}
- test("SPARK-54420: load with createEmpty creates empty store") {
- val remoteDir = Utils.createTempDir().toString
- new File(remoteDir).delete()
-
- withDB(remoteDir) { db =>
- // loading batch 0 with loadEmpty = true
- db.load(0, None, loadEmpty = true)
- assert(iterator(db).isEmpty)
- db.put("a", "1")
- val (version1, _) = db.commit()
- assert(toStr(db.get("a")) === "1")
+ testWithStateStoreCheckpointIds(
+ "SPARK-54420: load with createEmpty creates empty store") { enableCkptId =>
+ val remoteDir = Utils.createTempDir().toString
+ new File(remoteDir).delete()
- // check we can load store normally even the previous one loadEmpty =
true
- db.load(version1)
- db.put("b", "2")
- val (version2, _) = db.commit()
- assert(version2 === version1 + 1)
- assert(toStr(db.get("b")) === "2")
- assert(toStr(db.get("a")) === "1")
+ withDB(remoteDir, enableStateStoreCheckpointIds = enableCkptId) { db =>
+ // loading batch 0 with loadEmpty = true
+ db.load(0, None, loadEmpty = true)
+ assert(iterator(db).isEmpty)
+ db.put("a", "1")
+ val (version1, checkpointInfoV1) = db.commit()
+ assert(toStr(db.get("a")) === "1")
- // load an empty store
- db.load(version2, loadEmpty = true)
- db.put("c", "3")
- val (version3, _) = db.commit()
- assert(db.get("b") === null)
- assert(db.get("a") === null)
- assert(toStr(db.get("c")) === "3")
- assert(version3 === version2 + 1)
-
- // load 2 empty store in a row
- db.load(version3, loadEmpty = true)
- db.put("d", "4")
- val (version4, _) = db.commit()
- assert(db.get("c") === null)
- assert(toStr(db.get("d")) === "4")
- assert(version4 === version3 + 1)
-
- db.load(version4)
- db.put("e", "5")
- db.commit()
- assert(db.iterator().map(toStr).toSet === Set(("d", "4"), ("e", "5")))
+ // check we can load store normally even the previous one loadEmpty =
true
+ db.load(version1, checkpointInfoV1.stateStoreCkptId)
+ db.put("b", "2")
+ val (version2, _) = db.commit()
+ assert(version2 === version1 + 1)
+ assert(toStr(db.get("b")) === "2")
+ assert(toStr(db.get("a")) === "1")
+
+ // load an empty store
+ db.load(version2, loadEmpty = true)
+ db.put("c", "3")
+ val (version3, _) = db.commit()
+ assert(db.get("b") === null)
+ assert(db.get("a") === null)
+ assert(toStr(db.get("c")) === "3")
+ assert(version3 === version2 + 1)
+
+ // load 2 empty store in a row
+ db.load(version3, loadEmpty = true)
+ db.put("d", "4")
+ val (version4, checkpointInfoV4) = db.commit()
+ assert(db.get("c") === null)
+ assert(toStr(db.get("d")) === "4")
+ assert(version4 === version3 + 1)
+
+ db.load(version4, checkpointInfoV4.stateStoreCkptId)
+ db.put("e", "5")
+ db.commit()
+ assert(db.iterator().map(toStr).toSet === Set(("d", "4"), ("e", "5")))
+
+ if (enableCkptId) {
+ val ex = intercept[IllegalArgumentException] {
+ db.load(version4, checkpointInfoV4.stateStoreCkptId, loadEmpty =
true)
+ }
+ assert(ex.getMessage.contains("stateStoeCkptId should be empty when
loadEmpty is true"))
Review Comment:
stateStoreCkptId
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala:
##########
@@ -454,4 +491,38 @@ object OfflineStateRepartitionTestUtils {
}
}
}
+
+ private def verifyCheckpointIds(
+ repartitionBatchId: Long,
+ checkpointMetadata: StreamingQueryCheckpointMetadata,
+ expectedShufflePartitions: Int,
+ expectedStoreCount: Int): Unit = {
+ // Verify commit log has the repartition batch with checkpoint IDs
+ val commitOpt = checkpointMetadata.commitLog.get(repartitionBatchId)
+ assert(commitOpt.isDefined, s"Commit for batch $repartitionBatchId should
exist")
+
+ val commitMetadata = commitOpt.get
+
+ // Verify stateUniqueIds is present for checkpoint V2
+ assert(commitMetadata.stateUniqueIds.isDefined,
+ "stateUniqueIds should be present in commit metadata when checkpoint
version >= 2")
+
+ val operatorIdToCkptInfos = commitMetadata.stateUniqueIds.get
+ assert(operatorIdToCkptInfos.nonEmpty,
+ "operatorIdToCkptInfos should not be empty")
+
+ // Verify structure for each operator
+ operatorIdToCkptInfos.foreach { case (operatorId, partitionToCkptIds) =>
+ // Should have checkpoint IDs for all partitions
+ assert(partitionToCkptIds.length == expectedShufflePartitions,
+ s"Operator $operatorId: Expected $expectedShufflePartitions partition
checkpoint IDs, " +
+ s"but found ${partitionToCkptIds.length}")
+ // Each partition should have checkpoint IDs (at least one per store)
+ partitionToCkptIds.zipWithIndex.foreach { case (ckptIds, partitionId) =>
+ assert(ckptIds.length == expectedStoreCount,
Review Comment:
Each operator can have different store counts. So we need expectedStoreCount
per operator
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala:
##########
@@ -95,11 +101,29 @@ class StatePartitionAllColumnFamiliesWriterSuite extends
StateDataSourceTestBase
transformFunc = None,
writeCheckpointMetadata = Some(targetCheckpointMetadata)
)
- rewriter.run()
+ val checkpointInfos = rewriter.run()
+ assert(spark.conf.get(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key)
== previousCkptVersion)
- // Commit to commitLog
+ // Commit to commitLog with checkpoint IDs
val latestCommit = targetCheckpointMetadata.commitLog.get(lastBatch).get
- targetCheckpointMetadata.commitLog.add(writeBatchId, latestCommit)
+ val operatorId = 0L
+ val partitionSeq: Array[Array[StateStoreCheckpointInfo]] =
checkpointInfos(operatorId)
+ val commitMetadata = if
(StateStoreConf(conf).enableStateStoreCheckpointIds) {
+ // For join operators: 4 stores (left-keyToNumValues,
left-keyWithIndexToValue,
+ // right-keyToNumValues,
right-keyWithIndexToValue)
+ // For regular operators: 1 store
+ val ckptIds: Array[Array[String]] = partitionSeq.map { storesSeq =>
Review Comment:
ditto
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]