This is an automated email from the ASF dual-hosted git repository.
ashrigondekar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 6653e00a10c6 [SPARK-54420][SS] Introduce Offline Repartitioning
StatePartitionWriter for Single Column Family
6653e00a10c6 is described below
commit 6653e00a10c694fcfc9eb10082761ace5d702eae
Author: zifeif2 <[email protected]>
AuthorDate: Wed Dec 10 11:48:59 2025 -0800
[SPARK-54420][SS] Introduce Offline Repartitioning StatePartitionWriter for
Single Column Family
### What changes were proposed in this pull request?
Introducing StatePartitionAllColumnFamiliesWriter as part of the offline
repartition project. In this PR, we only support a single-column-family
operator.
This writer takes the repartitioned DataFrame returned from
**StatePartitionAllColumnFamiliesReader** and writes it to a new version in the
state store. See the comments for the DataFrame schema. In addition, this
writer does not load previous state (since we are overwriting the state with
the repartitioned data), and when committing, it will always commit a snapshot.
**Major Changes**
- Introduce StatePartitionAllColumnFamiliesWriter
- Introduce a new parameter loadEmpty for StateStoreProvider.getStore()
- Introduce a new function loadEmpty for RocksDB
### Why are the changes needed?
This will be used in offline repartitioning to allow
OfflineRepartitioningRunner to directly write data to state store
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Integration tests in
`sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesWriterSuite.scala`
Unit tests in
`sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala`
### Was this patch authored or co-authored using generative AI tooling?
Yes. Sonnet 4.5
Closes #53287 from zifeif2/simple-writer.
Authored-by: zifeif2 <[email protected]>
Signed-off-by: Anish Shrigondekar <[email protected]>
---
.../datasources/v2/state/utils/SchemaUtil.scala | 4 +-
.../state/HDFSBackedStateStoreProvider.scala | 7 +-
.../sql/execution/streaming/state/RocksDB.scala | 63 +-
.../state/RocksDBStateStoreProvider.scala | 13 +-
.../streaming/state/StatePartitionWriter.scala | 128 ++++
.../sql/execution/streaming/state/StateStore.scala | 4 +-
.../v2/state/StateDataSourceTestBase.scala | 21 +-
.../RocksDBStateStoreCheckpointFormatV2Suite.scala | 6 +-
.../execution/streaming/state/RocksDBSuite.scala | 47 +-
...tatePartitionAllColumnFamiliesWriterSuite.scala | 680 +++++++++++++++++++++
.../streaming/state/StateStoreSuite.scala | 29 +-
.../apache/spark/sql/streaming/StreamSuite.scala | 3 +-
12 files changed, 965 insertions(+), 40 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala
index 44d83fc99b57..44e032f5163a 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala
@@ -63,8 +63,8 @@ object SchemaUtil {
.add("partition_id", IntegerType)
} else if (sourceOptions.internalOnlyReadAllColumnFamilies) {
new StructType()
- // todo [SPARK-54443]: change keySchema to a more specific type after
we
- // can extract partition key from keySchema
+ // TODO [SPARK-54443]: change keySchema to a more specific type after
we
+ // can extract partition key from keySchema
.add("partition_key", keySchema)
.add("key_bytes", BinaryType)
.add("value_bytes", BinaryType)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
index 9da75a9728dd..12399dccf422 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
@@ -322,12 +322,17 @@ private[sql] class HDFSBackedStateStoreProvider extends
StateStoreProvider with
override def getStore(
version: Long,
uniqueId: Option[String] = None,
- forceSnapshotOnCommit: Boolean = false): StateStore = {
+ forceSnapshotOnCommit: Boolean = false,
+ loadEmpty: Boolean = false): StateStore = {
if (uniqueId.isDefined) {
throw StateStoreErrors.stateStoreCheckpointIdsNotSupported(
"HDFSBackedStateStoreProvider does not support checkpointFormatVersion
> 1 " +
"but a state store checkpointID is passed in")
}
+ if (loadEmpty) {
+ throw StateStoreErrors.unsupportedOperationException("getStore",
+ "Internal Error: HDFSBackedStateStoreProvider doesn't support
loadEmpty")
+ }
val newMap = getLoadedMapForStore(version)
logInfo(log"Retrieved version ${MDC(LogKeys.STATE_STORE_VERSION, version)}
" +
log"of ${MDC(LogKeys.STATE_STORE_PROVIDER,
HDFSBackedStateStoreProvider.this)} " +
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
index c92c5017cada..39410ca15432 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
@@ -180,7 +180,7 @@ class RocksDB(
@volatile private var db: NativeRocksDB = _
@volatile private var changelogWriter: Option[StateStoreChangelogWriter] =
None
- private val enableChangelogCheckpointing: Boolean =
conf.enableChangelogCheckpointing
+ @volatile private var enableChangelogCheckpointing: Boolean =
conf.enableChangelogCheckpointing
@volatile protected var loadedVersion: Long = -1L // -1 = nothing valid is
loaded
// Can be updated by whichever thread uploaded a snapshot, which could be
either task,
@@ -553,21 +553,36 @@ class RocksDB(
this
}
+ private def loadEmptyStoreWithoutCheckpointId(version: Long): Unit = {
+ // Use version 0 logic to create empty directory with no SST files
+ val metadata = fileManager.loadCheckpointFromDfs(0, workingDir,
rocksDBFileMapping, None)
+ loadedVersion = version
+ fileManager.setMaxSeenVersion(version)
+ openLocalRocksDB(metadata)
+ }
+
private def loadWithoutCheckpointId(
version: Long,
- readOnly: Boolean = false): RocksDB = {
+ readOnly: Boolean = false,
+ loadEmpty: Boolean = false): RocksDB = {
+
try {
- if (loadedVersion != version) {
+ // For loadEmpty, always proceed; otherwise, only if version changed
+ if (loadEmpty || loadedVersion != version) {
closeDB(ignoreException = false)
- // load the latest snapshot
- loadSnapshotWithoutCheckpointId(version)
-
- if (loadedVersion != version) {
- val versionsAndUniqueIds: Array[(Long, Option[String])] =
- (loadedVersion + 1 to version).map((_, None)).toArray
- replayChangelog(versionsAndUniqueIds)
- loadedVersion = version
+ if (loadEmpty) {
+ loadEmptyStoreWithoutCheckpointId(version)
+ } else {
+ // load the latest snapshot
+ loadSnapshotWithoutCheckpointId(version)
+
+ if (loadedVersion != version) {
+ val versionsAndUniqueIds: Array[(Long, Option[String])] =
+ (loadedVersion + 1 to version).map((_, None)).toArray
+ replayChangelog(versionsAndUniqueIds)
+ loadedVersion = version
+ }
}
// After changelog replay the numKeysOnWritingVersion will be updated
to
// the correct number of keys in the loaded version.
@@ -578,16 +593,27 @@ class RocksDB(
if (conf.resetStatsOnLoad) {
nativeStats.reset
}
- logInfo(log"Loaded ${MDC(LogKeys.VERSION_NUM, version)}")
+ if (loadEmpty) {
+ logInfo(log"Loaded empty store at version ${MDC(LogKeys.VERSION_NUM,
version)}")
+ } else {
+ logInfo(log"Loaded ${MDC(LogKeys.VERSION_NUM, version)}")
+ }
} catch {
case t: Throwable =>
loadedVersion = -1 // invalidate loaded data
throw t
}
- if (enableChangelogCheckpointing && !readOnly) {
+ // Checking conf.enableChangelogCheckpointing instead of
enableChangelogCheckpointing.
+ // enableChangelogCheckpointing is set to false when loadEmpty is true,
but we still want
+ // to abort previous used changelogWriter if there is any
+ if (conf.enableChangelogCheckpointing && !readOnly) {
// Make sure we don't leak resource.
changelogWriter.foreach(_.abort())
- changelogWriter = Some(fileManager.getChangeLogWriter(version + 1,
useColumnFamilies))
+ if (loadEmpty) {
+ changelogWriter = None
+ } else {
+ changelogWriter = Some(fileManager.getChangeLogWriter(version + 1,
useColumnFamilies))
+ }
}
this
}
@@ -703,7 +729,8 @@ class RocksDB(
def load(
version: Long,
stateStoreCkptId: Option[String] = None,
- readOnly: Boolean = false): RocksDB = {
+ readOnly: Boolean = false,
+ loadEmpty: Boolean = false): RocksDB = {
val startTime = System.currentTimeMillis()
assert(version >= 0)
@@ -714,10 +741,14 @@ class RocksDB(
logInfo(log"Loading ${MDC(LogKeys.VERSION_NUM, version)} with
stateStoreCkptId: ${
MDC(LogKeys.UUID, stateStoreCkptId.getOrElse(""))}")
+ // If loadEmpty is true, we will not generate a changelog but only a
snapshot file to prevent
+ // mistakenly applying new changelog to older state version
+ enableChangelogCheckpointing = if (loadEmpty) false else
conf.enableChangelogCheckpointing
if (stateStoreCkptId.isDefined || enableStateStoreCheckpointIds && version
== 0) {
+ assert(!loadEmpty, "loadEmpty not supported for checkpointV2 yet")
loadWithCheckpointId(version, stateStoreCkptId, readOnly)
} else {
- loadWithoutCheckpointId(version, readOnly)
+ loadWithoutCheckpointId(version, readOnly, loadEmpty)
}
// Record the metrics after loading
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
index 96652ffe6fd7..2a6761733ae0 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
@@ -726,6 +726,7 @@ private[sql] class RocksDBStateStoreProvider
* @param readOnly Whether to open the store in read-only mode
* @param existingStore Optional existing store to reuse instead of creating
a new one
* @param forceSnapshotOnCommit Whether to force a snapshot upload on commit
+ * @param loadEmpty If true, creates an empty store at this version without
loading previous data
* @return The loaded state store
*/
private def loadStateStore(
@@ -733,7 +734,8 @@ private[sql] class RocksDBStateStoreProvider
uniqueId: Option[String] = None,
readOnly: Boolean,
existingStore: Option[RocksDBStateStore] = None,
- forceSnapshotOnCommit: Boolean = false): StateStore = {
+ forceSnapshotOnCommit: Boolean = false,
+ loadEmpty: Boolean = false): StateStore = {
var acquiredStamp: Option[Long] = None
var storeLoaded = false
try {
@@ -765,7 +767,8 @@ private[sql] class RocksDBStateStoreProvider
rocksDB.load(
version,
stateStoreCkptId = if (storeConf.enableStateStoreCheckpointIds)
uniqueId else None,
- readOnly = readOnly)
+ readOnly = readOnly,
+ loadEmpty = loadEmpty)
// Create or reuse store instance
val store = existingStore match {
@@ -806,12 +809,14 @@ private[sql] class RocksDBStateStoreProvider
override def getStore(
version: Long,
uniqueId: Option[String] = None,
- forceSnapshotOnCommit: Boolean = false): StateStore = {
+ forceSnapshotOnCommit: Boolean = false,
+ loadEmpty: Boolean = false): StateStore = {
loadStateStore(
version,
uniqueId,
readOnly = false,
- forceSnapshotOnCommit = forceSnapshotOnCommit
+ forceSnapshotOnCommit = forceSnapshotOnCommit,
+ loadEmpty = loadEmpty
)
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionWriter.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionWriter.scala
new file mode 100644
index 000000000000..8a4b2a5b8f45
--- /dev/null
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionWriter.scala
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming.state
+
+import java.util.UUID
+
+import scala.collection.MapView
+import scala.collection.immutable.HashMap
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import
org.apache.spark.sql.execution.streaming.runtime.StreamingCheckpointConstants.DIR_NAME_STATE
+
+/**
+ * A writer that can directly write binary data to the streaming state store.
+ *
+ * This writer expects input rows with the same schema produced by
+ * StatePartitionAllColumnFamiliesReader:
+ * (partition_key, key_bytes, value_bytes, column_family_name)
+ *
+ * The writer creates a fresh (empty) state store instance for the target
commit version
+ * instead of loading previous partition data. After writing all rows for the
partition, it will
+ * commit all changes as a snapshot
+ */
+class StatePartitionAllColumnFamiliesWriter(
+ storeConf: StateStoreConf,
+ hadoopConf: Configuration,
+ partitionId: Int,
+ targetCpLocation: String,
+ operatorId: Int,
+ storeName: String,
+ currentBatchId: Long,
+ columnFamilyToSchemaMap: HashMap[String, StateStoreColFamilySchema]) {
+ private val defaultSchema = {
+ columnFamilyToSchemaMap.getOrElse(
+ StateStore.DEFAULT_COL_FAMILY_NAME,
+ throw new IllegalArgumentException(
+ s"Column family ${StateStore.DEFAULT_COL_FAMILY_NAME} not found in
schema map")
+ )
+ }
+
+ private val columnFamilyToKeySchemaLenMap: MapView[String, Int] =
+ columnFamilyToSchemaMap.view.mapValues(_.keySchema.length)
+ private val columnFamilyToValueSchemaLenMap: MapView[String, Int] =
+ columnFamilyToSchemaMap.view.mapValues(_.valueSchema.length)
+
+ protected lazy val provider: StateStoreProvider = {
+ val stateCheckpointLocation = new Path(targetCpLocation,
DIR_NAME_STATE).toString
+ val stateStoreId = StateStoreId(stateCheckpointLocation,
+ operatorId, partitionId, storeName)
+ val stateStoreProviderId = StateStoreProviderId(stateStoreId,
UUID.randomUUID())
+
+ val provider = StateStoreProvider.createAndInit(
+ stateStoreProviderId, defaultSchema.keySchema, defaultSchema.valueSchema,
+ defaultSchema.keyStateEncoderSpec.get,
+ useColumnFamilies = false, storeConf, hadoopConf,
+ useMultipleValuesPerKey = false, stateSchemaProvider = None)
+ provider
+ }
+
+ private lazy val stateStore: StateStore = {
+ // TODO[SPARK-54590]: Support checkpoint V2 in
StatePartitionAllColumnFamiliesWriter
+ // Create empty store to avoid loading old partition data since we are
rewriting the
+ // store e.g. during repartitioning
+ // Use loadEmpty=true to create a fresh state store without loading
previous versions
+ // We create the empty store AT version, and the next commit will
+ // produce version + 1
+ provider.getStore(
+ currentBatchId,
+ stateStoreCkptId = None,
+ loadEmpty = true
+ )
+ }
+
+ // The function that writes and commits data to state store. It takes in
rows with schema
+ // - partition_key, StructType
+ // - key_bytes, BinaryType
+ // - value_bytes, BinaryType
+ // - column_family_name, StringType
+ def write(rows: Iterator[InternalRow]): Unit = {
+ try {
+ rows.foreach(row => writeRow(row))
+ stateStore.commit()
+ } finally {
+ if (!stateStore.hasCommitted) {
+ stateStore.abort()
+ }
+ }
+ }
+
+ private def writeRow(record: InternalRow): Unit = {
+ assert(record.numFields == 4,
+ s"Invalid record schema: expected 4 fields (partition_key, key_bytes,
value_bytes, " +
+ s"column_family_name), got ${record.numFields}")
+
+ // Extract raw bytes and column family name from the record
+ val keyBytes = record.getBinary(1)
+ val valueBytes = record.getBinary(2)
+ val colFamilyName = record.getString(3)
+
+ // Reconstruct UnsafeRow objects from the raw bytes
+ // The bytes are in UnsafeRow memory format from
StatePartitionReaderAllColumnFamilies
+ val keyRow = new UnsafeRow(columnFamilyToKeySchemaLenMap(colFamilyName))
+ keyRow.pointTo(keyBytes, keyBytes.length)
+
+ val valueRow = new
UnsafeRow(columnFamilyToValueSchemaLenMap(colFamilyName))
+ valueRow.pointTo(valueBytes, valueBytes.length)
+
+ stateStore.put(keyRow, valueRow, colFamilyName)
+ }
+}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index 43b95766882f..bd6b4bede84b 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -665,11 +665,13 @@ trait StateStoreProvider {
/**
* Return an instance of [[StateStore]] representing state data of the given
version.
* If `stateStoreCkptId` is provided, the instance also needs to match the
ID.
+ * If `loadEmpty` is true, creates an empty store at this version without
loading previous data.
* */
def getStore(
version: Long,
stateStoreCkptId: Option[String] = None,
- forceSnapshotOnCommit: Boolean = false): StateStore
+ forceSnapshotOnCommit: Boolean = false,
+ loadEmpty: Boolean = false): StateStore
/**
* Return an instance of [[ReadStateStore]] representing state data of the
given version
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala
index 64d005c719b7..b66408bb7d69 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala
@@ -78,7 +78,7 @@ trait StateDataSourceTestBase extends StreamTest with
StateStoreMetricsTest {
)
}
- private def getCompositeKeyStreamingAggregationQuery(
+ protected def getCompositeKeyStreamingAggregationQuery(
inputData: MemoryStream[Int]): Dataset[(Int, String, Long, Long, Int,
Int)] = {
inputData.toDF()
.selectExpr("value", "value % 2 AS groupKey",
@@ -140,7 +140,7 @@ trait StateDataSourceTestBase extends StreamTest with
StateStoreMetricsTest {
)
}
- private def getLargeDataStreamingAggregationQuery(
+ protected def getLargeDataStreamingAggregationQuery(
inputData: MemoryStream[Int]): Dataset[(Int, Long, Long, Int, Int)] = {
inputData.toDF()
.selectExpr("value", "value % 10 AS groupKey")
@@ -179,7 +179,7 @@ trait StateDataSourceTestBase extends StreamTest with
StateStoreMetricsTest {
)
}
- private def getDropDuplicatesQuery(inputData: MemoryStream[Int]):
Dataset[Long] = {
+ protected def getDropDuplicatesQuery(inputData: MemoryStream[Int]):
Dataset[Long] = {
inputData.toDS()
.withColumn("eventTime", timestamp_seconds($"value"))
.withWatermark("eventTime", "10 seconds")
@@ -204,7 +204,7 @@ trait StateDataSourceTestBase extends StreamTest with
StateStoreMetricsTest {
)
}
- private def getDropDuplicatesQueryWithColumnSpecified(
+ protected def getDropDuplicatesQueryWithColumnSpecified(
inputData: MemoryStream[(String, Int)]): Dataset[(String, Int)] = {
inputData.toDS()
.selectExpr("_1 AS col1", "_2 AS col2")
@@ -256,7 +256,7 @@ trait StateDataSourceTestBase extends StreamTest with
StateStoreMetricsTest {
)
}
- private def getDropDuplicatesWithinWatermarkQuery(
+ protected def getDropDuplicatesWithinWatermarkQuery(
inputData: MemoryStream[(String, Int)]): DataFrame = {
inputData.toDS()
.withColumn("eventTime", timestamp_seconds($"_2"))
@@ -293,7 +293,7 @@ trait StateDataSourceTestBase extends StreamTest with
StateStoreMetricsTest {
)
}
- private def getFlatMapGroupsWithStateQuery(
+ protected def getFlatMapGroupsWithStateQuery(
inputData: MemoryStream[(String, Long)]): Dataset[(String, Int, Long,
Boolean)] = {
// scalastyle:off line.size.limit
// This test code is borrowed from Sessionization example, with
modification a bit to run with testStream
@@ -405,8 +405,7 @@ trait StateDataSourceTestBase extends StreamTest with
StateStoreMetricsTest {
col("rightId"), col("rightTime").cast("int"))
}
- protected def runSessionWindowAggregationQuery(checkpointRoot: String): Unit
= {
- val input = MemoryStream[(String, Long)]
+ protected def getSessionWindowAggregationQuery(input: MemoryStream[(String,
Long)]): DataFrame = {
val sessionWindow = session_window($"eventTime", "10 seconds")
val events = input.toDF()
@@ -415,13 +414,17 @@ trait StateDataSourceTestBase extends StreamTest with
StateStoreMetricsTest {
.withWatermark("eventTime", "30 seconds")
.selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime")
- val streamingDf = events
+ events
.groupBy(sessionWindow as Symbol("session"), $"sessionId")
.agg(count("*").as("numEvents"))
.selectExpr("sessionId", "CAST(session.start AS LONG)",
"CAST(session.end AS LONG)",
"CAST(session.end AS LONG) - CAST(session.start AS LONG) AS
durationMs",
"numEvents")
+ }
+ protected def runSessionWindowAggregationQuery(checkpointRoot: String): Unit
= {
+ val input = MemoryStream[(String, Long)]
+ val streamingDf = getSessionWindowAggregationQuery(input)
testStream(streamingDf, OutputMode.Complete())(
StartStream(checkpointLocation = checkpointRoot),
AddData(input,
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala
index 2afb3e69c4e4..5f1adadc30a3 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala
@@ -200,8 +200,10 @@ class CkptIdCollectingStateStoreProviderWrapper extends
StateStoreProvider {
override def getStore(
version: Long,
stateStoreCkptId: Option[String] = None,
- forceSnapshotOnCommit: Boolean = false): StateStore = {
- val innerStateStore = innerProvider.getStore(version, stateStoreCkptId,
forceSnapshotOnCommit)
+ forceSnapshotOnCommit: Boolean = false,
+ loadEmpty: Boolean = false): StateStore = {
+ val innerStateStore = innerProvider.getStore(version, stateStoreCkptId,
+ forceSnapshotOnCommit, loadEmpty)
CkptIdCollectingStateStoreWrapper(innerStateStore)
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
index ca72f7033118..04f879ed64d1 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
@@ -3942,6 +3942,50 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures
with SharedSparkSession
}}
}
+ test("SPARK-54420: load with createEmpty creates empty store") {
+ val remoteDir = Utils.createTempDir().toString
+ new File(remoteDir).delete()
+
+ withDB(remoteDir) { db =>
+ // loading batch 0 with loadEmpty = true
+ db.load(0, None, loadEmpty = true)
+ assert(iterator(db).isEmpty)
+ db.put("a", "1")
+ val (version1, _) = db.commit()
+ assert(toStr(db.get("a")) === "1")
+
+ // check we can load store normally even the previous one loadEmpty =
true
+ db.load(version1)
+ db.put("b", "2")
+ val (version2, _) = db.commit()
+ assert(version2 === version1 + 1)
+ assert(toStr(db.get("b")) === "2")
+ assert(toStr(db.get("a")) === "1")
+
+ // load an empty store
+ db.load(version2, loadEmpty = true)
+ db.put("c", "3")
+ val (version3, _) = db.commit()
+ assert(db.get("b") === null)
+ assert(db.get("a") === null)
+ assert(toStr(db.get("c")) === "3")
+ assert(version3 === version2 + 1)
+
+ // load 2 empty store in a row
+ db.load(version3, loadEmpty = true)
+ db.put("d", "4")
+ val (version4, _) = db.commit()
+ assert(db.get("c") === null)
+ assert(toStr(db.get("d")) === "4")
+ assert(version4 === version3 + 1)
+
+ db.load(version4)
+ db.put("e", "5")
+ db.commit()
+ assert(db.iterator().map(toStr).toSet === Set(("d", "4"), ("e", "5")))
+ }
+ }
+
test("SPARK-44639: Use Java tmp dir instead of configured local dirs on
Yarn") {
val conf = new Configuration()
conf.set(StreamExecution.RUN_ID_KEY, UUID.randomUUID().toString)
@@ -3981,7 +4025,8 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures
with SharedSparkSession
override def load(
version: Long,
ckptId: Option[String] = None,
- readOnly: Boolean = false): RocksDB = {
+ readOnly: Boolean = false,
+ createEmpty: Boolean = false): RocksDB = {
// When a ckptId is defined, it means the test is explicitly using v2
semantic
// When it is not, it is possible that implicitly uses it.
// So still do a versionToUniqueId.get
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala
new file mode 100644
index 000000000000..d18b8b0c4f1a
--- /dev/null
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala
@@ -0,0 +1,680 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming.state
+
+import java.io.File
+
+import scala.collection.immutable.HashMap
+
+import org.apache.spark.TaskContext
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
+import
org.apache.spark.sql.execution.datasources.v2.state.{StateDataSourceTestBase,
StateSourceOptions}
+import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream,
StreamingQueryCheckpointMetadata}
+import org.apache.spark.sql.execution.streaming.utils.StreamingUtils
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{OutputMode, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.sql.types.{BooleanType, IntegerType, LongType,
NullType, StructField, StructType, TimestampType}
+import org.apache.spark.util.SerializableConfiguration
+
+/**
+ * Test suite for StatePartitionAllColumnFamiliesWriter.
+ * Tests the writer's ability to correctly write raw bytes read from
+ * StatePartitionAllColumnFamiliesReader to a state store without loading
previous versions.
+ */
+class StatePartitionAllColumnFamiliesWriterSuite extends
StateDataSourceTestBase {
+ import testImplicits._
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ spark.conf.set(SQLConf.STATE_STORE_PROVIDER_CLASS.key,
+ classOf[RocksDBStateStoreProvider].getName)
+ spark.conf.set(SQLConf.SHUFFLE_PARTITIONS.key, "2")
+ }
+
+ /**
+ * Common helper method to perform round-trip test: read state bytes from
source,
+ * write to target, and verify target matches source.
+ *
+ * @param sourceDir Source checkpoint directory
+ * @param targetDir Target checkpoint directory
+ * @param keySchema Key schema for the state store
+ * @param valueSchema Value schema for the state store
+ * @param keyStateEncoderSpec Key state encoder spec
+ * @param storeName Optional store name (for stream-stream join which has
multiple stores)
+ */
+ private def performRoundTripTest(
+ sourceDir: String,
+ targetDir: String,
+ keySchema: StructType,
+ valueSchema: StructType,
+ keyStateEncoderSpec: KeyStateEncoderSpec,
+ storeName: Option[String] = None): Unit = {
+
+ // Step 1: Read original state using normal reader (for comparison later)
+ val sourceReader = spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH, sourceDir)
+ val sourceNormalData = (storeName match {
+ case Some(name) => sourceReader.option(StateSourceOptions.STORE_NAME,
name)
+ case None => sourceReader
+ }).load()
+ .selectExpr("key", "value", "partition_id")
+ .collect()
+
+ // Step 2: Read from source using AllColumnFamiliesReader (raw bytes)
+ val sourceBytesReader = spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH, sourceDir)
+ .option(StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES,
"true")
+ val sourceBytesData = (storeName match {
+ case Some(name) =>
sourceBytesReader.option(StateSourceOptions.STORE_NAME, name)
+ case None => sourceBytesReader
+ }).load()
+
+ // Verify schema of raw bytes
+ val schema = sourceBytesData.schema
+ assert(schema.fieldNames === Array(
+ "partition_key", "key_bytes", "value_bytes", "column_family_name"))
+
+ // Step 3: Write raw bytes to target checkpoint location
+ val hadoopConf = spark.sessionState.newHadoopConf()
+ val targetCpLocation = StreamingUtils.resolvedCheckpointLocation(
+ hadoopConf, targetDir)
+ val targetCheckpointMetadata = new StreamingQueryCheckpointMetadata(
+ spark, targetCpLocation)
+ val lastBatch = targetCheckpointMetadata.commitLog.getLatestBatchId().get
+ val targetOffsetSeq = targetCheckpointMetadata.offsetLog.get(lastBatch).get
+ val currentBatchId = lastBatch + 1
+ targetCheckpointMetadata.offsetLog.add(currentBatchId, targetOffsetSeq)
+
+ // Create column family to schema map
+ val columnFamilyToSchemaMap = HashMap(
+ StateStore.DEFAULT_COL_FAMILY_NAME -> StateStoreColFamilySchema(
+ StateStore.DEFAULT_COL_FAMILY_NAME,
+ keySchemaId = 0,
+ keySchema,
+ valueSchemaId = 0,
+ valueSchema,
+ keyStateEncoderSpec = Some(keyStateEncoderSpec)
+ )
+ )
+
+ val storeConf: StateStoreConf = StateStoreConf(SQLConf.get)
+ val serializableHadoopConf = new SerializableConfiguration(hadoopConf)
+
+ // Define the partition processing function
+ val putPartitionFunc: Iterator[Row] => Unit = partition => {
+ val allCFWriter = new StatePartitionAllColumnFamiliesWriter(
+ storeConf,
+ serializableHadoopConf.value,
+ TaskContext.getPartitionId(),
+ targetCpLocation,
+ 0,
+ storeName.getOrElse(StateStoreId.DEFAULT_STORE_NAME),
+ currentBatchId,
+ columnFamilyToSchemaMap
+ )
+ val rowConverter =
CatalystTypeConverters.createToCatalystConverter(schema)
+
+
allCFWriter.write(partition.map(rowConverter(_).asInstanceOf[InternalRow]))
+ }
+
+ // Write raw bytes to target using foreachPartition
+ sourceBytesData.foreachPartition(putPartitionFunc)
+
+ // Commit to commitLog
+ val latestCommit = targetCheckpointMetadata.commitLog.get(lastBatch).get
+ targetCheckpointMetadata.commitLog.add(currentBatchId, latestCommit)
+ val versionToCheck = currentBatchId + 1
+ val storeNamePath = s"state/0/0${storeName.fold("")("/" + _)}"
+ assert(!checkpointFileExists(new File(targetDir, storeNamePath),
versionToCheck, ".changelog"))
+ assert(checkpointFileExists(new File(targetDir, storeNamePath),
versionToCheck, ".zip"))
+
+ // Step 4: Read from target using normal reader
+ val targetReader = spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH, targetDir)
+ val targetNormalData = (storeName match {
+ case Some(name) => targetReader.option(StateSourceOptions.STORE_NAME,
name)
+ case None => targetReader
+ }).load()
+ .selectExpr("key", "value", "partition_id")
+ .collect()
+
+ // Step 5: Verify data matches
+ assert(sourceNormalData.length == targetNormalData.length,
+ s"Row count mismatch: source=${sourceNormalData.length}, " +
+ s"target=${targetNormalData.length}")
+
+ // Sort and compare row by row
+ val sourceSorted = sourceNormalData.sortBy(_.toString)
+ val targetSorted = targetNormalData.sortBy(_.toString)
+
+ sourceSorted.zip(targetSorted).zipWithIndex.foreach {
+ case ((sourceRow, targetRow), idx) =>
+ assert(sourceRow == targetRow,
+ s"Row mismatch at index $idx:\n" +
+ s" Source: $sourceRow\n" +
+ s" Target: $targetRow")
+ }
+ }
+
+ /**
+ * Checks if a changelog file for the specified version exists in the
given directory.
+ * A changelog file has the suffix ".changelog".
+ *
+ * @param dir Directory to search for changelog files
+ * @param version The version to check for existence
+ * @param suffix Either 'zip' or 'changelog'
+ * @return true if a changelog file with the given version exists, false
otherwise
+ */
+ private def checkpointFileExists(dir: File, version: Long, suffix:
String): Boolean = {
+ Option(dir.listFiles)
+ .getOrElse(Array.empty)
+ .map { file =>
+ file
+ }
+ .filter { file =>
+ file.getName.endsWith(suffix) && !file.getName.startsWith(".")
+ }
+ .exists { file =>
+ val nameWithoutSuffix = file.getName.stripSuffix(suffix)
+ val parts = nameWithoutSuffix.split("_")
+ parts.headOption match {
+ case Some(ver) if ver.forall(_.isDigit) => ver.toLong == version
+ case _ => false
+ }
+ }
+ }
+
+ /**
+ * Helper method to test SPARK-54420 read and write with different state
format versions
+ * for simple aggregation (single grouping key).
+ * @param stateVersion The state format version (1 or 2)
+ */
+ private def testRoundTripForAggrStateVersion(stateVersion: Int): Unit = {
+ withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key ->
stateVersion.toString) {
+ withTempDir { sourceDir =>
+ withTempDir { targetDir =>
+ // Step 1: Create state by running a streaming aggregation
+ runLargeDataStreamingAggregationQuery(sourceDir.getAbsolutePath)
+ val inputData: MemoryStream[Int] = MemoryStream[Int]
+ val aggregated = getLargeDataStreamingAggregationQuery(inputData)
+
+ // add dummy data to target source to test writer won't load
previous store
+ testStream(aggregated, OutputMode.Update)(
+ StartStream(checkpointLocation = targetDir.getAbsolutePath),
+ // batch 0
+ AddData(inputData, 0 until 2: _*),
+ CheckLastBatch(
+ (0, 1, 0, 0, 0), // 0
+ (1, 1, 1, 1, 1) // 1
+ ),
+ // batch 1
+ AddData(inputData, 0 until 2: _*),
+ CheckLastBatch(
+ (0, 2, 0, 0, 0), // 0
+ (1, 2, 2, 1, 1) // 1
+ )
+ )
+
+ // Step 2: Define schemas based on state version
+ val keySchema = StructType(Array(
+ StructField("groupKey", IntegerType, nullable = false)))
+ val valueSchema = if (stateVersion == 1) {
+ // State version 1 includes key columns in the value
+ StructType(Array(
+ StructField("groupKey", IntegerType, nullable = false),
+ StructField("count", LongType, nullable = false),
+ StructField("sum", LongType, nullable = false),
+ StructField("max", IntegerType, nullable = false),
+ StructField("min", IntegerType, nullable = false)
+ ))
+ } else {
+ // State version 2 excludes key columns from the value
+ StructType(Array(
+ StructField("count", LongType, nullable = false),
+ StructField("sum", LongType, nullable = false),
+ StructField("max", IntegerType, nullable = false),
+ StructField("min", IntegerType, nullable = false)
+ ))
+ }
+
+ // Create key state encoder spec (no prefix key for simple
aggregation)
+ val keyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+
+ // Perform round-trip test using common helper
+ performRoundTripTest(
+ sourceDir.getAbsolutePath,
+ targetDir.getAbsolutePath,
+ keySchema,
+ valueSchema,
+ keyStateEncoderSpec
+ )
+ }
+ }
+ }
+ }
+
+ /**
+ * Helper method to test SPARK-54420 read and write with different state
format versions
+ * for composite key aggregation (multiple grouping keys).
+ * @param stateVersion The state format version (1 or 2)
+ */
+ private def testCompositeKeyRoundTripForStateVersion(stateVersion: Int):
Unit = {
+ withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key ->
stateVersion.toString) {
+ withTempDir { sourceDir =>
+ withTempDir { targetDir =>
+ // Step 1: Create state by running a composite key streaming
aggregation
+ runCompositeKeyStreamingAggregationQuery(sourceDir.getAbsolutePath)
+ val inputData: MemoryStream[Int] = MemoryStream[Int]
+ val aggregated = getCompositeKeyStreamingAggregationQuery(inputData)
+
+ // add dummy data to target source to test writer won't load
previous store
+ testStream(aggregated, OutputMode.Update)(
+ StartStream(checkpointLocation = targetDir.getAbsolutePath),
+ // batch 0
+ AddData(inputData, 0, 1),
+ CheckLastBatch(
+ (0, "Apple", 1, 0, 0, 0),
+ (1, "Banana", 1, 1, 1, 1)
+ )
+ )
+
+ // Step 2: Define schemas based on state version for composite key
+ val keySchema = StructType(Array(
+ StructField("groupKey", IntegerType, nullable = false),
+ StructField("fruit", org.apache.spark.sql.types.StringType,
nullable = true)
+ ))
+ val valueSchema = if (stateVersion == 1) {
+ // State version 1 includes key columns in the value
+ StructType(Array(
+ StructField("groupKey", IntegerType, nullable = false),
+ StructField("fruit", org.apache.spark.sql.types.StringType,
nullable = true),
+ StructField("count", LongType, nullable = false),
+ StructField("sum", LongType, nullable = false),
+ StructField("max", IntegerType, nullable = false),
+ StructField("min", IntegerType, nullable = false)
+ ))
+ } else {
+ // State version 2 excludes key columns from the value
+ StructType(Array(
+ StructField("count", LongType, nullable = false),
+ StructField("sum", LongType, nullable = false),
+ StructField("max", IntegerType, nullable = false),
+ StructField("min", IntegerType, nullable = false)
+ ))
+ }
+
+ // Create key state encoder spec (no prefix key for composite key
aggregation)
+ val keyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+
+ // Perform round-trip test using common helper
+ performRoundTripTest(
+ sourceDir.getAbsolutePath,
+ targetDir.getAbsolutePath,
+ keySchema,
+ valueSchema,
+ keyStateEncoderSpec
+ )
+ }
+ }
+ }
+ }
+
+ /**
+ * Helper method to test round-trip for stream-stream join with different
versions.
+ */
+ private def testStreamStreamJoinRoundTrip(stateVersion: Int): Unit = {
+ withSQLConf(SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION.key ->
stateVersion.toString) {
+ withTempDir { sourceDir =>
+ withTempDir { targetDir =>
+ // Step 1: Create state by running stream-stream join
+ runStreamStreamJoinQuery(sourceDir.getAbsolutePath)
+
+ // Create dummy data in target
+ val inputData: MemoryStream[(Int, Long)] = MemoryStream[(Int, Long)]
+ val query = getStreamStreamJoinQuery(inputData)
+ testStream(query)(
+ StartStream(checkpointLocation = targetDir.getAbsolutePath),
+ AddData(inputData, (1, 1L)),
+ CheckNewAnswer()
+ )
+
+ // Step 2: Test all 4 state stores created by stream-stream join
+ // Test keyToNumValues stores (both left and right)
+ Seq("left-keyToNumValues", "right-keyToNumValues").foreach {
storeName =>
+ val keySchema = StructType(Array(
+ StructField("key", IntegerType)
+ ))
+ val valueSchema = StructType(Array(
+ StructField("value", LongType)
+ ))
+ val keyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+
+ // Perform round-trip test using common helper
+ performRoundTripTest(
+ sourceDir.getAbsolutePath,
+ targetDir.getAbsolutePath,
+ keySchema,
+ valueSchema,
+ keyStateEncoderSpec,
+ storeName = Some(storeName)
+ )
+ }
+
+ // Test keyWithIndexToValue stores (both left and right)
+ Seq("left-keyWithIndexToValue", "right-keyWithIndexToValue").foreach
{ storeName =>
+ val keySchema = StructType(Array(
+ StructField("key", IntegerType, nullable = false),
+ StructField("index", LongType)
+ ))
+ val valueSchema = if (stateVersion == 2) {
+ StructType(Array(
+ StructField("value", IntegerType, nullable = false),
+ StructField("time", TimestampType, nullable = false),
+ StructField("matched", BooleanType)
+ ))
+ } else {
+ StructType(Array(
+ StructField("value", IntegerType, nullable = false),
+ StructField("time", TimestampType, nullable = false)
+ ))
+ }
+ val keyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+
+ // Perform round-trip test using common helper
+ performRoundTripTest(
+ sourceDir.getAbsolutePath,
+ targetDir.getAbsolutePath,
+ keySchema,
+ valueSchema,
+ keyStateEncoderSpec,
+ storeName = Some(storeName)
+ )
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Helper method to test round-trip for flatMapGroupsWithState with
different versions.
+ */
+ private def testFlatMapGroupsWithStateRoundTrip(stateVersion: Int): Unit = {
+ // Skip this test on big endian platforms (version 1 only)
+ if (stateVersion == 1) {
+
assume(java.nio.ByteOrder.nativeOrder().equals(java.nio.ByteOrder.LITTLE_ENDIAN))
+ }
+
+ withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key ->
stateVersion.toString) {
+ withTempDir { sourceDir =>
+ withTempDir { targetDir =>
+ // Step 1: Create state by running flatMapGroupsWithState
+ runFlatMapGroupsWithStateQuery(sourceDir.getAbsolutePath)
+
+ // Create dummy data in target
+ val clock = new StreamManualClock
+ val inputData: MemoryStream[(String, Long)] = MemoryStream[(String,
Long)]
+ val query = getFlatMapGroupsWithStateQuery(inputData)
+ testStream(query, OutputMode.Update)(
+ StartStream(Trigger.ProcessingTime("1 second"), triggerClock =
clock,
+ checkpointLocation = targetDir.getAbsolutePath),
+ AddData(inputData, ("a", 1L)),
+ AdvanceManualClock(1 * 1000),
+ CheckLastBatch(("a", 1, 0, false))
+ )
+
+ // Step 2: Define schemas for flatMapGroupsWithState
+ val keySchema = StructType(Array(
+ StructField("value", org.apache.spark.sql.types.StringType,
nullable = true)
+ ))
+ val valueSchema = if (stateVersion == 1) {
+ StructType(Array(
+ StructField("numEvents", IntegerType, nullable = false),
+ StructField("startTimestampMs", LongType, nullable = false),
+ StructField("endTimestampMs", LongType, nullable = false),
+ StructField("timeoutTimestamp", IntegerType, nullable = false)
+ ))
+ } else {
+ StructType(Array(
+ StructField("groupState",
org.apache.spark.sql.types.StructType(Array(
+ StructField("numEvents", IntegerType, nullable = false),
+ StructField("startTimestampMs", LongType, nullable = false),
+ StructField("endTimestampMs", LongType, nullable = false)
+ )), nullable = false),
+ StructField("timeoutTimestamp", LongType, nullable = false)
+ ))
+ }
+ val keyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+
+ // Perform round-trip test using common helper
+ performRoundTripTest(
+ sourceDir.getAbsolutePath,
+ targetDir.getAbsolutePath,
+ keySchema,
+ valueSchema,
+ keyStateEncoderSpec
+ )
+ }
+ }
+ }
+ }
+
+ // Run all tests with both changelog checkpointing enabled and disabled
+ Seq(true, false).foreach { changelogCheckpointingEnabled =>
+ val testSuffix = if (changelogCheckpointingEnabled) {
+ "with changelog checkpointing"
+ } else {
+ "without changelog checkpointing"
+ }
+
+ def testWithChangelogConfig(testName: String)(testFun: => Unit): Unit = {
+ test(s"$testName ($testSuffix)") {
+ withSQLConf(
+
"spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled" ->
+ changelogCheckpointingEnabled.toString) {
+ testFun
+ }
+ }
+ }
+
+ testWithChangelogConfig("SPARK-54420: aggregation state ver 1") {
+ testRoundTripForAggrStateVersion(1)
+ }
+
+ testWithChangelogConfig("SPARK-54420: aggregation state ver 2") {
+ testRoundTripForAggrStateVersion(2)
+ }
+
+ Seq(1, 2).foreach { version =>
+ testWithChangelogConfig(s"SPARK-54420: composite key aggregation state
ver $version") {
+ testCompositeKeyRoundTripForStateVersion(version)
+ }
+ }
+
+ testWithChangelogConfig("SPARK-54420: dropDuplicatesWithinWatermark") {
+ withTempDir { sourceDir =>
+ withTempDir { targetDir =>
+ // Step 1: Create state by running dropDuplicatesWithinWatermark
+ runDropDuplicatesWithinWatermarkQuery(sourceDir.getAbsolutePath)
+
+ // Create dummy data in target
+ val inputData: MemoryStream[(String, Int)] = MemoryStream[(String,
Int)]
+ val deduped = getDropDuplicatesWithinWatermarkQuery(inputData)
+ testStream(deduped, OutputMode.Append)(
+ StartStream(checkpointLocation = targetDir.getAbsolutePath),
+ AddData(inputData, ("a", 1)),
+ CheckAnswer(("a", 1))
+ )
+
+ // Step 2: Define schemas for dropDuplicatesWithinWatermark
+ val keySchema = StructType(Array(
+ StructField("_1", org.apache.spark.sql.types.StringType, nullable
= true)
+ ))
+ val valueSchema = StructType(Array(
+ StructField("expiresAtMicros", LongType, nullable = false)
+ ))
+ val keyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+
+ // Perform round-trip test using common helper
+ performRoundTripTest(
+ sourceDir.getAbsolutePath,
+ targetDir.getAbsolutePath,
+ keySchema,
+ valueSchema,
+ keyStateEncoderSpec
+ )
+ }
+ }
+ }
+
+ testWithChangelogConfig("SPARK-54420: dropDuplicates with column
specified") {
+ withTempDir { sourceDir =>
+ withTempDir { targetDir =>
+ // Step 1: Create state by running dropDuplicates with column
+ runDropDuplicatesQueryWithColumnSpecified(sourceDir.getAbsolutePath)
+
+ // Create dummy data in target
+ val inputData: MemoryStream[(String, Int)] = MemoryStream[(String,
Int)]
+ val deduped = getDropDuplicatesQueryWithColumnSpecified(inputData)
+ testStream(deduped, OutputMode.Append)(
+ StartStream(checkpointLocation = targetDir.getAbsolutePath),
+ AddData(inputData, ("a", 1)),
+ CheckAnswer(("a", 1))
+ )
+
+ // Step 2: Define schemas for dropDuplicates with column specified
+ val keySchema = StructType(Array(
+ StructField("col1", org.apache.spark.sql.types.StringType,
nullable = true)
+ ))
+ val valueSchema = StructType(Array(
+ StructField("__dummy__", NullType, nullable = true)
+ ))
+ val keyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+
+ // Perform round-trip test using common helper
+ performRoundTripTest(
+ sourceDir.getAbsolutePath,
+ targetDir.getAbsolutePath,
+ keySchema,
+ valueSchema,
+ keyStateEncoderSpec
+ )
+ }
+ }
+ }
+
+ testWithChangelogConfig("SPARK-54420: session window aggregation") {
+ withTempDir { sourceDir =>
+ withTempDir { targetDir =>
+ // Step 1: Create state by running session window aggregation
+ runSessionWindowAggregationQuery(sourceDir.getAbsolutePath)
+
+ // Create dummy data in target
+ val inputData: MemoryStream[(String, Long)] = MemoryStream[(String,
Long)]
+ val aggregated = getSessionWindowAggregationQuery(inputData)
+ testStream(aggregated, OutputMode.Complete())(
+ StartStream(checkpointLocation = targetDir.getAbsolutePath),
+ AddData(inputData, ("a", 40L)),
+ CheckNewAnswer(
+ ("a", 40, 50, 10, 1)
+ ),
+ StopStream
+ )
+
+ // Step 2: Define schemas for session window aggregation
+ val keySchema = StructType(Array(
+ StructField("sessionId", org.apache.spark.sql.types.StringType,
nullable = false),
+ StructField("sessionStartTime",
+ org.apache.spark.sql.types.TimestampType, nullable = false)
+ ))
+ val valueSchema = StructType(Array(
+ StructField("session_window",
org.apache.spark.sql.types.StructType(Array(
+ StructField("start", org.apache.spark.sql.types.TimestampType),
+ StructField("end", org.apache.spark.sql.types.TimestampType)
+ )), nullable = false),
+ StructField("sessionId", org.apache.spark.sql.types.StringType,
nullable = false),
+ StructField("count", LongType, nullable = false)
+ ))
+ // Session window aggregation uses prefix key scanning where
sessionId is the prefix
+ val keyStateEncoderSpec = PrefixKeyScanStateEncoderSpec(keySchema, 1)
+
+ // Perform round-trip test using common helper
+ performRoundTripTest(
+ sourceDir.getAbsolutePath,
+ targetDir.getAbsolutePath,
+ keySchema,
+ valueSchema,
+ keyStateEncoderSpec
+ )
+ }
+ }
+ }
+
+ testWithChangelogConfig("SPARK-54420: dropDuplicates") {
+ withTempDir { sourceDir =>
+ withTempDir { targetDir =>
+
+ // Step 1: Create state by running a streaming aggregation
+ runDropDuplicatesQuery(sourceDir.getAbsolutePath)
+ val inputData: MemoryStream[Int] = MemoryStream[Int]
+ val stream = getDropDuplicatesQuery(inputData)
+ testStream(stream, OutputMode.Append)(
+ StartStream(checkpointLocation = targetDir.getAbsolutePath),
+ AddData(inputData, (1 to 5).flatMap(_ => (10 to 15)): _*),
+ CheckAnswer(10 to 15: _*),
+ assertNumStateRows(total = 6, updated = 6)
+ )
+
+ // Step 2: Define schemas for dropDuplicates (state version 2)
+ val keySchema = StructType(Array(
+ StructField("value", IntegerType, nullable = false),
+ StructField("eventTime", org.apache.spark.sql.types.TimestampType)
+ ))
+ val valueSchema = StructType(Array(
+ StructField("__dummy__", NullType)
+ ))
+ val keyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+
+ // Perform round-trip test using common helper
+ performRoundTripTest(
+ sourceDir.getAbsolutePath,
+ targetDir.getAbsolutePath,
+ keySchema,
+ valueSchema,
+ keyStateEncoderSpec
+ )
+ }
+ }
+ }
+
+ Seq(1, 2).foreach { version =>
+ testWithChangelogConfig(s"SPARK-54420: flatMapGroupsWithState state ver
$version") {
+ testFlatMapGroupsWithStateRoundTrip(version)
+ }
+ }
+
+ Seq(1, 2).foreach { version =>
+ testWithChangelogConfig(s"SPARK-54420: stream-stream join state ver
$version") {
+ testStreamStreamJoinRoundTrip(version)
+ }
+ }
+ } // End of foreach loop for changelog checkpointing dimension
+}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
index a997ead74097..6269df928a77 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
@@ -95,7 +95,8 @@ class SignalingStateStoreProvider extends StateStoreProvider
with Logging {
override def getStore(
version: Long,
uniqueId: Option[String],
- forceSnapshotOnCommit: Boolean = false): StateStore = null
+ forceSnapshotOnCommit: Boolean = false,
+ loadEmpty: Boolean = false): StateStore = null
/**
* Simulates a maintenance operation that blocks until a signal is received.
@@ -175,7 +176,8 @@ class FakeStateStoreProviderTracksCloseThread extends
StateStoreProvider {
override def getStore(
version: Long,
uniqueId: Option[String],
- forceSnapshotOnCommit: Boolean = false): StateStore = null
+ forceSnapshotOnCommit: Boolean = false,
+ loadEmpty: Boolean = false): StateStore = null
override def doMaintenance(): Unit = {}
}
@@ -247,7 +249,8 @@ class FakeStateStoreProviderWithMaintenanceError extends
StateStoreProvider {
override def getStore(
version: Long,
uniqueId: Option[String],
- forceSnapshotOnCommit: Boolean = false): StateStore = null
+ forceSnapshotOnCommit: Boolean = false,
+ loadEmpty: Boolean = false): StateStore = null
override def doMaintenance(): Unit = {
Thread.currentThread.setUncaughtExceptionHandler(exceptionHandler)
@@ -1438,6 +1441,26 @@ class StateStoreSuite extends
StateStoreSuiteBase[HDFSBackedStateStoreProvider]
"HDFSBackedStateStoreProvider does not support checkpointFormatVersion >
1"))
}
+ test("SPARK-54420: HDFSBackedStateStoreProvider does not support loading
empty store") {
+ val provider = new HDFSBackedStateStoreProvider()
+ val hadoopConf = new Configuration()
+ hadoopConf.set(StreamExecution.RUN_ID_KEY, UUID.randomUUID().toString)
+ provider.init(
+ StateStoreId(newDir(), Random.nextInt(), 0),
+ keySchema,
+ valueSchema,
+ NoPrefixKeyStateEncoderSpec(keySchema),
+ useColumnFamilies = false,
+ new StateStoreConf(),
+ hadoopConf)
+
+ val e = intercept[StateStoreUnsupportedOperationException] {
+ provider.getStore(0, loadEmpty = true)
+ }
+ assert(e.getMessage.contains(
+ "Internal Error: HDFSBackedStateStoreProvider doesn't support
loadEmpty"))
+ }
+
test("Auto snapshot repair") {
withSQLConf(
SQLConf.STREAMING_CHECKPOINT_FILE_CHECKSUM_ENABLED.key -> false.toString,
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
index cc1d0bc1ed17..b4fd41a6b550 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
@@ -1463,7 +1463,8 @@ class TestStateStoreProvider extends StateStoreProvider {
override def getStore(
version: Long,
stateStoreCkptId: Option[String] = None,
- forceSnapshotOnCommit: Boolean = false): StateStore = null
+ forceSnapshotOnCommit: Boolean = false,
+ loadEmpty: Boolean = false): StateStore = null
}
/** A fake source that throws `ThrowingExceptionInCreateSource.exception` in
`createSource` */
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]