This is an automated email from the ASF dual-hosted git repository.
ashrigondekar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 813f7576aef2 [SPARK-54590][SS] Support Checkpoint V2 for State
Rewriter and Repartitioning
813f7576aef2 is described below
commit 813f7576aef225095393bce86df5379409bcd40b
Author: zifeif2 <[email protected]>
AuthorDate: Thu Jan 15 17:38:16 2026 -0800
[SPARK-54590][SS] Support Checkpoint V2 for State Rewriter and
Repartitioning
### What changes were proposed in this pull request?
Support checkpointV2 for repartition writer and StateRewriter by returning
the checkpoint Id to caller function after write is done.
Changes include
- RocksDB loadWithCheckpointId supports loadEmpty
- StatePartitionAllColumnFamiliesWriter return StateStoreCheckpointInfo
- StateRewriter also propagate StateStoreCheckpointInfo back to the
RepartitionRunner
- RepartitionRunner stores the checkpointIds in commitLog
### Why are the changes needed?
This is required in PrPr for repartition project
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
See added unit tests on moth operator with single state store and multiple
state stores
### Was this patch authored or co-authored using generative AI tooling?
Yes. Sonnet 4.5
Closes #53720 from zifeif2/repartition-cp-v2.
Lead-authored-by: zifeif2 <[email protected]>
Co-authored-by: Zifei Feng <[email protected]>
Signed-off-by: Anish Shrigondekar <[email protected]>
---
.../src/main/resources/error/error-conditions.json | 7 +
.../state/OfflineStateRepartitionRunner.scala | 14 +-
.../sql/execution/streaming/state/RocksDB.scala | 77 +-
.../streaming/state/RocksDBFileManager.scala | 2 +
.../streaming/state/StatePartitionWriter.scala | 5 +-
.../execution/streaming/state/StateRewriter.scala | 114 ++-
.../state/OfflineStateRepartitionSuite.scala | 361 ++++++----
.../execution/streaming/state/RocksDBSuite.scala | 106 +--
...tatePartitionAllColumnFamiliesWriterSuite.scala | 791 +++++++++++----------
9 files changed, 861 insertions(+), 616 deletions(-)
diff --git a/common/utils/src/main/resources/error/error-conditions.json
b/common/utils/src/main/resources/error/error-conditions.json
index 2b205908333c..519a3cafd3be 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -5603,6 +5603,13 @@
"Ensure that the checkpoint is for a stateful streaming query and
the query ran on a Spark version that supports operator metadata (Spark 4.0+)."
]
},
+ "STATE_CHECKPOINT_FORMAT_VERSION_MISMATCH" : {
+ "message" : [
+ "The checkpoint format version in SQLConf does not match the
checkpoint version in the commit log.",
+ "Expected version <expectedVersion>, but found <actualVersion>.",
+ "Please set '<sqlConfKey>' to <expectedVersion> in your SQLConf
before retrying."
+ ]
+ },
"UNSUPPORTED_STATE_STORE_METADATA_VERSION" : {
"message" : [
"Unsupported state store metadata version encountered.",
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala
index 19f75eb385f5..f53efe1cc0ac 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala
@@ -95,12 +95,12 @@ class OfflineStateRepartitionRunner(
transformFunc = Some(stateRepartitionFunc),
writeCheckpointMetadata = Some(checkpointMetadata)
)
- rewriter.run()
+ val operatorToCkptIds = rewriter.run()
updateNumPartitionsInOperatorMetadata(newBatchId, readBatchId =
lastCommittedBatchId)
// Commit the repartition batch
- commitBatch(newBatchId, lastCommittedBatchId)
+ commitBatch(newBatchId, lastCommittedBatchId, operatorToCkptIds)
newBatchId
}
@@ -289,12 +289,14 @@ class OfflineStateRepartitionRunner(
}
}
- private def commitBatch(newBatchId: Long, lastCommittedBatchId: Long): Unit
= {
+ private def commitBatch(
+ newBatchId: Long,
+ lastCommittedBatchId: Long,
+ opIdToStateStoreCkptInfo: Option[Map[Long, Array[Array[String]]]]): Unit
= {
val latestCommit =
checkpointMetadata.commitLog.get(lastCommittedBatchId).get
+ val commitMetadata = latestCommit.copy(stateUniqueIds =
opIdToStateStoreCkptInfo)
- // todo: For checkpoint v2, we need to update the stateUniqueIds based on
the
- // newly created state commit. Will be done in subsequent PR.
- if (!checkpointMetadata.commitLog.add(newBatchId, latestCommit)) {
+ if (!checkpointMetadata.commitLog.add(newBatchId, commitMetadata)) {
throw QueryExecutionErrors.concurrentStreamLogUpdate(newBatchId)
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
index f9b94160f9eb..4c9b2282ba27 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
@@ -443,13 +443,26 @@ class RocksDB(
private def loadWithCheckpointId(
version: Long,
stateStoreCkptId: Option[String],
- readOnly: Boolean = false): RocksDB = {
+ readOnly: Boolean = false,
+ loadEmpty: Boolean = false): RocksDB = {
// An array contains lineage information from [snapShotVersion, version]
// (inclusive in both ends)
var currVersionLineage: Array[LineageItem] =
lineageManager.getLineageForCurrVersion()
try {
- if (loadedVersion != version || (loadedStateStoreCkptId.isEmpty ||
- stateStoreCkptId.get != loadedStateStoreCkptId.get)) {
+ if (loadEmpty) {
+ // Handle empty store loading separately for clarity
+ require(stateStoreCkptId.isEmpty,
+ "stateStoreCkptId should be empty when loadEmpty is true")
+ closeDB(ignoreException = false)
+ loadEmptyStore(version)
+ lineageManager.clear()
+ // After loading empty store, set the key counts and metrics
+ numKeysOnLoadedVersion = numKeysOnWritingVersion
+ numInternalKeysOnLoadedVersion = numInternalKeysOnWritingVersion
+ fileManagerMetrics = fileManager.latestLoadCheckpointMetrics
+ } else if (loadedVersion != version || loadedStateStoreCkptId.isEmpty ||
+ stateStoreCkptId.get != loadedStateStoreCkptId.get) {
+ // Handle normal checkpoint loading
closeDB(ignoreException = false)
val (latestSnapshotVersion, latestSnapshotUniqueId) = {
@@ -508,9 +521,9 @@ class RocksDB(
if (loadedVersion != version) {
val versionsAndUniqueIds = currVersionLineage.collect {
- case i if i.version > loadedVersion && i.version <= version =>
- (i.version, Option(i.checkpointUniqueId))
- }
+ case i if i.version > loadedVersion && i.version <= version =>
+ (i.version, Option(i.checkpointUniqueId))
+ }
replayChangelog(versionsAndUniqueIds)
loadedVersion = version
lineageManager.resetLineage(currVersionLineage)
@@ -530,9 +543,13 @@ class RocksDB(
if (conf.resetStatsOnLoad) {
nativeStats.reset
}
-
- logInfo(log"Loaded ${MDC(LogKeys.VERSION_NUM, version)} " +
- log"with uniqueId ${MDC(LogKeys.UUID, stateStoreCkptId)}")
+ if (loadEmpty) {
+ logInfo(log"Loaded empty store at version ${MDC(LogKeys.VERSION_NUM,
version)} " +
+ log"with empty uniqueId")
+ } else {
+ logInfo(log"Loaded ${MDC(LogKeys.VERSION_NUM, version)} " +
+ log"with uniqueId ${MDC(LogKeys.UUID, stateStoreCkptId)}")
+ }
} catch {
case t: Throwable =>
loadedVersion = -1 // invalidate loaded data
@@ -543,25 +560,33 @@ class RocksDB(
lineageManager.clear()
throw t
}
- if (enableChangelogCheckpointing && !readOnly) {
+ // Checking conf.enableChangelogCheckpointing instead of
enableChangelogCheckpointing.
+ // enableChangelogCheckpointing is set to false when loadEmpty is true,
but we still want
+ // to abort previous used changelogWriter if there is any
+ if (conf.enableChangelogCheckpointing && !readOnly) {
// Make sure we don't leak resource.
changelogWriter.foreach(_.abort())
- // Initialize the changelog writer with lineage info
- // The lineage stored in changelog files should normally start with
- // the version of a snapshot, except for the first few versions.
- // Because they are solely loaded from changelog file.
- // (e.g. with default minDeltasForSnapshot, there is only
1_uuid1.changelog, no 1_uuid1.zip)
- // It should end with exactly one version before the change log's
version.
- changelogWriter = Some(fileManager.getChangeLogWriter(
- version + 1,
- useColumnFamilies,
- sessionStateStoreCkptId,
- Some(currVersionLineage)))
+ if (loadEmpty) {
+ // We don't want to write changelog file when loadEmpty is true
+ changelogWriter = None
+ } else {
+ // Initialize the changelog writer with lineage info
+ // The lineage stored in changelog files should normally start with
+ // the version of a snapshot, except for the first few versions.
+ // Because they are solely loaded from changelog file.
+ // (e.g. with default minDeltasForSnapshot, there is only
1_uuid1.changelog, no 1_uuid1.zip)
+ // It should end with exactly one version before the change log's
version.
+ changelogWriter = Some(fileManager.getChangeLogWriter(
+ version + 1,
+ useColumnFamilies,
+ sessionStateStoreCkptId,
+ Some(currVersionLineage)))
+ }
}
this
}
- private def loadEmptyStoreWithoutCheckpointId(version: Long): Unit = {
+ private def loadEmptyStore(version: Long): Unit = {
// Use version 0 logic to create empty directory with no SST files
val metadata = fileManager.loadCheckpointFromDfs(0, workingDir,
rocksDBFileMapping, None)
loadedVersion = version
@@ -580,7 +605,7 @@ class RocksDB(
closeDB(ignoreException = false)
if (loadEmpty) {
- loadEmptyStoreWithoutCheckpointId(version)
+ loadEmptyStore(version)
} else {
// load the latest snapshot
loadSnapshotWithoutCheckpointId(version)
@@ -752,9 +777,9 @@ class RocksDB(
// If loadEmpty is true, we will not generate a changelog but only a
snapshot file to prevent
// mistakenly applying new changelog to older state version
enableChangelogCheckpointing = if (loadEmpty) false else
conf.enableChangelogCheckpointing
- if (stateStoreCkptId.isDefined || enableStateStoreCheckpointIds && version
== 0) {
- assert(!loadEmpty, "loadEmpty not supported for checkpointV2 yet")
- loadWithCheckpointId(version, stateStoreCkptId, readOnly)
+ if (stateStoreCkptId.isDefined ||
+ enableStateStoreCheckpointIds && (version == 0 || loadEmpty)) {
+ loadWithCheckpointId(version, stateStoreCkptId, readOnly, loadEmpty)
} else {
loadWithoutCheckpointId(version, readOnly, loadEmpty)
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala
index 75e459b26511..7135421f4866 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala
@@ -375,6 +375,8 @@ class RocksDBFileManager(
createDfsRootDirIfNotExist()
// Since we cleared the local dir, we should also clear the local file
mapping
rocksDBFileMapping.clear()
+ // Set empty metrics since we're not loading any files from DFS
+ loadCheckpointMetrics = RocksDBFileManagerMetrics.EMPTY_METRICS
RocksDBCheckpointMetadata(Seq.empty, 0)
} else {
// Delete all non-immutable files in local dir, and unzip new ones from
DFS commit file
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionWriter.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionWriter.scala
index 3df97d3adc0e..64c0fb6d6e76 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionWriter.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionWriter.scala
@@ -145,10 +145,13 @@ class StatePartitionAllColumnFamiliesWriter(
// - key_bytes, BinaryType
// - value_bytes, BinaryType
// - column_family_name, StringType
- def write(rows: Iterator[InternalRow]): Unit = {
+ // Returns StateStoreCheckpointInfo containing the checkpoint ID after
commit if
+ // enabled checkpointV2
+ def write(rows: Iterator[InternalRow]): StateStoreCheckpointInfo = {
try {
rows.foreach(row => writeRow(row))
stateStore.commit()
+ stateStore.getStateStoreCheckpointInfo()
} finally {
if (!stateStore.hasCommitted) {
stateStore.abort()
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala
index da28a3c907f7..42561ec5ca7b 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala
@@ -22,16 +22,17 @@ import java.util.UUID
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
-import org.apache.spark.{SparkIllegalStateException, TaskContext}
+import org.apache.spark.{SparkIllegalStateException, SparkThrowable,
TaskContext}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.internal.LogKeys._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions
import
org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader
import org.apache.spark.sql.execution.streaming.checkpointing.OffsetSeqMetadata
-import
org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorsUtils
+import
org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo,
StatefulOperatorsUtils}
import
org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateVariableType,
TransformWithStateOperatorProperties, TransformWithStateVariableInfo}
import
org.apache.spark.sql.execution.streaming.runtime.{StreamingCheckpointConstants,
StreamingQueryCheckpointMetadata}
import
org.apache.spark.sql.execution.streaming.state.{StatePartitionAllColumnFamiliesWriter,
StateSchemaCompatibilityChecker}
@@ -80,8 +81,19 @@ class StateRewriter(
readResolvedCheckpointLocation.getOrElse(resolvedCheckpointLocation)
private val stateRootLocation = new Path(
resolvedCheckpointLocation,
StreamingCheckpointConstants.DIR_NAME_STATE).toString
+ private lazy val writeCheckpoint = writeCheckpointMetadata.getOrElse(
+ new StreamingQueryCheckpointMetadata(sparkSession,
resolvedCheckpointLocation))
+ private lazy val readCheckpoint = if
(readResolvedCheckpointLocation.isDefined) {
+ new StreamingQueryCheckpointMetadata(sparkSession,
readResolvedCheckpointLocation.get)
+ } else {
+ // Same checkpoint for read & write
+ writeCheckpoint
+ }
- def run(): Unit = {
+ // If checkpoint id is enabled, return
+ // Map[operatorId, Array[partition -> Array[stateStore ->
StateStoreCheckpointId]]].
+ // Otherwise, return None
+ def run(): Option[Map[Long, Array[Array[String]]]] = {
logInfo(log"Starting state rewrite for " +
log"checkpointLocation=${MDC(CHECKPOINT_LOCATION,
resolvedCheckpointLocation)}, " +
log"readCheckpointLocation=" +
@@ -89,16 +101,45 @@ class StateRewriter(
log"readBatchId=${MDC(BATCH_ID, readBatchId)}, " +
log"writeBatchId=${MDC(BATCH_ID, writeBatchId)}")
- val (_, timeTakenMs) = Utils.timeTakenMs {
+ val (checkpointIds, timeTakenMs) = Utils.timeTakenMs {
runInternal()
}
logInfo(log"State rewrite completed in ${MDC(DURATION, timeTakenMs)} ms
for " +
log"checkpointLocation=${MDC(CHECKPOINT_LOCATION,
resolvedCheckpointLocation)}")
+ checkpointIds
}
- private def runInternal(): Unit = {
+ private def extractCheckpointIdsPerPartition(
+ checkpointInfos: Map[Long, Array[Array[StateStoreCheckpointInfo]]]
+ ): Option[Map[Long, Array[Array[String]]]] = {
+ val enableCheckpointId = StatefulOperatorStateInfo.
+ enableStateStoreCheckpointIds(sparkSession.sessionState.conf)
+ if (!enableCheckpointId) {
+ return None
+ }
+ // convert Map[operatorId, Array[stateStore -> Array[partition ->
StateStoreCheckpointInfo]]]
+ // to Map[operatorId, Array[partition -> Array[stateStore ->
StateStoreCheckpointId]]].
+ Option(checkpointInfos.map {
+ case(operator, storesSeq) =>
+ val numPartitions = storesSeq.head.length
+ operator -> (0 until numPartitions).map { partitionIdx =>
+ storesSeq.map { storePartitions =>
+ val checkpointInfoPerPartition = storePartitions(partitionIdx)
+ assert(checkpointInfoPerPartition.partitionId == partitionIdx)
+ assert(checkpointInfoPerPartition.batchVersion == writeBatchId + 1)
+ // expect baseStateStoreCkptId empty because we didn't load
+ // any previous stores when doing the rewrite
+ assert(checkpointInfoPerPartition.baseStateStoreCkptId.isEmpty)
+ checkpointInfoPerPartition.stateStoreCkptId.get
+ }
+ }.toArray
+ })
+ }
+
+ private def runInternal(): Option[Map[Long, Array[Array[String]]]] = {
try {
+ verifyCheckpointFormatVersion()
val stateMetadataReader = new StateMetadataPartitionReader(
resolvedCheckpointLocation,
new SerializableConfiguration(hadoopConf),
@@ -124,7 +165,7 @@ class StateRewriter(
// Do rewrite for each operator
// We can potentially parallelize this, but for now, do sequentially
- allOperatorsMetadata.foreach { opMetadata =>
+ val checkpointInfos = allOperatorsMetadata.map { opMetadata =>
val stateStoresMetadata = opMetadata.stateStoresMetadata
assert(!stateStoresMetadata.isEmpty,
s"Operator ${opMetadata.operatorInfo.operatorName} has no state
stores")
@@ -133,7 +174,7 @@ class StateRewriter(
val stateVarsIfTws = getStateVariablesIfTWS(opMetadata)
// Rewrite each state store of the operator
- stateStoresMetadata.foreach { stateStoreMetadata =>
+ val checkpointInfo = stateStoresMetadata.map { stateStoreMetadata =>
rewriteStore(
opMetadata,
stateStoreMetadata,
@@ -143,8 +184,10 @@ class StateRewriter(
stateVarsIfTws,
sqlConfEntries
)
- }
- }
+ }.toArray
+ opMetadata.operatorInfo.operatorId -> checkpointInfo
+ }.toMap
+ extractCheckpointIdsPerPartition(checkpointInfos)
} catch {
case e: Throwable =>
logError(log"State rewrite failed for " +
@@ -163,7 +206,7 @@ class StateRewriter(
storeSchemaFiles: List[Path],
stateVarsIfTws: Map[String, TransformWithStateVariableInfo],
sqlConfEntries: Map[String, String]
- ): Unit = {
+ ): Array[StateStoreCheckpointInfo] = {
// Read state
val stateDf = sparkSession.read
.format("statestore")
@@ -206,7 +249,7 @@ class StateRewriter(
// to avoid serializing the entire Rewriter object per partition.
val targetCheckpointLocation = resolvedCheckpointLocation
val currentBatchId = writeBatchId
- updatedStateDf.queryExecution.toRdd.foreachPartition { partitionIter =>
+ updatedStateDf.queryExecution.toRdd.mapPartitions { partitionIter:
Iterator[InternalRow] =>
// Recreate SQLConf on executor from serialized entries
val executorSqlConf = new SQLConf()
sqlConfEntries.foreach { case (k, v) => executorSqlConf.setConfString(k,
v) }
@@ -224,9 +267,8 @@ class StateRewriter(
schemaProvider,
executorSqlConf
)
-
- partitionWriter.write(partitionIter)
- }
+ Iterator(partitionWriter.write(partitionIter))
+ }.collect()
}
/** Create the store and sql confs from the conf written in the offset log */
@@ -342,6 +384,30 @@ class StateRewriter(
None
}
}
+
+ private def verifyCheckpointFormatVersion(): Unit = {
+ // Verify checkpoint version in sqlConf based on commitLog for
readCheckpoint
+ // in case user forgot to set STATE_STORE_CHECKPOINT_FORMAT_VERSION.
+ // Using read batch commit since the latest commit could be a skipped
batch.
+ // If SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION is wrong,
readCheckpoint.commitLog
+ // will throw an exception, and we will propagate this exception upstream.
+ // This prevents the StateRewriter from failing to write the correct state
files
+ try {
+ readCheckpoint.commitLog.get(readBatchId)
+ } catch {
+ case e: IllegalStateException if e.getCause != null &&
+ e.getCause.isInstanceOf[SparkThrowable] =>
+ val sparkThrowable = e.getCause.asInstanceOf[SparkThrowable]
+ if (sparkThrowable.getCondition ==
"INVALID_LOG_VERSION.EXACT_MATCH_VERSION") {
+ val params = sparkThrowable.getMessageParameters
+ val expectedVersion = params.get("version")
+ val actualVersion = params.get("matchVersion")
+ throw
StateRewriterErrors.stateCheckpointFormatVersionMismatchError(
+ checkpointLocationForRead, expectedVersion, actualVersion)
+ }
+ throw e
+ }
+ }
}
/**
@@ -364,6 +430,14 @@ private[state] object StateRewriterErrors {
checkpointLocation: String): StateRewriterInvalidCheckpointError = {
new StateRewriterUnsupportedStoreMetadataVersionError(checkpointLocation)
}
+
+ def stateCheckpointFormatVersionMismatchError(
+ checkpointLocation: String,
+ expectedVersion: String,
+ actualVersion: String): StateRewriterInvalidCheckpointError = {
+ new StateRewriterStateCheckpointFormatVersionMismatchError(
+ checkpointLocation, expectedVersion, actualVersion)
+ }
}
/**
@@ -402,3 +476,15 @@ private[state] class
StateRewriterUnsupportedStoreMetadataVersionError(
checkpointLocation,
subClass = "UNSUPPORTED_STATE_STORE_METADATA_VERSION",
messageParameters = Map.empty)
+
+private[state] class StateRewriterStateCheckpointFormatVersionMismatchError(
+ checkpointLocation: String,
+ expectedVersion: String,
+ actualVersion: String)
+ extends StateRewriterInvalidCheckpointError(
+ checkpointLocation,
+ subClass = "STATE_CHECKPOINT_FORMAT_VERSION_MISMATCH",
+ messageParameters = Map(
+ "expectedVersion" -> expectedVersion,
+ "actualVersion" -> actualVersion,
+ "sqlConfKey" -> SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key))
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala
index 1ab581d79437..b4bd50dfb42f 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala
@@ -21,8 +21,10 @@ import scala.util.Try
import org.apache.hadoop.conf.Configuration
+import org.apache.spark.sql.SparkSession
import
org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader
import org.apache.spark.sql.execution.streaming.checkpointing.{CommitLog,
CommitMetadata}
+import
org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorStateInfo
import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream,
StreamingQueryCheckpointMetadata}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming._
@@ -106,147 +108,171 @@ class OfflineStateRepartitionSuite extends StreamTest
)
}
- test("Repartition: success, failure, retry") {
- withTempDir { dir =>
- val originalPartitions = 3
- val input = MemoryStream[Int]
- val batchId = runSimpleStreamQuery(originalPartitions,
dir.getAbsolutePath, input)
- val checkpointMetadata = new StreamingQueryCheckpointMetadata(spark,
dir.getAbsolutePath)
- // Shouldn't be seen as a repartition batch
- assert(!isRepartitionBatch(batchId, checkpointMetadata.offsetLog,
dir.getAbsolutePath))
-
- // Trying to repartition to the same number should fail
- val ex = intercept[StateRepartitionShufflePartitionsAlreadyMatchError] {
- spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
originalPartitions)
+ Seq(1, 2).foreach { ckptVersion =>
+ def testWithCheckpointId(testName: String)(testFun: => Unit): Unit = {
+ test(s"$testName (enableCkptId = ${ckptVersion >= 2})") {
+ withSQLConf(
+ SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key ->
ckptVersion.toString) {
+ testFun
+ }
}
- checkError(
- ex,
- condition =
"STATE_REPARTITION_INVALID_CHECKPOINT.SHUFFLE_PARTITIONS_ALREADY_MATCH",
- parameters = Map(
- "checkpointLocation" -> dir.getAbsolutePath,
- "batchId" -> batchId.toString,
- "numPartitions" -> originalPartitions.toString
- )
- )
+ }
- // Trying to repartition to a different number should succeed
- val newPartitions = originalPartitions + 1
- spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
newPartitions)
- val repartitionBatchId = batchId + 1
- val hadoopConf = spark.sessionState.newHadoopConf()
- verifyRepartitionBatch(
- repartitionBatchId, checkpointMetadata, hadoopConf,
dir.getAbsolutePath, newPartitions)
-
- // Now delete the repartition commit to simulate a failed repartition
attempt.
- // This will delete all the commits after the batchId.
- checkpointMetadata.commitLog.purgeAfter(batchId)
-
- // Try to repartition with a different numPartitions should fail,
- // since it will see an uncommitted repartition batch with a different
numPartitions.
- val ex2 = intercept[StateRepartitionLastBatchAbandonedRepartitionError] {
- spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
newPartitions + 1)
- }
- checkError(
- ex2,
- condition =
"STATE_REPARTITION_INVALID_CHECKPOINT.LAST_BATCH_ABANDONED_REPARTITION",
- parameters = Map(
- "checkpointLocation" -> dir.getAbsolutePath,
- "lastBatchId" -> repartitionBatchId.toString,
- "lastBatchShufflePartitions" -> newPartitions.toString,
- "numPartitions" -> (newPartitions + 1).toString
+ testWithCheckpointId("Repartition: success, failure, retry") {
+ withTempDir { dir =>
+ val originalPartitions = 3
+ val input = MemoryStream[Int]
+ val batchId = runSimpleStreamQuery(originalPartitions,
dir.getAbsolutePath, input)
+ val checkpointMetadata = new StreamingQueryCheckpointMetadata(spark,
dir.getAbsolutePath)
+ // Shouldn't be seen as a repartition batch
+ assert(!isRepartitionBatch(batchId, checkpointMetadata.offsetLog,
dir.getAbsolutePath))
+
+ // Trying to repartition to the same number should fail
+ val ex = intercept[StateRepartitionShufflePartitionsAlreadyMatchError]
{
+ spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
originalPartitions)
+ }
+ checkError(
+ ex,
+ condition =
"STATE_REPARTITION_INVALID_CHECKPOINT.SHUFFLE_PARTITIONS_ALREADY_MATCH",
+ parameters = Map(
+ "checkpointLocation" -> dir.getAbsolutePath,
+ "batchId" -> batchId.toString,
+ "numPartitions" -> originalPartitions.toString
+ )
)
- )
-
- // Retrying with the same numPartitions should work
- spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
newPartitions)
- verifyRepartitionBatch(
- repartitionBatchId, checkpointMetadata, hadoopConf,
dir.getAbsolutePath, newPartitions)
- // Repartition with way more partitions, to verify that empty partitions
are properly created
- val morePartitions = newPartitions * 3
- spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
morePartitions)
- verifyRepartitionBatch(
- repartitionBatchId + 1, checkpointMetadata, hadoopConf,
- dir.getAbsolutePath, morePartitions)
+ // Trying to repartition to a different number should succeed
+ val newPartitions = originalPartitions + 1
+ spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
newPartitions)
+ val repartitionBatchId = batchId + 1
+ val hadoopConf = spark.sessionState.newHadoopConf()
+ verifyRepartitionBatch(
+ repartitionBatchId,
+ checkpointMetadata,
+ hadoopConf,
+ dir.getAbsolutePath,
+ newPartitions,
+ spark)
+
+ // Now delete the repartition commit to simulate a failed repartition
attempt.
+ // This will delete all the commits after the batchId.
+ checkpointMetadata.commitLog.purgeAfter(batchId)
+
+ // Try to repartition with a different numPartitions should fail,
+ // since it will see an uncommitted repartition batch with a different
numPartitions.
+ val ex2 =
intercept[StateRepartitionLastBatchAbandonedRepartitionError] {
+ spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
newPartitions + 1)
+ }
+ checkError(
+ ex2,
+ condition =
"STATE_REPARTITION_INVALID_CHECKPOINT.LAST_BATCH_ABANDONED_REPARTITION",
+ parameters = Map(
+ "checkpointLocation" -> dir.getAbsolutePath,
+ "lastBatchId" -> repartitionBatchId.toString,
+ "lastBatchShufflePartitions" -> newPartitions.toString,
+ "numPartitions" -> (newPartitions + 1).toString
+ )
+ )
- // Restart the query to make sure it can start after repartitioning
- runSimpleStreamQuery(morePartitions, dir.getAbsolutePath, input)
+ // Retrying with the same numPartitions should work
+ spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
newPartitions)
+ verifyRepartitionBatch(
+ repartitionBatchId,
+ checkpointMetadata,
+ hadoopConf,
+ dir.getAbsolutePath,
+ newPartitions,
+ spark)
+
+ // Repartition with way more partitions, to verify that empty
partitions are properly
+ // created
+ val morePartitions = newPartitions * 3
+ spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
morePartitions)
+ verifyRepartitionBatch(
+ repartitionBatchId + 1, checkpointMetadata, hadoopConf,
+ dir.getAbsolutePath, morePartitions, spark)
+ // Restart the query to make sure it can start after repartitioning
+ runSimpleStreamQuery(morePartitions, dir.getAbsolutePath, input)
+ }
}
- }
- test("Query last batch failed before repartitioning") {
- withTempDir { dir =>
- val originalPartitions = 3
- val input = MemoryStream[Int]
- // Run 3 batches
- val firstBatchId = 0
- val lastBatchId = firstBatchId + 2
- (firstBatchId to lastBatchId).foreach { _ =>
- runSimpleStreamQuery(originalPartitions, dir.getAbsolutePath, input)
- }
- val checkpointMetadata = new StreamingQueryCheckpointMetadata(spark,
dir.getAbsolutePath)
+ testWithCheckpointId("Query last batch failed before repartitioning") {
+ withTempDir { dir =>
+ val originalPartitions = 3
+ val input = MemoryStream[Int]
+ // Run 3 batches
+ val firstBatchId = 0
+ val lastBatchId = firstBatchId + 2
+ (firstBatchId to lastBatchId).foreach { _ =>
+ runSimpleStreamQuery(originalPartitions, dir.getAbsolutePath, input)
+ }
+ val checkpointMetadata = new StreamingQueryCheckpointMetadata(spark,
dir.getAbsolutePath)
- // Lets keep only the first commit to simulate multiple failed batches
- checkpointMetadata.commitLog.purgeAfter(firstBatchId)
+ // Lets keep only the first commit to simulate multiple failed batches
+ checkpointMetadata.commitLog.purgeAfter(firstBatchId)
- // Now repartitioning should fail
- val ex = intercept[StateRepartitionLastBatchFailedError] {
- spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
originalPartitions + 1)
- }
- checkError(
- ex,
- condition = "STATE_REPARTITION_INVALID_CHECKPOINT.LAST_BATCH_FAILED",
- parameters = Map(
- "checkpointLocation" -> dir.getAbsolutePath,
- "lastBatchId" -> lastBatchId.toString
+ // Now repartitioning should fail
+ val ex = intercept[StateRepartitionLastBatchFailedError] {
+ spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
originalPartitions + 1)
+ }
+ checkError(
+ ex,
+ condition = "STATE_REPARTITION_INVALID_CHECKPOINT.LAST_BATCH_FAILED",
+ parameters = Map(
+ "checkpointLocation" -> dir.getAbsolutePath,
+ "lastBatchId" -> lastBatchId.toString
+ )
)
- )
- // Setting enforceExactlyOnceSink to false should allow repartitioning
- spark.streamingCheckpointManager.repartition(
- dir.getAbsolutePath, originalPartitions + 1, enforceExactlyOnceSink =
false)
- verifyRepartitionBatch(
- lastBatchId + 1,
- checkpointMetadata,
- spark.sessionState.newHadoopConf(),
- dir.getAbsolutePath,
- originalPartitions + 1,
- // Repartition should be based on the first batch, since we skipped
the others
- baseBatchId = Some(firstBatchId))
+ // Setting enforceExactlyOnceSink to false should allow repartitioning
+ spark.streamingCheckpointManager.repartition(
+ dir.getAbsolutePath, originalPartitions + 1, enforceExactlyOnceSink
= false)
+ verifyRepartitionBatch(
+ lastBatchId + 1,
+ checkpointMetadata,
+ spark.sessionState.newHadoopConf(),
+ dir.getAbsolutePath,
+ originalPartitions + 1,
+ spark,
+ // Repartition should be based on the first batch, since we skipped
the others
+ baseBatchId = Some(firstBatchId))
+ }
}
- }
- test("Consecutive repartition") {
- withTempDir { dir =>
- val originalPartitions = 5
- val input = MemoryStream[Int]
- val batchId = runSimpleStreamQuery(originalPartitions,
dir.getAbsolutePath, input)
-
- val checkpointMetadata = new StreamingQueryCheckpointMetadata(spark,
dir.getAbsolutePath)
- val hadoopConf = spark.sessionState.newHadoopConf()
-
- // decrease
- spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
originalPartitions - 3)
- verifyRepartitionBatch(
- batchId + 1,
- checkpointMetadata,
- hadoopConf,
- dir.getAbsolutePath,
- originalPartitions - 3
- )
+ testWithCheckpointId("Consecutive repartition") {
+ withTempDir { dir =>
+ val originalPartitions = 5
+ val input = MemoryStream[Int]
+ val batchId = runSimpleStreamQuery(originalPartitions,
dir.getAbsolutePath, input)
+
+ val checkpointMetadata = new StreamingQueryCheckpointMetadata(spark,
dir.getAbsolutePath)
+ val hadoopConf = spark.sessionState.newHadoopConf()
+
+ // decrease
+ spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
originalPartitions - 3)
+ verifyRepartitionBatch(
+ batchId + 1,
+ checkpointMetadata,
+ hadoopConf,
+ dir.getAbsolutePath,
+ originalPartitions - 3,
+ spark
+ )
- // increase
- spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
originalPartitions + 1)
- verifyRepartitionBatch(
- batchId + 2,
- checkpointMetadata,
- hadoopConf,
- dir.getAbsolutePath,
- originalPartitions + 1
- )
+ // increase
+ spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
originalPartitions + 1)
+ verifyRepartitionBatch(
+ batchId + 2,
+ checkpointMetadata,
+ hadoopConf,
+ dir.getAbsolutePath,
+ originalPartitions + 1,
+ spark
+ )
- // Restart the query to make sure it can start after repartitioning
- runSimpleStreamQuery(originalPartitions + 1, dir.getAbsolutePath, input)
+ // Restart the query to make sure it can start after repartitioning
+ runSimpleStreamQuery(originalPartitions + 1, dir.getAbsolutePath,
input)
+ }
}
}
@@ -285,6 +311,7 @@ object OfflineStateRepartitionTestUtils {
hadoopConf: Configuration,
checkpointLocation: String,
expectedShufflePartitions: Int,
+ spark: SparkSession,
baseBatchId: Option[Long] = None): Unit = {
// Should be seen as a repartition batch
assert(isRepartitionBatch(batchId, checkpointMetadata.offsetLog,
checkpointLocation))
@@ -296,8 +323,21 @@ object OfflineStateRepartitionTestUtils {
verifyOffsetAndCommitLog(
batchId, previousBatchId, expectedShufflePartitions, checkpointMetadata)
verifyPartitionDirs(checkpointLocation, expectedShufflePartitions)
+
+ val serializableConf = new SerializableConfiguration(hadoopConf)
+ val baseOperatorsMetadata = getOperatorMetadata(
+ checkpointLocation, serializableConf, previousBatchId)
+ val repartitionOperatorsMetadata = getOperatorMetadata(
+ checkpointLocation, serializableConf, batchId)
verifyOperatorMetadata(
- batchId, previousBatchId, checkpointLocation, expectedShufflePartitions,
hadoopConf)
+ baseOperatorsMetadata, repartitionOperatorsMetadata,
expectedShufflePartitions)
+ if
(StatefulOperatorStateInfo.enableStateStoreCheckpointIds(spark.sessionState.conf))
{
+ verifyCheckpointIds(
+ batchId,
+ checkpointMetadata,
+ expectedShufflePartitions,
+ baseOperatorsMetadata)
+ }
}
private def verifyOffsetAndCommitLog(
@@ -376,23 +416,20 @@ object OfflineStateRepartitionTestUtils {
}
}
- private def verifyOperatorMetadata(
- repartitionBatchId: Long,
- baseBatchId: Long,
+ private def getOperatorMetadata(
checkpointLocation: String,
- expectedShufflePartitions: Int,
- hadoopConf: Configuration): Unit = {
- val serializableConf = new SerializableConfiguration(hadoopConf)
-
- // Read operator metadata for both batches
- val baseMetadataReader = new StateMetadataPartitionReader(
- checkpointLocation, serializableConf, baseBatchId)
- val repartitionMetadataReader = new StateMetadataPartitionReader(
- checkpointLocation, serializableConf, repartitionBatchId)
-
- val baseOperatorsMetadata = baseMetadataReader.allOperatorStateMetadata
- val repartitionOperatorsMetadata =
repartitionMetadataReader.allOperatorStateMetadata
+ serializableConf: SerializableConfiguration,
+ batchId: Long
+ ): Array[OperatorStateMetadata] = {
+ val metadataPartitionReader = new StateMetadataPartitionReader(
+ checkpointLocation, serializableConf, batchId)
+ metadataPartitionReader.allOperatorStateMetadata
+ }
+ private def verifyOperatorMetadata(
+ baseOperatorsMetadata: Array[OperatorStateMetadata],
+ repartitionOperatorsMetadata: Array[OperatorStateMetadata],
+ expectedShufflePartitions: Int): Unit = {
assert(baseOperatorsMetadata.nonEmpty, "Base batch should have operator
metadata")
assert(repartitionOperatorsMetadata.nonEmpty, "Repartition batch should
have operator metadata")
assert(baseOperatorsMetadata.length == repartitionOperatorsMetadata.length,
@@ -454,4 +491,44 @@ object OfflineStateRepartitionTestUtils {
}
}
}
+
+ private def verifyCheckpointIds(
+ repartitionBatchId: Long,
+ checkpointMetadata: StreamingQueryCheckpointMetadata,
+ expectedShufflePartitions: Int,
+ baseOperatorsMetadata: Array[OperatorStateMetadata]): Unit = {
+ val expectedStoreCnts: Map[Long, Int] = baseOperatorsMetadata.map {
+ case metadataV2: OperatorStateMetadataV2 =>
+ metadataV2.operatorInfo.operatorId -> metadataV2.stateStoreInfo.length
+ case metadataV1: OperatorStateMetadataV1 =>
+ metadataV1.operatorInfo.operatorId -> metadataV1.stateStoreInfo.length
+ }.toMap
+ // Verify commit log has the repartition batch with checkpoint IDs
+ val commitOpt = checkpointMetadata.commitLog.get(repartitionBatchId)
+ assert(commitOpt.isDefined, s"Commit for batch $repartitionBatchId should
exist")
+
+ val commitMetadata = commitOpt.get
+
+ // Verify stateUniqueIds is present for checkpoint V2
+ assert(commitMetadata.stateUniqueIds.isDefined,
+ "stateUniqueIds should be present in commit metadata when checkpoint
version >= 2")
+
+ val operatorIdToCkptInfos = commitMetadata.stateUniqueIds.get
+ assert(operatorIdToCkptInfos.nonEmpty,
+ "operatorIdToCkptInfos should not be empty")
+
+ // Verify structure for each operator
+ operatorIdToCkptInfos.foreach { case (operatorId, partitionToCkptIds) =>
+ // Should have checkpoint IDs for all partitions
+ assert(partitionToCkptIds.length == expectedShufflePartitions,
+ s"Operator $operatorId: Expected $expectedShufflePartitions partition
checkpoint IDs, " +
+ s"but found ${partitionToCkptIds.length}")
+ // Each partition should have checkpoint IDs (at least one per store)
+ partitionToCkptIds.zipWithIndex.foreach { case (ckptIds, partitionId) =>
+ assert(ckptIds.length == expectedStoreCnts(operatorId),
+ s"Operator $operatorId, partition $partitionId should" +
+ s"has ${expectedStoreCnts(operatorId)} checkpoint Ids")
+ }
+ }
+ }
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
index ad532fb71bf0..36b65fdf50e5 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
@@ -3943,48 +3943,68 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures
with SharedSparkSession
}}
}
- test("SPARK-54420: load with createEmpty creates empty store") {
- val remoteDir = Utils.createTempDir().toString
- new File(remoteDir).delete()
-
- withDB(remoteDir) { db =>
- // loading batch 0 with loadEmpty = true
- db.load(0, None, loadEmpty = true)
- assert(iterator(db).isEmpty)
- db.put("a", "1")
- val (version1, _) = db.commit()
- assert(toStr(db.get("a")) === "1")
+ testWithStateStoreCheckpointIds(
+ "SPARK-54420: load with createEmpty creates empty store") { enableCkptId =>
+ val remoteDir = Utils.createTempDir().toString
+ new File(remoteDir).delete()
+ var lastVersion = 0L
+ var lastCheckpointInfo: Option[StateStoreCheckpointInfo] = None
- // check we can load store normally even the previous one loadEmpty =
true
- db.load(version1)
- db.put("b", "2")
- val (version2, _) = db.commit()
- assert(version2 === version1 + 1)
- assert(toStr(db.get("b")) === "2")
- assert(toStr(db.get("a")) === "1")
+ withDB(remoteDir, enableStateStoreCheckpointIds = enableCkptId) { db =>
+ // loading batch 0 with loadEmpty = true
+ db.load(0, None, loadEmpty = true)
+ assert(iterator(db).isEmpty)
+ db.put("a", "1")
+ val (version1, checkpointInfoV1) = db.commit()
+ assert(toStr(db.get("a")) === "1")
- // load an empty store
- db.load(version2, loadEmpty = true)
- db.put("c", "3")
- val (version3, _) = db.commit()
- assert(db.get("b") === null)
- assert(db.get("a") === null)
- assert(toStr(db.get("c")) === "3")
- assert(version3 === version2 + 1)
-
- // load 2 empty store in a row
- db.load(version3, loadEmpty = true)
- db.put("d", "4")
- val (version4, _) = db.commit()
- assert(db.get("c") === null)
- assert(toStr(db.get("d")) === "4")
- assert(version4 === version3 + 1)
-
- db.load(version4)
- db.put("e", "5")
- db.commit()
- assert(db.iterator().map(toStr).toSet === Set(("d", "4"), ("e", "5")))
- }
+ // check we can load store normally even the previous one loadEmpty =
true
+ db.load(version1, checkpointInfoV1.stateStoreCkptId)
+ db.put("b", "2")
+ val (version2, _) = db.commit()
+ assert(version2 === version1 + 1)
+ assert(toStr(db.get("b")) === "2")
+ assert(toStr(db.get("a")) === "1")
+
+ // load an empty store
+ db.load(version2, loadEmpty = true)
+ db.put("c", "3")
+ val (version3, _) = db.commit()
+ assert(db.get("b") === null)
+ assert(db.get("a") === null)
+ assert(toStr(db.get("c")) === "3")
+ assert(version3 === version2 + 1)
+
+ // load 2 empty store in a row
+ db.load(version3, loadEmpty = true)
+ db.put("d", "4")
+
+ val (version4, checkpointV4) = db.commit()
+ assert(db.get("c") === null)
+ assert(toStr(db.get("d")) === "4")
+ lastVersion = version4
+ lastCheckpointInfo = Option(checkpointV4)
+ assert(lastVersion === version3 + 1)
+ }
+
+ withDB(remoteDir, enableStateStoreCheckpointIds = enableCkptId) { db =>
+ db.load(lastVersion, lastCheckpointInfo.map(_.stateStoreCkptId).orNull)
+ db.put("e", "5")
+ db.commit()
+ assert(db.iterator().map(toStr).toSet === Set(("d", "4"), ("e", "5")))
+ }
+
+ if (enableCkptId) {
+ withDB(remoteDir, enableStateStoreCheckpointIds = enableCkptId) { db =>
+ val ex = intercept[IllegalArgumentException] {
+ db.load(
+ lastVersion,
+ lastCheckpointInfo.map(_.stateStoreCkptId).orNull,
+ loadEmpty = true)
+ }
+ assert(ex.getMessage.contains("stateStoreCkptId should be empty when
loadEmpty is true"))
+ }
+ }
}
test("SPARK-44639: Use Java tmp dir instead of configured local dirs on
Yarn") {
@@ -4027,13 +4047,13 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures
with SharedSparkSession
version: Long,
ckptId: Option[String] = None,
readOnly: Boolean = false,
- createEmpty: Boolean = false): RocksDB = {
+ loadEmpty: Boolean = false): RocksDB = {
// When a ckptId is defined, it means the test is explicitly using v2
semantic
// When it is not, it is possible that implicitly uses it.
// So still do a versionToUniqueId.get
ckptId match {
- case Some(_) => super.load(version, ckptId, readOnly)
- case None => super.load(version, versionToUniqueId.get(version),
readOnly)
+ case Some(_) => super.load(version, ckptId, readOnly, loadEmpty)
+ case None => super.load(version, versionToUniqueId.get(version),
readOnly, loadEmpty)
}
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala
index e495db499bfe..dc6b9d7e8987 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala
@@ -95,11 +95,12 @@ class StatePartitionAllColumnFamiliesWriterSuite extends
StateDataSourceTestBase
transformFunc = None,
writeCheckpointMetadata = Some(targetCheckpointMetadata)
)
- rewriter.run()
+ val checkpointInfos = rewriter.run()
- // Commit to commitLog
+ // Commit to commitLog with checkpoint IDs
val latestCommit = targetCheckpointMetadata.commitLog.get(lastBatch).get
- targetCheckpointMetadata.commitLog.add(writeBatchId, latestCommit)
+ val commitMetadata = latestCommit.copy(stateUniqueIds = checkpointInfos)
+ targetCheckpointMetadata.commitLog.add(writeBatchId, commitMetadata)
val versionToCheck = writeBatchId + 1
storeToColumnFamilies.foreach { case (storeName, columnFamilies) =>
@@ -215,20 +216,6 @@ class StatePartitionAllColumnFamiliesWriterSuite extends
StateDataSourceTestBase
useMultipleValuePerKey)
}
- /**
- * Helper method to create a single-entry column family schema map.
- * This simplifies the common case where only the default column family is
used.
- */
- private def createSingleColumnFamilySchemaMap(
- keySchema: StructType,
- valueSchema: StructType,
- keyStateEncoderSpec: KeyStateEncoderSpec,
- colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME
- ): Map[String, StatePartitionWriterColumnFamilyInfo] = {
- Map(colFamilyName -> createColFamilyInfo(keySchema, valueSchema,
- keyStateEncoderSpec, colFamilyName))
- }
-
/**
* Helper method to test SPARK-54420 read and write with different state
format versions
* for simple aggregation (single grouping key).
@@ -307,17 +294,6 @@ class StatePartitionAllColumnFamiliesWriterSuite extends
StateDataSourceTestBase
}
}
- private def getJoinV3ColumnSchemaMap(): Map[String,
StatePartitionWriterColumnFamilyInfo] = {
- val schemas =
StreamStreamJoinTestUtils.getJoinV3ColumnSchemaMapWithMetadata()
- schemas.map { case (cfName, metadata) =>
- cfName -> createColFamilyInfo(
- metadata.keySchema,
- metadata.valueSchema,
- metadata.encoderSpec,
- cfName,
- metadata.useMultipleValuePerKey)
- }
- }
/**
* Helper method to test round-trip for stream-stream join with different
versions.
*/
@@ -450,8 +426,375 @@ class StatePartitionAllColumnFamiliesWriterSuite extends
StateDataSourceTestBase
}
}
- testWithChangelogConfig("SPARK-54420: aggregation state ver 1") {
- testRoundTripForAggrStateVersion(1)
+ // Run transformWithState tests with enable/disable checkpoint V2
+ Seq(1, 2).foreach { ckptVersion =>
+ def testWithChangelogAndCheckpointId(testName: String)(testFun: =>
Unit): Unit = {
+ testWithChangelogConfig(s"$testName (enableCkptId = ${ckptVersion >=
2})") {
+ withSQLConf(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key ->
ckptVersion.toString) {
+ testFun
+ }
+ }
+ }
+
+ testWithChangelogAndCheckpointId("SPARK-54420: aggregation state ver 1")
{
+ testRoundTripForAggrStateVersion(1)
+ }
+
+ Seq(1, 2).foreach { version =>
+ testWithChangelogAndCheckpointId(s"SPARK-54420: stream-stream join
state ver $version") {
+ testStreamStreamJoinRoundTrip(version)
+ }
+ }
+ // Run transformWithState tests with different encoding formats
+ Seq("unsaferow", "avro").foreach { encodingFormat =>
+ def testWithChangelogAndEncodingConfig(testName: String)(testFun: =>
Unit): Unit = {
+ testWithChangelogAndCheckpointId(
+ s"$testName (encoding = $encodingFormat)") {
+ withSQLConf(
+ SQLConf.STREAMING_STATE_STORE_ENCODING_FORMAT.key ->
encodingFormat) {
+ testFun
+ }
+ }
+ }
+
+ testWithChangelogAndEncodingConfig(
+ "SPARK-54411: transformWithState with multiple column families") {
+ withTempDir { sourceDir =>
+ withTempDir { targetDir =>
+ val inputData = MemoryStream[String]
+ val query = inputData.toDS()
+ .groupByKey(x => x)
+ .transformWithState(new MultiStateVarProcessor(),
+ TimeMode.None(),
+ OutputMode.Update())
+ def runQuery(checkpointLocation: String, roundsOfData: Int):
Unit = {
+ val dataActions = (1 to roundsOfData).flatMap { _ =>
+ Seq(
+ AddData(inputData, "a", "b", "a"),
+ ProcessAllAvailable()
+ )
+ }
+ testStream(query)(
+ Seq(StartStream(checkpointLocation = checkpointLocation)) ++
+ dataActions ++
+ Seq(StopStream): _*
+ )
+ }
+
+ runQuery(sourceDir.getAbsolutePath, 2)
+ runQuery(targetDir.getAbsolutePath, 1)
+
+ val schemas =
MultiStateVarProcessorTestUtils.getSchemasWithMetadata()
+ val columnFamilyToSelectExprs = MultiStateVarProcessorTestUtils
+ .getColumnFamilyToSelectExprs()
+
+ val columnFamilyToStateSourceOptions = schemas.keys.map { cfName
=>
+ val base = Map(
+ StateSourceOptions.STATE_VAR_NAME -> cfName
+ )
+
+ val withFlatten =
+ if (cfName == MultiStateVarProcessorTestUtils.ITEMS_LIST) {
+ base + (StateSourceOptions.FLATTEN_COLLECTION_TYPES ->
"true")
+ } else {
+ base
+ }
+
+ cfName -> withFlatten
+ }.toMap
+
+ performRoundTripTest(
+ sourceDir.getAbsolutePath,
+ targetDir.getAbsolutePath,
+ storeToColumnFamilies = Map(StateStoreId.DEFAULT_STORE_NAME ->
schemas.keys.toList),
+ storeToColumnFamilyToSelectExprs =
+ Map(StateStoreId.DEFAULT_STORE_NAME ->
columnFamilyToSelectExprs),
+ storeToColumnFamilyToStateSourceOptions =
+ Map(StateStoreId.DEFAULT_STORE_NAME ->
columnFamilyToStateSourceOptions),
+ operatorName =
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
+ )
+ }
+ }
+ }
+
+ testWithChangelogAndEncodingConfig(
+ "SPARK-54411: transformWithState with eventTime timers") {
+ withTempDir { sourceDir =>
+ withTempDir { targetDir =>
+ val inputData = MemoryStream[(String, Long)]
+ val result = inputData.toDS()
+ .select(col("_1").as("key"),
timestamp_seconds(col("_2")).as("eventTime"))
+ .withWatermark("eventTime", "10 seconds")
+ .as[(String, Timestamp)]
+ .groupByKey(_._1)
+ .transformWithState(
+ new EventTimeTimerProcessor(),
+ TimeMode.EventTime(),
+ OutputMode.Update())
+
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation = sourceDir.getAbsolutePath),
+ AddData(inputData, ("a", 1L), ("b", 2L), ("c", 3L)),
+ ProcessAllAvailable(),
+ StopStream
+ )
+
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation = targetDir.getAbsolutePath),
+ AddData(inputData, ("x", 1L)),
+ ProcessAllAvailable(),
+ StopStream
+ )
+
+ val (schemaMap, selectExprs, stateSourceOptions) =
+ getTimerStateConfigsForCountState(TimeMode.EventTime())
+
+ performRoundTripTest(
+ sourceDir.getAbsolutePath,
+ targetDir.getAbsolutePath,
+ storeToColumnFamilies =
+ Map(StateStoreId.DEFAULT_STORE_NAME ->
schemaMap.keys.toList),
+ storeToColumnFamilyToSelectExprs =
+ Map(StateStoreId.DEFAULT_STORE_NAME -> selectExprs),
+ storeToColumnFamilyToStateSourceOptions =
+ Map(StateStoreId.DEFAULT_STORE_NAME -> stateSourceOptions),
+ operatorName =
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
+ )
+ }
+ }
+ }
+
+ testWithChangelogAndEncodingConfig(
+ "SPARK-54411: transformWithState with processing time timers") {
+ withTempDir { sourceDir =>
+ withTempDir { targetDir =>
+ val clock = new StreamManualClock
+ val inputData = MemoryStream[String]
+ val result = inputData.toDS()
+ .groupByKey(x => x)
+ .transformWithState(new
RunningCountStatefulProcessorWithProcTimeTimer(),
+ TimeMode.ProcessingTime(),
+ OutputMode.Update())
+
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation = sourceDir.getAbsolutePath,
+ trigger = Trigger.ProcessingTime("1 second"),
+ triggerClock = clock),
+ AddData(inputData, "a"),
+ AdvanceManualClock(1 * 1000),
+ CheckNewAnswer(("a", "1")),
+ StopStream
+ )
+
+ val clock2 = new StreamManualClock
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation = targetDir.getAbsolutePath,
+ trigger = Trigger.ProcessingTime("1 second"),
+ triggerClock = clock2),
+ AddData(inputData, "x"),
+ AdvanceManualClock(1 * 1000),
+ CheckNewAnswer(("a", "1"), ("x", "1")),
+ StopStream
+ )
+
+ val (schemaMap, selectExprs, sourceOptions) =
+ getTimerStateConfigsForCountState(TimeMode.ProcessingTime())
+
+ performRoundTripTest(
+ sourceDir.getAbsolutePath,
+ targetDir.getAbsolutePath,
+ storeToColumnFamilies =
+ Map(StateStoreId.DEFAULT_STORE_NAME ->
schemaMap.keys.toList),
+ storeToColumnFamilyToSelectExprs =
+ Map(StateStoreId.DEFAULT_STORE_NAME -> selectExprs),
+ storeToColumnFamilyToStateSourceOptions =
+ Map(StateStoreId.DEFAULT_STORE_NAME -> sourceOptions),
+ operatorName =
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
+ )
+ }
+ }
+ }
+
+ testWithChangelogAndEncodingConfig("SPARK-54411: transformWithState
with list and TTL") {
+ withTempDir { sourceDir =>
+ withTempDir { targetDir =>
+ val clock = new StreamManualClock
+ val inputData = MemoryStream[InputEvent]
+ val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
+ val result = inputData.toDS()
+ .groupByKey(x => x.key)
+ .transformWithState(new ListStateTTLProcessor(ttlConfig),
+ TimeMode.ProcessingTime(),
+ OutputMode.Update())
+
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation = sourceDir.getAbsolutePath,
+ trigger = Trigger.ProcessingTime("1 second"),
+ triggerClock = clock),
+ AddData(inputData, InputEvent("k1", "put", 1)),
+ AdvanceManualClock(1 * 1000),
+ CheckNewAnswer(),
+ StopStream
+ )
+
+ val clock2 = new StreamManualClock
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation = targetDir.getAbsolutePath,
+ trigger = Trigger.ProcessingTime("1 second"),
+ triggerClock = clock2),
+ AddData(inputData, InputEvent("k1", "append", 1)),
+ AdvanceManualClock(1 * 1000),
+ CheckNewAnswer(),
+ StopStream
+ )
+
+ val schemas =
TTLProcessorUtils.getListStateTTLSchemasWithMetadata()
+
+ val columnFamilyToSelectExprs = Map(
+ TTLProcessorUtils.LIST_STATE ->
TTLProcessorUtils.getTTLSelectExpressions(
+ TTLProcessorUtils.LIST_STATE
+ ))
+
+ val columnFamilyToStateSourceOptions = schemas.keys.map { cfName
=>
+ val base = Map(StateSourceOptions.STATE_VAR_NAME -> cfName)
+
+ val withFlatten =
+ if (cfName == TTLProcessorUtils.LIST_STATE) {
+ base + (StateSourceOptions.FLATTEN_COLLECTION_TYPES ->
"true")
+ } else {
+ base
+ }
+
+ cfName -> withFlatten
+ }.toMap
+
+ performRoundTripTest(
+ sourceDir.getAbsolutePath,
+ targetDir.getAbsolutePath,
+ storeToColumnFamilies =
+ Map(StateStoreId.DEFAULT_STORE_NAME -> schemas.keys.toList),
+ storeToColumnFamilyToSelectExprs =
+ Map(StateStoreId.DEFAULT_STORE_NAME ->
columnFamilyToSelectExprs),
+ storeToColumnFamilyToStateSourceOptions =
+ Map(StateStoreId.DEFAULT_STORE_NAME ->
columnFamilyToStateSourceOptions),
+ operatorName =
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
+ )
+ }
+ }
+ }
+
+ testWithChangelogAndEncodingConfig("SPARK-54411: transformWithState
with map and TTL") {
+ withTempDir { sourceDir =>
+ withTempDir { targetDir =>
+ val clock = new StreamManualClock
+ val inputData = MemoryStream[MapInputEvent]
+ val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
+ val result = inputData.toDS()
+ .groupByKey(x => x.key)
+ .transformWithState(new MapStateTTLProcessor(ttlConfig),
+ TimeMode.ProcessingTime(),
+ OutputMode.Update())
+
+ testStream(result)(
+ StartStream(checkpointLocation = sourceDir.getAbsolutePath,
+ trigger = Trigger.ProcessingTime("1 second"),
+ triggerClock = clock),
+ AddData(inputData, MapInputEvent("a", "key1", "put", 1)),
+ AdvanceManualClock(1 * 1000),
+ CheckNewAnswer(),
+ StopStream
+ )
+
+ val clock2 = new StreamManualClock
+ testStream(result)(
+ StartStream(checkpointLocation = targetDir.getAbsolutePath,
+ trigger = Trigger.ProcessingTime("1 second"),
+ triggerClock = clock2),
+ AddData(inputData, MapInputEvent("x", "key1", "put", 1)),
+ AdvanceManualClock(1 * 1000),
+ CheckNewAnswer(),
+ StopStream
+ )
+
+ val schemas =
TTLProcessorUtils.getMapStateTTLSchemasWithMetadata()
+
+ val columnFamilyToSelectExprs = Map(
+ TTLProcessorUtils.MAP_STATE ->
TTLProcessorUtils.getTTLSelectExpressions(
+ TTLProcessorUtils.MAP_STATE
+ ))
+
+ val columnFamilyToStateSourceOptions = schemas.keys.map { cfName
=>
+ cfName -> Map(StateSourceOptions.STATE_VAR_NAME -> cfName)
+ }.toMap
+
+ performRoundTripTest(
+ sourceDir.getAbsolutePath,
+ targetDir.getAbsolutePath,
+ storeToColumnFamilies =
+ Map(StateStoreId.DEFAULT_STORE_NAME -> schemas.keys.toList),
+ storeToColumnFamilyToSelectExprs =
+ Map(StateStoreId.DEFAULT_STORE_NAME ->
columnFamilyToSelectExprs),
+ storeToColumnFamilyToStateSourceOptions =
+ Map(StateStoreId.DEFAULT_STORE_NAME ->
columnFamilyToStateSourceOptions),
+ operatorName =
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
+ )
+ }
+ }
+ }
+
+ testWithChangelogAndEncodingConfig("SPARK-54411: transformWithState
with value and TTL") {
+ withTempDir { sourceDir =>
+ withTempDir { targetDir =>
+ val clock = new StreamManualClock
+ val inputData = MemoryStream[InputEvent]
+ val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
+ val result = inputData.toDS()
+ .groupByKey(x => x.key)
+ .transformWithState(new ValueStateTTLProcessor(ttlConfig),
+ TimeMode.ProcessingTime(),
+ OutputMode.Update())
+
+ testStream(result)(
+ StartStream(checkpointLocation = sourceDir.getAbsolutePath,
+ trigger = Trigger.ProcessingTime("1 second"),
+ triggerClock = clock),
+ AddData(inputData, InputEvent("k1", "put", 1)),
+ AddData(inputData, InputEvent("k2", "put", 2)),
+ AdvanceManualClock(1 * 1000),
+ CheckNewAnswer(),
+ StopStream
+ )
+
+ val clock2 = new StreamManualClock
+ testStream(result)(
+ StartStream(checkpointLocation = targetDir.getAbsolutePath,
+ trigger = Trigger.ProcessingTime("1 second"),
+ triggerClock = clock2),
+ AddData(inputData, InputEvent("x", "put", 1)),
+ AdvanceManualClock(1 * 1000),
+ CheckNewAnswer(),
+ StopStream
+ )
+
+ val schemas =
TTLProcessorUtils.getValueStateTTLSchemasWithMetadata()
+
+ val columnFamilyToStateSourceOptions = schemas.keys.map { cfName
=>
+ cfName -> Map(StateSourceOptions.STATE_VAR_NAME -> cfName)
+ }.toMap
+
+ performRoundTripTest(
+ sourceDir.getAbsolutePath,
+ targetDir.getAbsolutePath,
+ storeToColumnFamilies =
+ Map(StateStoreId.DEFAULT_STORE_NAME -> schemas.keys.toList),
+ storeToColumnFamilyToStateSourceOptions =
+ Map(StateStoreId.DEFAULT_STORE_NAME ->
columnFamilyToStateSourceOptions),
+ operatorName =
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
+ )
+ }
+ }
+ }
+ } // End of foreach loop for encoding format (transformWithState tests
only)
}
testWithChangelogConfig("SPARK-54420: aggregation state ver 2") {
@@ -573,12 +916,6 @@ class StatePartitionAllColumnFamiliesWriterSuite extends
StateDataSourceTestBase
}
}
- Seq(1, 2).foreach { version =>
- testWithChangelogConfig(s"SPARK-54420: stream-stream join state ver
$version") {
- testStreamStreamJoinRoundTrip(version)
- }
- }
-
testWithChangelogConfig("SPARK-54411: stream-stream join state ver 3") {
withSQLConf(
SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION.key -> "3"
@@ -620,358 +957,44 @@ class StatePartitionAllColumnFamiliesWriterSuite extends
StateDataSourceTestBase
}
}
}
+ } // End of foreach loop for changelog checkpointing dimension
- // Run transformWithState tests with different encoding formats
- Seq("unsaferow", "avro").foreach { encodingFormat =>
- def testWithChangelogAndEncodingConfig(testName: String)(testFun: =>
Unit): Unit = {
- test(s"$testName ($changelogCpTestSuffix, encoding =
$encodingFormat)") {
- withSQLConf(
-
"spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled" ->
- changelogCheckpointingEnabled.toString,
- SQLConf.STREAMING_STATE_STORE_ENCODING_FORMAT.key ->
encodingFormat) {
- testFun
- }
- }
- }
-
- testWithChangelogAndEncodingConfig(
- "SPARK-54411: transformWithState with multiple column families") {
- withTempDir { sourceDir =>
- withTempDir { targetDir =>
- val inputData = MemoryStream[String]
- val query = inputData.toDS()
- .groupByKey(x => x)
- .transformWithState(new MultiStateVarProcessor(),
- TimeMode.None(),
- OutputMode.Update())
- def runQuery(checkpointLocation: String, roundsOfData: Int): Unit
= {
- val dataActions = (1 to roundsOfData).flatMap { _ =>
- Seq(
- AddData(inputData, "a", "b", "a"),
- ProcessAllAvailable()
- )
- }
- testStream(query)(
- Seq(StartStream(checkpointLocation = checkpointLocation)) ++
- dataActions ++
- Seq(StopStream): _*
- )
- }
-
- runQuery(sourceDir.getAbsolutePath, 2)
- runQuery(targetDir.getAbsolutePath, 1)
-
- val schemas =
MultiStateVarProcessorTestUtils.getSchemasWithMetadata()
- val columnFamilyToSelectExprs = MultiStateVarProcessorTestUtils
- .getColumnFamilyToSelectExprs()
-
- val columnFamilyToStateSourceOptions = schemas.keys.map { cfName =>
- val base = Map(
- StateSourceOptions.STATE_VAR_NAME -> cfName
- )
-
- val withFlatten =
- if (cfName == MultiStateVarProcessorTestUtils.ITEMS_LIST) {
- base + (StateSourceOptions.FLATTEN_COLLECTION_TYPES ->
"true")
- } else {
- base
- }
-
- cfName -> withFlatten
- }.toMap
-
- performRoundTripTest(
- sourceDir.getAbsolutePath,
- targetDir.getAbsolutePath,
- storeToColumnFamilies = Map(StateStoreId.DEFAULT_STORE_NAME ->
schemas.keys.toList),
- storeToColumnFamilyToSelectExprs =
- Map(StateStoreId.DEFAULT_STORE_NAME ->
columnFamilyToSelectExprs),
- storeToColumnFamilyToStateSourceOptions =
- Map(StateStoreId.DEFAULT_STORE_NAME ->
columnFamilyToStateSourceOptions),
- operatorName =
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
- )
- }
- }
- }
-
- testWithChangelogAndEncodingConfig("SPARK-54411: transformWithState with
eventTime timers") {
- withTempDir { sourceDir =>
- withTempDir { targetDir =>
- val inputData = MemoryStream[(String, Long)]
- val result = inputData.toDS()
- .select(col("_1").as("key"),
timestamp_seconds(col("_2")).as("eventTime"))
- .withWatermark("eventTime", "10 seconds")
- .as[(String, Timestamp)]
- .groupByKey(_._1)
- .transformWithState(
- new EventTimeTimerProcessor(),
- TimeMode.EventTime(),
- OutputMode.Update())
-
- testStream(result, OutputMode.Update())(
- StartStream(checkpointLocation = sourceDir.getAbsolutePath),
- AddData(inputData, ("a", 1L), ("b", 2L), ("c", 3L)),
- ProcessAllAvailable(),
- StopStream
- )
-
- testStream(result, OutputMode.Update())(
- StartStream(checkpointLocation = targetDir.getAbsolutePath),
- AddData(inputData, ("x", 1L)),
- ProcessAllAvailable(),
- StopStream
- )
-
- val (schemaMap, selectExprs, stateSourceOptions) =
- getTimerStateConfigsForCountState(TimeMode.EventTime())
-
- performRoundTripTest(
- sourceDir.getAbsolutePath,
- targetDir.getAbsolutePath,
- storeToColumnFamilies =
- Map(StateStoreId.DEFAULT_STORE_NAME -> schemaMap.keys.toList),
- storeToColumnFamilyToSelectExprs =
- Map(StateStoreId.DEFAULT_STORE_NAME -> selectExprs),
- storeToColumnFamilyToStateSourceOptions =
- Map(StateStoreId.DEFAULT_STORE_NAME -> stateSourceOptions),
- operatorName =
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
- )
- }
- }
- }
-
- testWithChangelogAndEncodingConfig(
- "SPARK-54411: transformWithState with processing time timers") {
- withTempDir { sourceDir =>
- withTempDir { targetDir =>
- val clock = new StreamManualClock
- val inputData = MemoryStream[String]
- val result = inputData.toDS()
- .groupByKey(x => x)
- .transformWithState(new
RunningCountStatefulProcessorWithProcTimeTimer(),
- TimeMode.ProcessingTime(),
- OutputMode.Update())
-
- testStream(result, OutputMode.Update())(
- StartStream(checkpointLocation = sourceDir.getAbsolutePath,
- trigger = Trigger.ProcessingTime("1 second"),
- triggerClock = clock),
- AddData(inputData, "a"),
- AdvanceManualClock(1 * 1000),
- CheckNewAnswer(("a", "1")),
- StopStream
- )
-
- val clock2 = new StreamManualClock
- testStream(result, OutputMode.Update())(
- StartStream(checkpointLocation = targetDir.getAbsolutePath,
- trigger = Trigger.ProcessingTime("1 second"),
- triggerClock = clock2),
- AddData(inputData, "x"),
- AdvanceManualClock(1 * 1000),
- CheckNewAnswer(("a", "1"), ("x", "1")),
- StopStream
- )
-
- val (schemaMap, selectExprs, sourceOptions) =
- getTimerStateConfigsForCountState(TimeMode.ProcessingTime())
-
- performRoundTripTest(
- sourceDir.getAbsolutePath,
- targetDir.getAbsolutePath,
- storeToColumnFamilies =
- Map(StateStoreId.DEFAULT_STORE_NAME -> schemaMap.keys.toList),
- storeToColumnFamilyToSelectExprs =
- Map(StateStoreId.DEFAULT_STORE_NAME -> selectExprs),
- storeToColumnFamilyToStateSourceOptions =
- Map(StateStoreId.DEFAULT_STORE_NAME -> sourceOptions),
- operatorName =
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
- )
- }
- }
- }
-
- testWithChangelogAndEncodingConfig("SPARK-54411: transformWithState with
list and TTL") {
- withTempDir { sourceDir =>
- withTempDir { targetDir =>
- val clock = new StreamManualClock
- val inputData = MemoryStream[InputEvent]
- val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
- val result = inputData.toDS()
- .groupByKey(x => x.key)
- .transformWithState(new ListStateTTLProcessor(ttlConfig),
- TimeMode.ProcessingTime(),
- OutputMode.Update())
-
- testStream(result, OutputMode.Update())(
- StartStream(checkpointLocation = sourceDir.getAbsolutePath,
- trigger = Trigger.ProcessingTime("1 second"),
- triggerClock = clock),
- AddData(inputData, InputEvent("k1", "put", 1)),
- AdvanceManualClock(1 * 1000),
- CheckNewAnswer(),
- StopStream
- )
-
- val clock2 = new StreamManualClock
- testStream(result, OutputMode.Update())(
- StartStream(checkpointLocation = targetDir.getAbsolutePath,
- trigger = Trigger.ProcessingTime("1 second"),
- triggerClock = clock2),
- AddData(inputData, InputEvent("k1", "append", 1)),
- AdvanceManualClock(1 * 1000),
- CheckNewAnswer(),
- StopStream
- )
-
- val schemas =
TTLProcessorUtils.getListStateTTLSchemasWithMetadata()
-
- val columnFamilyToSelectExprs = Map(
- TTLProcessorUtils.LIST_STATE ->
TTLProcessorUtils.getTTLSelectExpressions(
- TTLProcessorUtils.LIST_STATE
- ))
-
- val columnFamilyToStateSourceOptions = schemas.keys.map { cfName =>
- val base = Map(StateSourceOptions.STATE_VAR_NAME -> cfName)
-
- val withFlatten =
- if (cfName == TTLProcessorUtils.LIST_STATE) {
- base + (StateSourceOptions.FLATTEN_COLLECTION_TYPES ->
"true")
- } else {
- base
- }
-
- cfName -> withFlatten
- }.toMap
-
- performRoundTripTest(
- sourceDir.getAbsolutePath,
- targetDir.getAbsolutePath,
- storeToColumnFamilies =
- Map(StateStoreId.DEFAULT_STORE_NAME -> schemas.keys.toList),
- storeToColumnFamilyToSelectExprs =
- Map(StateStoreId.DEFAULT_STORE_NAME ->
columnFamilyToSelectExprs),
- storeToColumnFamilyToStateSourceOptions =
- Map(StateStoreId.DEFAULT_STORE_NAME ->
columnFamilyToStateSourceOptions),
- operatorName =
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
- )
- }
- }
- }
-
- testWithChangelogAndEncodingConfig("SPARK-54411: transformWithState with
map and TTL") {
- withTempDir { sourceDir =>
- withTempDir { targetDir =>
- val clock = new StreamManualClock
- val inputData = MemoryStream[MapInputEvent]
- val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
- val result = inputData.toDS()
- .groupByKey(x => x.key)
- .transformWithState(new MapStateTTLProcessor(ttlConfig),
- TimeMode.ProcessingTime(),
- OutputMode.Update())
-
- testStream(result)(
- StartStream(checkpointLocation = sourceDir.getAbsolutePath,
- trigger = Trigger.ProcessingTime("1 second"),
- triggerClock = clock),
- AddData(inputData, MapInputEvent("a", "key1", "put", 1)),
- AdvanceManualClock(1 * 1000),
- CheckNewAnswer(),
- StopStream
- )
-
- val clock2 = new StreamManualClock
- testStream(result)(
- StartStream(checkpointLocation = targetDir.getAbsolutePath,
- trigger = Trigger.ProcessingTime("1 second"),
- triggerClock = clock2),
- AddData(inputData, MapInputEvent("x", "key1", "put", 1)),
- AdvanceManualClock(1 * 1000),
- CheckNewAnswer(),
- StopStream
- )
-
- val schemas = TTLProcessorUtils.getMapStateTTLSchemasWithMetadata()
-
- val columnFamilyToSelectExprs = Map(
- TTLProcessorUtils.MAP_STATE ->
TTLProcessorUtils.getTTLSelectExpressions(
- TTLProcessorUtils.MAP_STATE
- ))
-
- val columnFamilyToStateSourceOptions = schemas.keys.map { cfName =>
- cfName -> Map(StateSourceOptions.STATE_VAR_NAME -> cfName)
- }.toMap
-
- performRoundTripTest(
- sourceDir.getAbsolutePath,
- targetDir.getAbsolutePath,
- storeToColumnFamilies =
- Map(StateStoreId.DEFAULT_STORE_NAME -> schemas.keys.toList),
- storeToColumnFamilyToSelectExprs =
- Map(StateStoreId.DEFAULT_STORE_NAME ->
columnFamilyToSelectExprs),
- storeToColumnFamilyToStateSourceOptions =
- Map(StateStoreId.DEFAULT_STORE_NAME ->
columnFamilyToStateSourceOptions),
- operatorName =
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
- )
- }
+ test("SPARK-54590: Rewriter throw exception if checkpoint version is not set
correct") {
+ withSQLConf(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "2") {
+ withTempDir { sourceDir =>
+ // Step 1: Create state by running a streaming aggregation
+ runDropDuplicatesQuery(sourceDir.getAbsolutePath)
+ val sourceCheckpointMetadata = new StreamingQueryCheckpointMetadata(
+ spark, sourceDir.getAbsolutePath)
+ val readBatchId =
sourceCheckpointMetadata.commitLog.getLatestBatchId().get
+ // Forced set STATE_STORE_CHECKPOINT_FORMAT_VERSION to 1 to mimic when
user forgot to
+ // update checkpoint version to 2 in sqlConfig when running
stateRewriter
+ // on checkpointV2 query.
+ spark.conf.unset(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key)
+ val ex = intercept[StateRewriterInvalidCheckpointError] {
+ val rewriter = new StateRewriter(
+ spark,
+ readBatchId,
+ readBatchId + 1,
+ sourceDir.getAbsolutePath,
+ spark.sessionState.newHadoopConf()
+ )
+ rewriter.run()
}
- }
-
- testWithChangelogAndEncodingConfig("SPARK-54411: transformWithState with
value and TTL") {
- withTempDir { sourceDir =>
- withTempDir { targetDir =>
- val clock = new StreamManualClock
- val inputData = MemoryStream[InputEvent]
- val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
- val result = inputData.toDS()
- .groupByKey(x => x.key)
- .transformWithState(new ValueStateTTLProcessor(ttlConfig),
- TimeMode.ProcessingTime(),
- OutputMode.Update())
-
- testStream(result)(
- StartStream(checkpointLocation = sourceDir.getAbsolutePath,
- trigger = Trigger.ProcessingTime("1 second"),
- triggerClock = clock),
- AddData(inputData, InputEvent("k1", "put", 1)),
- AddData(inputData, InputEvent("k2", "put", 2)),
- AdvanceManualClock(1 * 1000),
- CheckNewAnswer(),
- StopStream
- )
-
- val clock2 = new StreamManualClock
- testStream(result)(
- StartStream(checkpointLocation = targetDir.getAbsolutePath,
- trigger = Trigger.ProcessingTime("1 second"),
- triggerClock = clock2),
- AddData(inputData, InputEvent("x", "put", 1)),
- AdvanceManualClock(1 * 1000),
- CheckNewAnswer(),
- StopStream
- )
- val schemas =
TTLProcessorUtils.getValueStateTTLSchemasWithMetadata()
-
- val columnFamilyToStateSourceOptions = schemas.keys.map { cfName =>
- cfName -> Map(StateSourceOptions.STATE_VAR_NAME -> cfName)
- }.toMap
-
- performRoundTripTest(
- sourceDir.getAbsolutePath,
- targetDir.getAbsolutePath,
- storeToColumnFamilies =
- Map(StateStoreId.DEFAULT_STORE_NAME -> schemas.keys.toList),
- storeToColumnFamilyToStateSourceOptions =
- Map(StateStoreId.DEFAULT_STORE_NAME ->
columnFamilyToStateSourceOptions),
- operatorName =
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
- )
- }
- }
+ checkError(
+ ex,
+
"STATE_REWRITER_INVALID_CHECKPOINT.STATE_CHECKPOINT_FORMAT_VERSION_MISMATCH",
+ "55019",
+ parameters = Map(
+ "checkpointLocation" -> sourceDir.getAbsolutePath,
+ "expectedVersion" -> "2",
+ "actualVersion" -> "1",
+ "sqlConfKey" -> SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key)
+ )
}
- } // End of foreach loop for encoding format (transformWithState tests
only)
- } // End of foreach loop for changelog checkpointing dimension
+ }
+ }
test("SPARK-54411: Non-JoinV3 operator requires default column family in
schema") {
withTempDir { targetDir =>
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]