This is an automated email from the ASF dual-hosted git repository. ashrigondekar pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 885bfc22cb0d [SPARK-53333][SS] Enable StateDataSource with state checkpoint v2 (only readChangeFeed) 885bfc22cb0d is described below commit 885bfc22cb0d315519384568c9cb0dce2c0f556f Author: Dylan Wong <dylan.w...@databricks.com> AuthorDate: Tue Sep 2 14:05:24 2025 -0700 [SPARK-53333][SS] Enable StateDataSource with state checkpoint v2 (only readChangeFeed) ### What changes were proposed in this pull request? This PR extends StateDataSource (https://spark.apache.org/docs/latest/streaming/structured-streaming-state-data-source.html) support for state checkpoint v2 format to include the `readChangeFeed` functionality. This PR now enables users to read change feeds from state stores using checkpoint v2 format by: - Implementing full lineage reconstruction across multiple changelog files using `getFullLineage` in RocksDBFileManager. This is needed because changelog files only contain lineage from [snapShotVersion, version) and we may need the versions for all changelog files across snapshot boundaries. - Adding support for `getStateStoreChangeDataReader` to have and use the `endVersionStateStoreCkptId` parameter. Since we can construct the full lineage to the start version from the last version and `endVersionStateStoreCkptId` we do not need a `startVersionStateStoreCkptId`. However when `snapshotStartBatchId` is implemented `startVersionStateStoreCkptId` and `endVersionStateStoreCkptId` will be needed to maintain the current behavior. - Adding an extra parameter to `setStoreMetrics` to determine whether or not to call `store.getStateStoreCheckpointInfo()`. If we call this in the abort case in `TransformWithStateExec` or `TransformWithStateInPySparkExec` it will throw an exception and we do not want this. The key enhancement is the ability to read change feeds that span across multiple snapshots by walking backwards through the lineage information embedded in changelog files to construct the complete version history. NOTE: To read checkpoint v2 state data sources it is required to have `"spark.sql.streaming.stateStore.checkpointFormatVersion" -> 2`. It is possible to allow reading state data sources arbitrarily based on what is in the CommitLog by relaxing assertion checks but this is left as a future change. ### Why are the changes needed? State checkpoint v2 (`"spark.sql.streaming.stateStore.checkpointFormatVersion"`) introduces a new format for storing state metadata that includes unique identifiers in the file path for each state store. The existing StateDataSource implementation only worked with checkpoint v1 format, making it incompatible with streaming queries using the newer checkpoint format. Only `batchId` was implemented in https://github.com/apache/spark/pull/52047. ### Does this PR introduce _any_ user-facing change? Yes. State Data Source will work when checkpoint v2 is used and the `readChangeFeed` option is used. ### How was this patch tested? Adds a new test suite `RocksDBWithCheckpointV2StateDataSourceChangeDataReaderSuite` that reuses the unit tests in `RocksDBWithChangelogCheckpointStateDataSourceChangeDataReaderSuite` but with checkpoint v2 enabled and adds tests for the case of reading across snapshot boundaries. ``` testOnly *RocksDBWithCheckpointV2StateDataSourceChangeDataReaderSuite ``` ``` [info] Total number of tests run: 10 [info] Suites: completed 1, aborted 0 [info] Tests: succeeded 10, failed 0, canceled 0, ignored 0, pending 0 [info] All tests passed. ``` Adds a new test suite `StateDataSourceTransformWithStateSuiteCheckpointV2` that reuses the unit tests in `StateDataSourceTransformWithStateSuite` but with checkpoint v2 enabled. ``` testOnly *StateDataSourceTransformWithStateSuiteCheckpointV2 ``` Note that the cancelled tests are to not run the tests that use `snapshotStartBatchId`. ``` [info] Total number of tests run: 44 [info] Suites: completed 1, aborted 0 [info] Tests: succeeded 44, failed 0, canceled 2, ignored 0, pending 0 [info] All tests passed ``` Adds a new test suite `TransformWithStateInitialStateSuiteCheckpointV2` that reuses the unit tests in `TransformWithStateInitialStateSuite` but with checkpoint v2 enabled. ``` testOnly *TransformWithStateInitialStateSuiteCheckpointV2 ``` ``` [info] Total number of tests run: 44 [info] Suites: completed 1, aborted 0 [info] Tests: succeeded 44, failed 0, canceled 0, ignored 0, pending 0 [info] All tests passed. ``` Adds a new tests `TransformWithStateInPandasWithCheckpointV2Tests` and `TransformWithStateInPySparkWithCheckpointV2Tests` that reuses the unit tests in python that test the State Data Source. ### Was this patch authored or co-authored using generative AI tooling? No Closes #52148 from dylanwong250/SPARK-53333. Authored-by: Dylan Wong <dylan.w...@databricks.com> Signed-off-by: Anish Shrigondekar <anish.shrigonde...@databricks.com> --- .../src/main/resources/error/error-conditions.json | 19 +++ .../pandas/test_pandas_transform_with_state.py | 36 ++++ .../spark/sql/errors/QueryExecutionErrors.scala | 11 ++ .../datasources/v2/state/StateDataSource.scala | 42 ++++- .../v2/state/StateDataSourceErrors.scala | 27 +++ .../v2/state/StatePartitionReader.scala | 22 ++- .../StreamStreamJoinStatePartitionReader.scala | 13 +- .../TransformWithStateInPySparkExec.scala | 2 +- .../operators/stateful/statefulOperators.scala | 4 +- .../TransformWithStateExec.scala | 2 +- .../state/HDFSBackedStateStoreProvider.scala | 26 ++- .../sql/execution/streaming/state/RocksDB.scala | 57 +++++++ .../state/RocksDBStateStoreProvider.scala | 22 ++- .../sql/execution/streaming/state/StateStore.scala | 3 +- .../streaming/state/StateStoreChangelog.scala | 40 +++-- .../streaming/state/StateStoreErrors.scala | 10 ++ .../state/StateDataSourceChangeDataReadSuite.scala | 185 ++++++++++++++++++++- .../v2/state/StateDataSourceReadSuite.scala | 15 -- .../StateDataSourceTransformWithStateSuite.scala | 11 ++ .../streaming/state/RocksDBLineageSuite.scala | 170 +++++++++++++++++++ .../streaming/state/StateStoreSuite.scala | 4 +- .../TransformWithStateInitialStateSuite.scala | 8 + 22 files changed, 660 insertions(+), 69 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 167c460536ac..8bbf24074f35 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -368,6 +368,11 @@ "The change log writer version cannot be <version>." ] }, + "INVALID_CHECKPOINT_LINEAGE" : { + "message" : [ + "Invalid checkpoint lineage: <lineage>. <message>" + ] + }, "KEY_ROW_FORMAT_VALIDATION_FAILURE" : { "message" : [ "<msg>" @@ -5162,6 +5167,12 @@ ], "sqlState" : "42802" }, + "STATE_STORE_CHECKPOINT_IDS_NOT_SUPPORTED" : { + "message" : [ + "<msg>" + ], + "sqlState" : "KD002" + }, "STATE_STORE_CHECKPOINT_LOCATION_NOT_EMPTY" : { "message" : [ "The checkpoint location <checkpointLocation> should be empty on batch 0", @@ -5407,6 +5418,14 @@ }, "sqlState" : "42616" }, + "STDS_MIXED_CHECKPOINT_FORMAT_VERSIONS_NOT_SUPPORTED" : { + "message" : [ + "Reading state across different checkpoint format versions is not supported.", + "startBatchId=<startBatchId>, endBatchId=<endBatchId>.", + "startFormatVersion=<startFormatVersion>, endFormatVersion=<endFormatVersion>." + ], + "sqlState" : "KD002" + }, "STDS_NO_PARTITION_DISCOVERED_IN_STATE_STORE" : { "message" : [ "The state does not have any partition. Please double check that the query points to the valid state. options: <sourceOptions>" diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py index d3bda545e1c9..527fb7d370e7 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py @@ -1916,6 +1916,30 @@ class TransformWithStateInPandasTestsMixin(TransformWithStateTestsMixin): return cfg +class TransformWithStateInPandasWithCheckpointV2TestsMixin(TransformWithStateInPandasTestsMixin): + @classmethod + def conf(cls): + cfg = super().conf() + cfg.set("spark.sql.streaming.stateStore.checkpointFormatVersion", "2") + return cfg + + # TODO(SPARK-53332): Add test back when checkpoint v2 support exists for snapshotStartBatchId + def test_transform_with_value_state_metadata(self): + pass + + +class TransformWithStateInPySparkWithCheckpointV2TestsMixin(TransformWithStateInPySparkTestsMixin): + @classmethod + def conf(cls): + cfg = super().conf() + cfg.set("spark.sql.streaming.stateStore.checkpointFormatVersion", "2") + return cfg + + # TODO(SPARK-53332): Add test back when checkpoint v2 support exists for snapshotStartBatchId + def test_transform_with_value_state_metadata(self): + pass + + class TransformWithStateInPandasTests(TransformWithStateInPandasTestsMixin, ReusedSQLTestCase): pass @@ -1924,6 +1948,18 @@ class TransformWithStateInPySparkTests(TransformWithStateInPySparkTestsMixin, Re pass +class TransformWithStateInPandasWithCheckpointV2Tests( + TransformWithStateInPandasWithCheckpointV2TestsMixin, ReusedSQLTestCase +): + pass + + +class TransformWithStateInPySparkWithCheckpointV2Tests( + TransformWithStateInPySparkWithCheckpointV2TestsMixin, ReusedSQLTestCase +): + pass + + if __name__ == "__main__": from pyspark.sql.tests.pandas.test_pandas_transform_with_state import * # noqa: F401 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index ba229a2e746c..67bb80403b9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -2740,6 +2740,17 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE ) } + def invalidCheckpointLineage(lineage: String, message: String): Throwable = { + new SparkException( + errorClass = "CANNOT_LOAD_STATE_STORE.INVALID_CHECKPOINT_LINEAGE", + messageParameters = Map( + "lineage" -> lineage, + "message" -> message + ), + cause = null + ) + } + def notEnoughMemoryToLoadStore( stateStoreId: String, stateStoreProviderName: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index 828c06ab834a..54d3c45d237b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -371,7 +371,8 @@ case class StateSourceOptions( stateVarName: Option[String], readRegisteredTimers: Boolean, flattenCollectionTypes: Boolean, - operatorStateUniqueIds: Option[Array[Array[String]]] = None) { + startOperatorStateUniqueIds: Option[Array[Array[String]]] = None, + endOperatorStateUniqueIds: Option[Array[Array[String]]] = None) { def stateCheckpointLocation: Path = new Path(resolvedCpLocation, DIR_NAME_STATE) override def toString: String = { @@ -576,29 +577,52 @@ object StateSourceOptions extends DataSourceOptions { batchId.get } - val operatorStateUniqueIds = getOperatorStateUniqueIds( + val endBatchId = if (readChangeFeedOptions.isDefined) { + readChangeFeedOptions.get.changeEndBatchId + } else { + batchId.get + } + + val startOperatorStateUniqueIds = getOperatorStateUniqueIds( sparkSession, startBatchId, operatorId, resolvedCpLocation) - if (operatorStateUniqueIds.isDefined) { + val endOperatorStateUniqueIds = if (startBatchId == endBatchId) { + startOperatorStateUniqueIds + } else { + getOperatorStateUniqueIds( + sparkSession, + endBatchId, + operatorId, + resolvedCpLocation) + } + + if (startOperatorStateUniqueIds.isDefined != endOperatorStateUniqueIds.isDefined) { + val startFormatVersion = if (startOperatorStateUniqueIds.isDefined) 2 else 1 + val endFormatVersion = if (endOperatorStateUniqueIds.isDefined) 2 else 1 + throw StateDataSourceErrors.mixedCheckpointFormatVersionsNotSupported( + startBatchId, + endBatchId, + startFormatVersion, + endFormatVersion + ) + } + + if (startOperatorStateUniqueIds.isDefined) { if (fromSnapshotOptions.isDefined) { throw StateDataSourceErrors.invalidOptionValue( SNAPSHOT_START_BATCH_ID, "Snapshot reading is currently not supported with checkpoint v2.") } - if (readChangeFeedOptions.isDefined) { - throw StateDataSourceErrors.invalidOptionValue( - READ_CHANGE_FEED, - "Read change feed is currently not supported with checkpoint v2.") - } } StateSourceOptions( resolvedCpLocation, batchId.get, operatorId, storeName, joinSide, readChangeFeed, fromSnapshotOptions, readChangeFeedOptions, - stateVarName, readRegisteredTimers, flattenCollectionTypes, operatorStateUniqueIds) + stateVarName, readRegisteredTimers, flattenCollectionTypes, + startOperatorStateUniqueIds, endOperatorStateUniqueIds) } private def resolvedCheckpointLocation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceErrors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceErrors.scala index b6883a98f3ed..74ab308131f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceErrors.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceErrors.scala @@ -81,6 +81,18 @@ object StateDataSourceErrors { sourceOptions: StateSourceOptions): StateDataSourceException = { new StateDataSourceNoPartitionDiscoveredInStateStore(sourceOptions) } + + def mixedCheckpointFormatVersionsNotSupported( + startBatchId: Long, + endBatchId: Long, + startFormatVersion: Int, + endFormatVersion: Int): StateDataSourceException = { + new StateDataSourceMixedCheckpointFormatVersionsNotSupported( + startBatchId, + endBatchId, + startFormatVersion, + endFormatVersion) + } } abstract class StateDataSourceException( @@ -172,3 +184,18 @@ class StateDataSourceReadOperatorMetadataFailure( "STDS_FAILED_TO_READ_OPERATOR_METADATA", Map("checkpointLocation" -> checkpointLocation, "batchId" -> batchId.toString), cause = null) + +class StateDataSourceMixedCheckpointFormatVersionsNotSupported( + startBatchId: Long, + endBatchId: Long, + startFormatVersion: Int, + endFormatVersion: Int) + extends StateDataSourceException( + "STDS_MIXED_CHECKPOINT_FORMAT_VERSIONS_NOT_SUPPORTED", + Map( + "startBatchId" -> startBatchId.toString, + "endBatchId" -> endBatchId.toString, + "startFormatVersion" -> startFormatVersion.toString, + "endFormatVersion" -> endFormatVersion.toString + ), + cause = null) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index ebef6e3dac55..7180fe483fcc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -96,11 +96,20 @@ abstract class StatePartitionReaderBase( schema, "value").asInstanceOf[StructType] } - protected val getStoreUniqueId : Option[String] = { + protected def getStoreUniqueId( + operatorStateUniqueIds: Option[Array[Array[String]]]) : Option[String] = { SymmetricHashJoinStateManager.getStateStoreCheckpointId( storeName = partition.sourceOptions.storeName, partitionId = partition.partition, - stateStoreCkptIds = partition.sourceOptions.operatorStateUniqueIds) + stateStoreCkptIds = operatorStateUniqueIds) + } + + protected def getStartStoreUniqueId: Option[String] = { + getStoreUniqueId(partition.sourceOptions.startOperatorStateUniqueIds) + } + + protected def getEndStoreUniqueId: Option[String] = { + getStoreUniqueId(partition.sourceOptions.endOperatorStateUniqueIds) } protected lazy val provider: StateStoreProvider = { @@ -123,7 +132,7 @@ abstract class StatePartitionReaderBase( if (useColFamilies) { val store = provider.getStore( partition.sourceOptions.batchId + 1, - getStoreUniqueId) + getEndStoreUniqueId) require(stateStoreColFamilySchemaOpt.isDefined) val stateStoreColFamilySchema = stateStoreColFamilySchemaOpt.get require(stateStoreColFamilySchema.keyStateEncoderSpec.isDefined) @@ -182,9 +191,11 @@ class StatePartitionReader( private lazy val store: ReadStateStore = { partition.sourceOptions.fromSnapshotOptions match { case None => + assert(getStartStoreUniqueId == getEndStoreUniqueId, + "Start and end store unique IDs must be the same when not reading from snapshot") provider.getReadStore( partition.sourceOptions.batchId + 1, - getStoreUniqueId + getStartStoreUniqueId ) case Some(fromSnapshotOptions) => @@ -261,7 +272,8 @@ class StateStoreChangeDataPartitionReader( .getStateStoreChangeDataReader( partition.sourceOptions.readChangeFeedOptions.get.changeStartBatchId + 1, partition.sourceOptions.readChangeFeedOptions.get.changeEndBatchId + 1, - colFamilyNameOpt) + colFamilyNameOpt, + getEndStoreUniqueId) } override lazy val iter: Iterator[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala index 0f8a3b3b609f..bf0e8968789c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala @@ -76,21 +76,22 @@ class StreamStreamJoinStatePartitionReader( partition.sourceOptions.stateCheckpointLocation.toString, partition.sourceOptions.operatorId) - private val stateStoreCheckpointIds = SymmetricHashJoinStateManager.getStateStoreCheckpointIds( + private val startStateStoreCheckpointIds = + SymmetricHashJoinStateManager.getStateStoreCheckpointIds( partition.partition, - partition.sourceOptions.operatorStateUniqueIds, + partition.sourceOptions.startOperatorStateUniqueIds, usesVirtualColumnFamilies) private val keyToNumValuesStateStoreCkptId = if (joinSide == LeftSide) { - stateStoreCheckpointIds.left.keyToNumValues + startStateStoreCheckpointIds.left.keyToNumValues } else { - stateStoreCheckpointIds.right.keyToNumValues + startStateStoreCheckpointIds.right.keyToNumValues } private val keyWithIndexToValueStateStoreCkptId = if (joinSide == LeftSide) { - stateStoreCheckpointIds.left.keyWithIndexToValue + startStateStoreCheckpointIds.left.keyWithIndexToValue } else { - stateStoreCheckpointIds.right.keyWithIndexToValue + startStateStoreCheckpointIds.right.keyWithIndexToValue } /* diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala index 1b967af38b6d..f8390b7d878f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala @@ -389,7 +389,7 @@ case class TransformWithStateInPySparkExec( store.abort() } } - setStoreMetrics(store) + setStoreMetrics(store, isStreaming) setOperatorMetrics() }).map { row => numOutputRows += 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/statefulOperators.scala index cc8d354a0393..0634a2f05b41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/statefulOperators.scala @@ -430,14 +430,14 @@ trait StateStoreWriter * Set the SQL metrics related to the state store. * This should be called in that task after the store has been updated. */ - protected def setStoreMetrics(store: StateStore): Unit = { + protected def setStoreMetrics(store: StateStore, setCheckpointInfo: Boolean = true): Unit = { val storeMetrics = store.metrics longMetric("numTotalStateRows") += storeMetrics.numKeys longMetric("stateMemory") += storeMetrics.memoryUsedBytes setStoreCustomMetrics(storeMetrics.customMetrics) setStoreInstanceMetrics(storeMetrics.instanceMetrics) - if (StatefulOperatorStateInfo.enableStateStoreCheckpointIds(conf)) { + if (StatefulOperatorStateInfo.enableStateStoreCheckpointIds(conf) && setCheckpointInfo) { // Set the state store checkpoint information for the driver to collect val ssInfo = store.getStateStoreCheckpointInfo() setStateStoreCheckpointInfo( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateExec.scala index 20e2c32015d8..52a0d470c266 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateExec.scala @@ -346,7 +346,7 @@ case class TransformWithStateExec( store.abort() } } - setStoreMetrics(store) + setStoreMetrics(store, isStreaming) setOperatorMetrics() closeStatefulProcessor() statefulProcessor.setHandle(null) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index ac7f1a021960..f37a26012e22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -29,7 +29,7 @@ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ -import org.apache.spark.{SparkConf, SparkEnv, SparkException, TaskContext} +import org.apache.spark.{SparkConf, SparkEnv, TaskContext} import org.apache.spark.internal.{Logging, LogKeys} import org.apache.spark.io.CompressionCodec import org.apache.spark.sql.catalyst.expressions.UnsafeRow @@ -292,9 +292,9 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with /** Get the state store for making updates to create a new `version` of the store. */ override def getStore(version: Long, uniqueId: Option[String] = None): StateStore = { if (uniqueId.isDefined) { - throw QueryExecutionErrors.cannotLoadStore(new SparkException( + throw StateStoreErrors.stateStoreCheckpointIdsNotSupported( "HDFSBackedStateStoreProvider does not support checkpointFormatVersion > 1 " + - "but a state store checkpointID is passed in")) + "but a state store checkpointID is passed in") } val newMap = getLoadedMapForStore(version) logInfo(log"Retrieved version ${MDC(LogKeys.STATE_STORE_VERSION, version)} " + @@ -369,10 +369,10 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with hadoopConf: Configuration, useMultipleValuesPerKey: Boolean = false, stateSchemaProvider: Option[StateSchemaProvider] = None): Unit = { - assert( - !storeConf.enableStateStoreCheckpointIds, - "HDFS State Store Provider doesn't support checkpointFormatVersion >= 2 " + - s"checkpointFormatVersion ${storeConf.stateStoreCheckpointFormatVersion}") + if (storeConf.enableStateStoreCheckpointIds) { + throw StateStoreErrors.stateStoreCheckpointIdsNotSupported( + "HDFSBackedStateStoreProvider does not support checkpointFormatVersion > 1") + } this.stateStoreId_ = stateStoreId this.keySchema = keySchema @@ -1064,8 +1064,16 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with override def getStateStoreChangeDataReader( startVersion: Long, endVersion: Long, - colFamilyNameOpt: Option[String] = None): + colFamilyNameOpt: Option[String] = None, + endVersionStateStoreCkptId: Option[String] = None): StateStoreChangeDataReader = { + + if (endVersionStateStoreCkptId.isDefined) { + throw StateStoreErrors.stateStoreCheckpointIdsNotSupported( + "HDFSBackedStateStoreProvider does not support checkpointFormatVersion > 1 " + + "but a state store checkpointID is passed in") + } + // Multiple column families are not supported with HDFSBackedStateStoreProvider if (colFamilyNameOpt.isDefined) { throw StateStoreErrors.multipleColumnFamiliesNotSupported(providerName) @@ -1099,7 +1107,7 @@ class HDFSBackedStateStoreChangeDataReader( extends StateStoreChangeDataReader( fm, stateLocation, startVersion, endVersion, compressionCodec) { - override protected var changelogSuffix: String = "delta" + override protected val changelogSuffix: String = "delta" override def getNext(): (RecordType.Value, UnsafeRow, UnsafeRow, Long) = { val reader = currentChangelogReader() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 85e2d72ec163..1e65b737e2bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -342,6 +342,63 @@ class RocksDB( currLineage } + /** + * Construct the full lineage from startVersion to endVersion (inclusive) by + * walking backwards using lineage information embedded in changelog files. + */ + def getFullLineage( + startVersion: Long, + endVersion: Long, + endVersionStateStoreCkptId: Option[String]): Array[LineageItem] = { + assert(startVersion <= endVersion, + s"startVersion $startVersion should be less than or equal to endVersion $endVersion") + assert(endVersionStateStoreCkptId.isDefined, + "endVersionStateStoreCkptId should be defined") + + // A buffer to collect the lineage information, the entries should be decreasing in version + val buf = mutable.ArrayBuffer[LineageItem]() + buf.append(LineageItem(endVersion, endVersionStateStoreCkptId.get)) + + while (buf.last.version > startVersion) { + val prevSmallestVersion = buf.last.version + val lineage = getLineageFromChangelogFile(buf.last.version, Some(buf.last.checkpointUniqueId)) + // lineage array is sorted in increasing order, we need to make it decreasing + val lineageSortedDecreasing = lineage.filter(_.version >= startVersion).sortBy(-_.version) + // append to the buffer in reverse order, so the buffer is always decreasing in version + buf.appendAll(lineageSortedDecreasing) + + // to prevent infinite loop if we make no progress, throw an exception + if (buf.last.version == prevSmallestVersion) { + throw QueryExecutionErrors.invalidCheckpointLineage(printLineageItems(buf.reverse.toArray), + s"Cannot find version smaller than ${buf.last.version} in lineage.") + } + } + + // we return the lineage in increasing order + val ret = buf.reverse.toArray + + // Sanity checks + if (ret.head.version != startVersion) { + throw QueryExecutionErrors.invalidCheckpointLineage(printLineageItems(ret), + s"Lineage does not start with startVersion: $startVersion.") + } + if (ret.last.version != endVersion) { + throw QueryExecutionErrors.invalidCheckpointLineage(printLineageItems(ret), + s"Lineage does not end with endVersion: $endVersion.") + } + // Verify that the lineage versions are increasing by one + // We do this by checking that each entry is one version higher than the previous one + val increasingByOne = ret.sliding(2).forall { + case Array(prev, next) => prev.version + 1 == next.version + case _ => true + } + if (!increasingByOne) { + throw QueryExecutionErrors.invalidCheckpointLineage(printLineageItems(ret), + "Lineage versions are not increasing by one.") + } + + ret + } /** * Load the given version of data in a native RocksDB instance. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 4230ea3a2166..ce2a216534ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -878,15 +878,18 @@ private[sql] class RocksDBStateStoreProvider override def getStateStoreChangeDataReader( startVersion: Long, endVersion: Long, - colFamilyNameOpt: Option[String] = None): + colFamilyNameOpt: Option[String] = None, + endVersionStateStoreCkptId: Option[String] = None): StateStoreChangeDataReader = { val statePath = stateStoreId.storeCheckpointLocation() val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) new RocksDBStateStoreChangeDataReader( CheckpointFileManager.create(statePath, hadoopConf), + rocksDB, statePath, startVersion, endVersion, + endVersionStateStoreCkptId, CompressionCodec.createCodec(sparkConf, storeConf.compressionCodec), keyValueEncoderMap, colFamilyNameOpt) @@ -1225,9 +1228,11 @@ object RocksDBStateStoreProvider { /** [[StateStoreChangeDataReader]] implementation for [[RocksDBStateStoreProvider]] */ class RocksDBStateStoreChangeDataReader( fm: CheckpointFileManager, + rocksDB: RocksDB, stateLocation: Path, startVersion: Long, endVersion: Long, + endVersionStateStoreCkptId: Option[String], compressionCodec: CompressionCodec, keyValueEncoderMap: ConcurrentHashMap[String, (RocksDBKeyStateEncoder, RocksDBValueStateEncoder, Short)], @@ -1235,7 +1240,20 @@ class RocksDBStateStoreChangeDataReader( extends StateStoreChangeDataReader( fm, stateLocation, startVersion, endVersion, compressionCodec, colFamilyNameOpt) { - override protected var changelogSuffix: String = "changelog" + override protected val versionsAndUniqueIds: Array[(Long, Option[String])] = + if (endVersionStateStoreCkptId.isDefined) { + val fullVersionLineage = rocksDB.getFullLineage( + startVersion, + endVersion, + endVersionStateStoreCkptId) + fullVersionLineage + .sortBy(_.version) + .map(item => (item.version, Some(item.checkpointUniqueId))) + } else { + (startVersion to endVersion).map((_, None)).toArray + } + + override protected val changelogSuffix: String = "changelog" override def getNext(): (RecordType.Value, UnsafeRow, UnsafeRow, Long) = { var currRecord: (RecordType.Value, Array[Byte], Array[Byte]) = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 604a27866f62..f94eecd1dd42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -834,7 +834,8 @@ trait SupportsFineGrainedReplay { def getStateStoreChangeDataReader( startVersion: Long, endVersion: Long, - colFamilyNameOpt: Option[String] = None): + colFamilyNameOpt: Option[String] = None, + endVersionStateStoreCkptId: Option[String] = None): NextIterator[(RecordType.Value, UnsafeRow, UnsafeRow, Long)] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala index 4c5dea63baea..792f22cc574d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala @@ -447,6 +447,7 @@ abstract class StateStoreChangelogReader( Serialization.read[Array[LineageItem]](lineageStr) } + // The array contains lineage information from [mostRecentSnapShotVersion, version - 1] inclusive lazy val lineage: Array[LineageItem] = readLineage() def version: Short @@ -632,27 +633,41 @@ abstract class StateStoreChangeDataReader( * Iterator that iterates over the changelog files in the state store. */ private class ChangeLogFileIterator extends Iterator[Path] { + val versionsAndUniqueIds: Iterator[(Long, Option[String])] = + StateStoreChangeDataReader.this.versionsAndUniqueIds.iterator private var currentVersion = StateStoreChangeDataReader.this.startVersion - 1 + private var currentUniqueId: Option[String] = None /** returns the version of the changelog returned by the latest [[next]] function call */ def getVersion: Long = currentVersion - override def hasNext: Boolean = currentVersion < StateStoreChangeDataReader.this.endVersion + override def hasNext: Boolean = versionsAndUniqueIds.hasNext override def next(): Path = { - currentVersion += 1 - getChangelogPath(currentVersion) + val nextTuple = versionsAndUniqueIds.next() + currentVersion = nextTuple._1 + currentUniqueId = nextTuple._2 + getChangelogPath(currentVersion, currentUniqueId) } - private def getChangelogPath(version: Long): Path = - new Path( - StateStoreChangeDataReader.this.stateLocation, - s"$version.${StateStoreChangeDataReader.this.changelogSuffix}") + private def getChangelogPath(version: Long, checkpointUniqueId: Option[String]): Path = + if (checkpointUniqueId.isDefined) { + new Path( + StateStoreChangeDataReader.this.stateLocation, + s"${version}_${checkpointUniqueId.get}." + + s"${StateStoreChangeDataReader.this.changelogSuffix}") + } else { + new Path( + StateStoreChangeDataReader.this.stateLocation, + s"$version.${StateStoreChangeDataReader.this.changelogSuffix}") + } } /** file format of the changelog files */ - protected var changelogSuffix: String + protected val changelogSuffix: String + protected val versionsAndUniqueIds: Array[(Long, Option[String])] = + (startVersion to endVersion).map((_, None)).toArray private lazy val fileIterator = new ChangeLogFileIterator private var changelogReader: StateStoreChangelogReader = null @@ -671,11 +686,10 @@ abstract class StateStoreChangeDataReader( return null } - changelogReader = if (colFamilyNameOpt.isDefined) { - new StateStoreChangelogReaderV2(fm, fileIterator.next(), compressionCodec) - } else { - new StateStoreChangelogReaderV1(fm, fileIterator.next(), compressionCodec) - } + val changelogFile = fileIterator.next() + changelogReader = + new StateStoreChangelogReaderFactory(fm, changelogFile, compressionCodec) + .constructChangelogReader() } changelogReader } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala index 43682de03446..8a44f5c28456 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala @@ -259,6 +259,10 @@ object StateStoreErrors { QueryExecutionErrors.cannotLoadStore(e) } } + + def stateStoreCheckpointIdsNotSupported(msg: String): StateStoreCheckpointIdsNotSupported = { + new StateStoreCheckpointIdsNotSupported(msg) + } } trait ConvertableToCannotLoadStoreError { @@ -545,6 +549,12 @@ class StateStoreOperationOutOfOrder(errorMsg: String) messageParameters = Map("errorMsg" -> errorMsg) ) +class StateStoreCheckpointIdsNotSupported(msg: String) + extends SparkRuntimeException( + errorClass = "STATE_STORE_CHECKPOINT_IDS_NOT_SUPPORTED", + messageParameters = Map("msg" -> msg) + ) + class StateStoreCommitValidationFailed( batchId: Long, expectedCommits: Int, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala index be19981dc8a8..a1be83627f31 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala @@ -17,15 +17,18 @@ package org.apache.spark.sql.execution.datasources.v2.state +import java.io.File +import java.sql.Timestamp import java.util.UUID import org.apache.hadoop.conf.Configuration import org.scalatest.Assertions import org.apache.spark.sql.Row +import org.apache.spark.sql.execution.streaming.checkpointing.{CommitLog, CommitMetadata} import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream, StreamExecution} import org.apache.spark.sql.execution.streaming.state._ -import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.{col, window} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -47,6 +50,14 @@ class RocksDBWithChangelogCheckpointStateDataSourceChangeDataReaderSuite extends } } +class RocksDBWithCheckpointV2StateDataSourceChangeDataReaderSuite extends + RocksDBWithChangelogCheckpointStateDataSourceChangeDataReaderSuite { + override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set("spark.sql.streaming.stateStore.checkpointFormatVersion", "2") + } +} + abstract class StateDataSourceChangeDataReaderSuite extends StateDataSourceTestBase with Assertions { @@ -124,6 +135,39 @@ abstract class StateDataSourceChangeDataReaderSuite extends StateDataSourceTestB } } + test("ERROR: mixed checkpoint format versions not supported") { + withTempDir { tempDir => + val commitLog = new CommitLog(spark, + new File(tempDir.getAbsolutePath, "commits").getAbsolutePath) + + // Start version: treated as v1 (no operator unique ids) + val startMetadata = CommitMetadata(0, None) + assert(commitLog.add(0, startMetadata)) + + // End version: treated as v2 (operator 0 has unique ids) + val endMetadata = CommitMetadata(0, + Some(Map[Long, Array[Array[String]]](0L -> Array(Array("uid"))))) + assert(commitLog.add(1, endMetadata)) + + val exc = intercept[StateDataSourceMixedCheckpointFormatVersionsNotSupported] { + spark.read.format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.READ_CHANGE_FEED, true) + .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) + .option(StateSourceOptions.CHANGE_END_BATCH_ID, 1) + .load() + } + + checkError(exc, "STDS_MIXED_CHECKPOINT_FORMAT_VERSIONS_NOT_SUPPORTED", "KD002", + Map( + "startBatchId" -> "0", + "endBatchId" -> "1", + "startFormatVersion" -> "1", + "endFormatVersion" -> "2" + )) + } + } + test("ERROR: joinSide option is used together with readChangeFeed") { withTempDir { tempDir => val exc = intercept[StateDataSourceConflictOptions] { @@ -139,11 +183,16 @@ abstract class StateDataSourceChangeDataReaderSuite extends StateDataSourceTestB } test("getChangeDataReader of state store provider") { + val versionToCkptId = scala.collection.mutable.Map[Long, Option[String]]() + def withNewStateStore(provider: StateStoreProvider, version: Int)(f: StateStore => Unit): Unit = { - val stateStore = provider.getStore(version) + val stateStore = provider.getStore(version, versionToCkptId.getOrElse(version, None)) f(stateStore) stateStore.commit() + + val ssInfo = stateStore.getStateStoreCheckpointInfo() + versionToCkptId(ssInfo.batchVersion) = ssInfo.stateStoreCkptId } withTempDir { tempDir => @@ -158,7 +207,8 @@ abstract class StateDataSourceChangeDataReaderSuite extends StateDataSourceTestB stateStore.remove(dataToKeyRow("b", 2)) } val reader = - provider.asInstanceOf[SupportsFineGrainedReplay].getStateStoreChangeDataReader(1, 4) + provider.asInstanceOf[SupportsFineGrainedReplay] + .getStateStoreChangeDataReader(1, 4, None, versionToCkptId.getOrElse(4, None)) assert(reader.next() === (RecordType.PUT_RECORD, dataToKeyRow("a", 1), dataToValueRow(1), 0L)) assert(reader.next() === (RecordType.PUT_RECORD, dataToKeyRow("b", 2), dataToValueRow(2), 1L)) @@ -322,4 +372,133 @@ abstract class StateDataSourceChangeDataReaderSuite extends StateDataSourceTestB checkAnswer(keyToNumValuesDf, keyToNumValuesDfExpectedDf) } } + + test("read change feed past multiple snapshots") { + withSQLConf("spark.sql.streaming.stateStore.minDeltasForSnapshot" -> "2") { + withTempDir { tempDir => + val inputData = MemoryStream[Int] + val df = inputData.toDF().groupBy("value").count() + testStream(df, OutputMode.Update)( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + AddData(inputData, 1, 2, 3, 4, 1), + ProcessAllAvailable(), + AddData(inputData, 2, 3, 4, 5), + ProcessAllAvailable(), + AddData(inputData, 3, 4, 5, 6), + ProcessAllAvailable(), + AddData(inputData, 1, 1), + ProcessAllAvailable(), + AddData(inputData, 1, 1), + ProcessAllAvailable(), + AddData(inputData, 1, 1), + ProcessAllAvailable() + ) + + val stateDf = spark.read.format("statestore") + .option(StateSourceOptions.READ_CHANGE_FEED, value = true) + .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) + .option(StateSourceOptions.CHANGE_END_BATCH_ID, 5) + .load(tempDir.getAbsolutePath) + + val expectedDf = Seq( + Row(0L, "update", Row(3), Row(1), 1), + Row(1L, "update", Row(3), Row(2), 1), + Row(1L, "update", Row(5), Row(1), 1), + Row(2L, "update", Row(3), Row(3), 1), + Row(2L, "update", Row(5), Row(2), 1), + Row(0L, "update", Row(4), Row(1), 2), + Row(1L, "update", Row(4), Row(2), 2), + Row(2L, "update", Row(4), Row(3), 2), + Row(0L, "update", Row(1), Row(2), 3), + Row(3L, "update", Row(1), Row(4), 3), + Row(4L, "update", Row(1), Row(6), 3), + Row(5L, "update", Row(1), Row(8), 3), + Row(0L, "update", Row(2), Row(1), 4), + Row(1L, "update", Row(2), Row(2), 4), + Row(2L, "update", Row(6), Row(1), 4) + ) + + checkAnswer(stateDf, expectedDf) + + val stateDf2 = spark.read.format("statestore") + .option(StateSourceOptions.READ_CHANGE_FEED, value = true) + .option(StateSourceOptions.CHANGE_START_BATCH_ID, 1) + .option(StateSourceOptions.CHANGE_END_BATCH_ID, 3) + .load(tempDir.getAbsolutePath) + + val expectedDf2 = Seq( + Row(1L, "update", Row(3), Row(2), 1), + Row(1L, "update", Row(5), Row(1), 1), + Row(2L, "update", Row(3), Row(3), 1), + Row(2L, "update", Row(5), Row(2), 1), + Row(1L, "update", Row(4), Row(2), 2), + Row(2L, "update", Row(4), Row(3), 2), + Row(3L, "update", Row(1), Row(4), 3), + Row(1L, "update", Row(2), Row(2), 4), + Row(2L, "update", Row(6), Row(1), 4) + ) + + checkAnswer(stateDf2, expectedDf2) + + val stateDf3 = spark.read.format("statestore") + .option(StateSourceOptions.READ_CHANGE_FEED, value = true) + .option(StateSourceOptions.CHANGE_START_BATCH_ID, 2) + .option(StateSourceOptions.CHANGE_END_BATCH_ID, 4) + .load(tempDir.getAbsolutePath) + + val expectedDf3 = Seq( + Row(2L, "update", Row(3), Row(3), 1), + Row(2L, "update", Row(5), Row(2), 1), + Row(2L, "update", Row(4), Row(3), 2), + Row(3L, "update", Row(1), Row(4), 3), + Row(4L, "update", Row(1), Row(6), 3), + Row(2L, "update", Row(6), Row(1), 4) + ) + + checkAnswer(stateDf3, expectedDf3) + } + } + } + + test("read change feed with delete entries") { + withTempDir { tempDir => + val inputData = MemoryStream[(Int, Timestamp)] + val df = inputData.toDF() + .selectExpr("_1 as key", "_2 as ts") + .withWatermark("ts", "1 second") + .groupBy(window(col("ts"), "1 second")) + .count() + + val ts0 = Timestamp.valueOf("2025-01-01 00:00:00") + val ts1 = Timestamp.valueOf("2025-01-01 00:00:01") + val ts2 = Timestamp.valueOf("2025-01-01 00:00:02") + val ts3 = Timestamp.valueOf("2025-01-01 00:00:03") + val ts4 = Timestamp.valueOf("2025-01-01 00:00:04") + + testStream(df, OutputMode.Append)( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + AddData(inputData, (1, ts0), (2, ts0)), + ProcessAllAvailable(), + AddData(inputData, (3, ts2)), + ProcessAllAvailable(), + AddData(inputData, (4, ts3)), + ProcessAllAvailable(), + StopStream + ) + + val stateDf = spark.read.format("statestore") + .option(StateSourceOptions.READ_CHANGE_FEED, value = true) + .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) + .load(tempDir.getAbsolutePath) + + val expectedDf = Seq( + Row(0L, "update", Row(Row(ts0, ts1)), Row(2), 4), + Row(1L, "update", Row(Row(ts2, ts3)), Row(1), 1), + Row(2L, "delete", Row(Row(ts0, ts1)), null, 4), + Row(2L, "update", Row(Row(ts3, ts4)), Row(1), 4) + ) + + checkAnswer(stateDf, expectedDf) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala index d744304afb42..59c67973a328 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala @@ -690,21 +690,6 @@ class RocksDBWithCheckpointV2StateDataSourceReaderSuite extends StateDataSourceR Map( "optionName" -> StateSourceOptions.SNAPSHOT_START_BATCH_ID, "message" -> "Snapshot reading is currently not supported with checkpoint v2.")) - - // Verify reading change feed throws error with checkpoint v2 - val exc2 = intercept[StateDataSourceInvalidOptionValue] { - val stateDf = spark.read.format("statestore") - .option(StateSourceOptions.READ_CHANGE_FEED, value = true) - .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) - .option(StateSourceOptions.CHANGE_END_BATCH_ID, 1) - .load(tmpDir.getAbsolutePath) - stateDf.collect() - } - - checkError(exc2, "STDS_INVALID_OPTION_VALUE.WITH_MESSAGE", "42616", - Map( - "optionName" -> StateSourceOptions.READ_CHANGE_FEED, - "message" -> "Read change feed is currently not supported with checkpoint v2.")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala index 1800319fb8b4..2061cf645a03 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala @@ -1013,6 +1013,8 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest * the state data. */ testWithChangelogCheckpointingEnabled("snapshotStartBatchId with transformWithState") { + // TODO(SPARK-53332): Remove this line once snapshotStartBatchId is supported for V2 format + assume(SQLConf.get.stateStoreCheckpointFormatVersion == 1) class AggregationStatefulProcessor extends StatefulProcessor[Int, (Int, Long), (Int, Long)] { @transient protected var _countState: ValueState[Long] = _ @@ -1150,3 +1152,12 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest } } } + +class StateDataSourceTransformWithStateSuiteCheckpointV2 extends + StateDataSourceTransformWithStateSuite { + + override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION, 2) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBLineageSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBLineageSuite.scala new file mode 100644 index 000000000000..48ef4158266b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBLineageSuite.scala @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.SparkException +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.util.Utils + +class RocksDBLineageSuite extends SharedSparkSession { + private def newDB(root: String, enableCheckpointIds: Boolean): RocksDB = { + val conf = RocksDBConf().copy(enableChangelogCheckpointing = true) + new RocksDB( + root, + conf, + localRootDir = Utils.createTempDir(), + hadoopConf = new Configuration, + useColumnFamilies = false, + enableStateStoreCheckpointIds = enableCheckpointIds) + } + + private def writeChangelogWithLineage( + db: RocksDB, + version: Long, + uniqueId: String, + lineage: Array[LineageItem]): Unit = { + val writer = db.fileManager.getChangeLogWriter( + version, + useColumnFamilies = false, + checkpointUniqueId = Some(uniqueId), + stateStoreCheckpointIdLineage = Some(lineage)) + writer.commit() + } + + test("getFullLineage: single changelog covers full range") { + withTempDir { remoteDir => + val db = newDB(remoteDir.getAbsolutePath, enableCheckpointIds = true) + try { + val start = 3L + val end = 5L + val id3 = "i3" + val id4 = "i4" + val id5 = "i5" + writeChangelogWithLineage(db, end, id5, Array(LineageItem(4, id4), LineageItem(3, id3))) + + val result = db.getFullLineage(start, end, Some(id5)) + assert(result.map(_.version).sameElements(Array(3L, 4L, 5L))) + assert(result.map(_.checkpointUniqueId).sameElements(Array(id3, id4, id5))) + } finally { + db.close() + } + } + } + + test("getFullLineage: multi-hop across changelog files") { + withTempDir { remoteDir => + val db = newDB(remoteDir.getAbsolutePath, enableCheckpointIds = true) + try { + val start = 1L + val end = 5L + val id1 = "i1"; val id2 = "i2"; val id3 = "i3"; val id4 = "i4"; val id5 = "i5" + writeChangelogWithLineage(db, 3, id3, Array(LineageItem(2, id2), LineageItem(1, id1))) + writeChangelogWithLineage(db, 5, id5, Array(LineageItem(4, id4), LineageItem(3, id3))) + + val result = db.getFullLineage(start, end, Some(id5)) + assert(result.map(_.version).sameElements(Array(1L, 2L, 3L, 4L, 5L))) + assert(result.map(_.checkpointUniqueId).sameElements(Array(id1, id2, id3, id4, id5))) + } finally { + db.close() + } + } + } + + test("getFullLineage: multiple lineages exist for the same version") { + withTempDir { remoteDir => + val db = newDB(remoteDir.getAbsolutePath, enableCheckpointIds = true) + try { + val start = 1L + val end = 5L + val id1 = "i1"; val id2 = "i2"; val id3 = "i3"; val id4 = "i4"; val id5 = "i5" + writeChangelogWithLineage(db, 3, id3, Array(LineageItem(2, id2), LineageItem(1, id1))) + writeChangelogWithLineage(db, 5, id5, Array(LineageItem(4, id4), LineageItem(3, id3))) + // Insert a bad lineage for version 5 + // We should not use this lineage since we call getFullLineage with id5 + val badId4 = id4 + "bad" + val badId5 = id5 + "bad" + writeChangelogWithLineage(db, 5, badId5, Array(LineageItem(4, badId4))) + + val result = db.getFullLineage(start, end, Some(id5)) + assert(result.map(_.version).sameElements(Array(1L, 2L, 3L, 4L, 5L))) + assert(result.map(_.checkpointUniqueId).sameElements(Array(id1, id2, id3, id4, id5))) + } finally { + db.close() + } + } + } + + test("getFullLineage: start equals end returns single item") { + withTempDir { remoteDir => + val db = newDB(remoteDir.getAbsolutePath, enableCheckpointIds = true) + try { + val result = db.getFullLineage(7, 7, Some("i7")) + assert(result.map(_.version).sameElements(Array(7L))) + assert(result.map(_.checkpointUniqueId).sameElements(Array("i7"))) + } finally { + db.close() + } + } + } + + test("getFullLineage: missing intermediate version triggers validation error") { + withTempDir { remoteDir => + val db = newDB(remoteDir.getAbsolutePath, enableCheckpointIds = true) + try { + writeChangelogWithLineage(db, 5, "i5", Array(LineageItem(3, "i3"))) + val ex = intercept[SparkException] { + db.getFullLineage(3, 5, Some("i5")) + } + checkError( + ex, + condition = "CANNOT_LOAD_STATE_STORE.INVALID_CHECKPOINT_LINEAGE", + parameters = Map( + "lineage" -> "3:i3 5:i5", + "message" -> "Lineage versions are not increasing by one." + ) + ) + } finally { + db.close() + } + } + } + + test("getFullLineage: no progress in lineage triggers guard error") { + withTempDir { remoteDir => + val db = newDB(remoteDir.getAbsolutePath, enableCheckpointIds = true) + try { + writeChangelogWithLineage(db, 5, "i5", Array.empty) + val ex = intercept[SparkException] { + db.getFullLineage(3, 5, Some("i5")) + } + checkError( + ex, + condition = "CANNOT_LOAD_STATE_STORE.INVALID_CHECKPOINT_LINEAGE", + parameters = Map( + "lineage" -> "5:i5", + "message" -> "Cannot find version smaller than 5 in lineage." + ) + ) + } finally { + db.close() + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 4d1e789a70b0..0b1483241b92 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -1400,7 +1400,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] val hadoopConf = new Configuration() hadoopConf.set(StreamExecution.RUN_ID_KEY, UUID.randomUUID().toString) - val e = intercept[AssertionError] { + val e = intercept[StateStoreCheckpointIdsNotSupported] { provider.init( StateStoreId(newDir(), Random.nextInt(), 0), keySchema, @@ -1411,7 +1411,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] hadoopConf) } assert(e.getMessage.contains( - "HDFS State Store Provider doesn't support checkpointFormatVersion >= 2")) + "HDFSBackedStateStoreProvider does not support checkpointFormatVersion > 1")) } override def newStoreProvider(): HDFSBackedStateStoreProvider = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala index 5f4de279724a..1c8c567b73fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala @@ -755,3 +755,11 @@ class TransformWithStateInitialStateSuite extends StateStoreMetricsTest } } } + +class TransformWithStateInitialStateSuiteCheckpointV2 + extends TransformWithStateInitialStateSuite { + override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION, 2) + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org