siying commented on code in PR #48460: URL: https://github.com/apache/spark/pull/48460#discussion_r1831890430
########## sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala: ########## @@ -0,0 +1,1103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.scalatest.Tag + +import org.apache.spark.{SparkContext, TaskContext} +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.streaming.{CommitLog, MemoryStream} +import org.apache.spark.sql.functions.count +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming._ +import org.apache.spark.sql.streaming.OutputMode.Update +import org.apache.spark.sql.test.TestSparkSession +import org.apache.spark.sql.types.StructType + +object CkptIdCollectingStateStoreWrapper { + // Internal list to hold checkpoint IDs (strings) + private var checkpointInfos: List[StateStoreCheckpointInfo] = List.empty + + // Method to add a string (checkpoint ID) to the list in a synchronized way + def addCheckpointInfo(checkpointID: StateStoreCheckpointInfo): Unit = synchronized { + checkpointInfos = checkpointID :: checkpointInfos + } + + // Method to read the list of checkpoint IDs in a synchronized way + def getStateStoreCheckpointInfos: List[StateStoreCheckpointInfo] = synchronized { + checkpointInfos + } + + def clear(): Unit = synchronized { + checkpointInfos = List.empty + } +} + +case class CkptIdCollectingStateStoreWrapper(innerStore: StateStore) extends StateStore { + + // Implement methods from ReadStateStore (parent trait) + + override def id: StateStoreId = innerStore.id + override def version: Long = innerStore.version + + override def get( + key: UnsafeRow, + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): UnsafeRow = { + innerStore.get(key, colFamilyName) + } + + override def valuesIterator( + key: UnsafeRow, + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[UnsafeRow] = { + innerStore.valuesIterator(key, colFamilyName) + } + + override def prefixScan( + prefixKey: UnsafeRow, + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[UnsafeRowPair] = { + innerStore.prefixScan(prefixKey, colFamilyName) + } + + override def iterator( + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[UnsafeRowPair] = { + innerStore.iterator(colFamilyName) + } + + override def abort(): Unit = innerStore.abort() + + // Implement methods from StateStore (current trait) + + override def removeColFamilyIfExists(colFamilyName: String): Boolean = { + innerStore.removeColFamilyIfExists(colFamilyName) + } + + override def createColFamilyIfAbsent( + colFamilyName: String, + keySchema: StructType, + valueSchema: StructType, + keyStateEncoderSpec: KeyStateEncoderSpec, + useMultipleValuesPerKey: Boolean = false, + isInternal: Boolean = false): Unit = { + innerStore.createColFamilyIfAbsent( + colFamilyName, + keySchema, + valueSchema, + keyStateEncoderSpec, + useMultipleValuesPerKey, + isInternal + ) + } + + override def put( + key: UnsafeRow, + value: UnsafeRow, + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = { + innerStore.put(key, value, colFamilyName) + } + + override def remove( + key: UnsafeRow, + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = { + innerStore.remove(key, colFamilyName) + } + + override def merge( + key: UnsafeRow, + value: UnsafeRow, + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = { + innerStore.merge(key, value, colFamilyName) + } + + override def commit(): Long = innerStore.commit() + override def metrics: StateStoreMetrics = innerStore.metrics + override def getStateStoreCheckpointInfo: StateStoreCheckpointInfo = { + val ret = innerStore.getStateStoreCheckpointInfo + CkptIdCollectingStateStoreWrapper.addCheckpointInfo(ret) + ret + } + override def hasCommitted: Boolean = innerStore.hasCommitted +} + +class CkptIdCollectingStateStoreProviderWrapper extends StateStoreProvider { + + val innerProvider = new RocksDBStateStoreProvider() + + // Now, delegate all methods in the wrapper class to the inner object + override def init( + stateStoreId: StateStoreId, + keySchema: StructType, + valueSchema: StructType, + keyStateEncoderSpec: KeyStateEncoderSpec, + useColumnFamilies: Boolean, + storeConfs: StateStoreConf, + hadoopConf: Configuration, + useMultipleValuesPerKey: Boolean = false): Unit = { + innerProvider.init( + stateStoreId, + keySchema, + valueSchema, + keyStateEncoderSpec, + useColumnFamilies, + storeConfs, + hadoopConf, + useMultipleValuesPerKey + ) + } + + override def stateStoreId: StateStoreId = innerProvider.stateStoreId + + override def close(): Unit = innerProvider.close() + + override def getStore(version: Long, stateStoreCkptId: Option[String] = None): StateStore = { + val innerStateStore = innerProvider.getStore(version, stateStoreCkptId) + CkptIdCollectingStateStoreWrapper(innerStateStore) + } + + override def getReadStore(version: Long, uniqueId: Option[String] = None): ReadStateStore = { + new WrappedReadStateStore( + CkptIdCollectingStateStoreWrapper(innerProvider.getReadStore(version, uniqueId))) + } + + override def doMaintenance(): Unit = innerProvider.doMaintenance() + + override def supportedCustomMetrics: Seq[StateStoreCustomMetric] = + innerProvider.supportedCustomMetrics +} + +// TODO add a test case for two of the tasks for the same shuffle partitions to finish and +// return their own state store checkpointID. This can happen because of task retry or +// speculative execution. +class RocksDBStateStoreCheckpointFormatV2Suite extends StreamTest + with AlsoTestWithRocksDBFeatures { + import testImplicits._ + + val providerClassName = classOf[CkptIdCollectingStateStoreProviderWrapper].getCanonicalName + + // Force test task retry number to be 2 + protected override def createSparkSession: TestSparkSession = { + new TestSparkSession(new SparkContext("local[1, 2]", this.getClass.getSimpleName, sparkConf)) + } + + override protected def beforeAll(): Unit = { + super.beforeAll() + } + + override def beforeEach(): Unit = { + CkptIdCollectingStateStoreWrapper.clear() + } + + def testWithCheckpointInfoTracked(testName: String, testTags: Tag*)( + testBody: => Any): Unit = { + super.testWithChangelogCheckpointingEnabled(testName, testTags: _*) { + super.beforeEach() + withSQLConf( + (SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName), + (SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "2"), + (SQLConf.SHUFFLE_PARTITIONS.key, "2")) { + testBody + } + // in case tests have any code that needs to execute after every test + super.afterEach() + } + } + + val changelogEnabled = + "spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled" -> "true" + val changelogDisabled = + "spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled" -> "false" + val ckptv1 = SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "1" + val ckptv2 = SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "2" + + val testConfigSetups = Seq( + // Enable and disable changelog under ckpt v2 + (Seq(changelogEnabled, ckptv2), Seq(changelogEnabled, ckptv2)), + (Seq(changelogDisabled, ckptv2), Seq(changelogDisabled, ckptv2)), + // Cross version cross changelog enabled/disabled + (Seq(changelogDisabled, ckptv1), Seq(changelogDisabled, ckptv2)), + (Seq(changelogEnabled, ckptv1), Seq(changelogEnabled, ckptv2)), + (Seq(changelogDisabled, ckptv1), Seq(changelogEnabled, ckptv2)), + (Seq(changelogEnabled, ckptv1), Seq(changelogDisabled, ckptv2)) + ) + + testConfigSetups.foreach { + case (firstRunConfig, secondRunConfig) => + testWithRocksDBStateStore("checkpointFormatVersion2 Backward Compatibility - simple agg - " + + s"first run: (changeLogEnabled, ckpt ver): " + + s"${firstRunConfig(0)._2}, ${firstRunConfig(1)._2}" + + s" - second run: ${secondRunConfig(0)._2}, ${secondRunConfig(1)._2}") { + withTempDir { checkpointDir => + val inputData = MemoryStream[Int] + val aggregated = + inputData + .toDF() + .groupBy($"value") + .agg(count("*")) + .as[(Int, Long)] + + withSQLConf(firstRunConfig: _*) { + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 3), + CheckLastBatch((3, 1)), + AddData(inputData, 3, 2), + CheckLastBatch((3, 2), (2, 1)), + StopStream + ) + + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 3, 2, 1), + CheckLastBatch((3, 3), (2, 2), (1, 1)), + // By default we run in new tuple mode. + AddData(inputData, 4, 4, 4, 4), + CheckLastBatch((4, 4)), + AddData(inputData, 5, 5), + CheckLastBatch((5, 2)) + ) + } + + withSQLConf(secondRunConfig: _*) { + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 4), + CheckLastBatch((4, 5)) + ) + } + } + } + } + + testConfigSetups.foreach { + case (firstRunConfig, secondRunConfig) => + testWithRocksDBStateStore("checkpointFormatVersion2 Backward Compatibility - dedup - " + + s"first run: (changeLogEnabled, ckpt ver): " + + s"${firstRunConfig(0)._2}, ${firstRunConfig(1)._2}" + + s" - second run: ${secondRunConfig(0)._2}, ${secondRunConfig(1)._2}") { + withTempDir { checkpointDir => + val inputData = MemoryStream[Int] + val deduplicated = inputData + .toDF() + .dropDuplicates("value") + .as[Int] + + withSQLConf(firstRunConfig: _*) { + testStream(deduplicated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 3), + CheckLastBatch(3), + AddData(inputData, 3, 2), + CheckLastBatch(2), + AddData(inputData, 3, 2, 1), + CheckLastBatch(1), + StopStream + ) + + // Test recovery + testStream(deduplicated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 4, 1, 3), + CheckLastBatch(4), + AddData(inputData, 5, 4, 4), + CheckLastBatch(5), + StopStream + ) + } + + withSQLConf(secondRunConfig: _*) { + // crash recovery again + testStream(deduplicated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 4, 7), + CheckLastBatch(7) + ) + } + } + } + } + + testConfigSetups.foreach { + case (firstRunConfig, secondRunConfig) => + testWithRocksDBStateStore("checkpointFormatVersion2 Backward Compatibility - " + + s"FlatMapGroupsWithState - first run: (changeLogEnabled, ckpt ver): " + + s"${firstRunConfig(0)._2}, ${firstRunConfig(1)._2}" + + s" - second run: ${secondRunConfig(0)._2}, ${secondRunConfig(1)._2}") { + withTempDir { checkpointDir => + val stateFunc = (key: Int, values: Iterator[Int], state: GroupState[Int]) => { + val count: Int = state.getOption.getOrElse(0) + values.size + state.update(count) + Iterator((key, count)) + } + + val inputData = MemoryStream[Int] + val aggregated = inputData + .toDF() + .toDF("key") + .selectExpr("key") + .as[Int] + .repartition($"key") + .groupByKey(x => x) + .flatMapGroupsWithState(OutputMode.Update, GroupStateTimeout.NoTimeout())(stateFunc) + + + withSQLConf(firstRunConfig: _*) { + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 3), + CheckLastBatch((3, 1)), + AddData(inputData, 3, 2), + CheckLastBatch((3, 2), (2, 1)), + StopStream + ) + + // Test recovery + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 4, 1, 3), + CheckLastBatch((4, 1), (1, 1), (3, 3)), + AddData(inputData, 5, 4, 4), + CheckLastBatch((5, 1), (4, 3)), + StopStream + ) + } + + withSQLConf(secondRunConfig: _*) { + // crash recovery again + // crash recovery again + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 4, 7), + CheckLastBatch((4, 4), (7, 1)), + AddData (inputData, 5), + CheckLastBatch((5, 2)), + StopStream + ) + } + } + } + } + + testConfigSetups.foreach { + case (firstRunConfig, secondRunConfig) => + testWithRocksDBStateStore("checkpointFormatVersion2 Backward Compatibility - ss join - " + + s"first run: (changeLogEnabled, ckpt ver): " + + s"${firstRunConfig(0)._2}, ${firstRunConfig(1)._2}" + + s" - second run: ${secondRunConfig(0)._2}, ${secondRunConfig(1)._2}") { + withTempDir { checkpointDir => + val inputData1 = MemoryStream[Int] + val inputData2 = MemoryStream[Int] + + val df1 = inputData1.toDS().toDF("value") + val df2 = inputData2.toDS().toDF("value") + + val joined = df1.join(df2, df1("value") === df2("value")) + + withSQLConf(firstRunConfig: _*) { + testStream(joined, OutputMode.Append)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData1, 3, 2), + AddData(inputData2, 3), + CheckLastBatch((3, 3)), + AddData(inputData2, 2), + // This data will be used after restarting the query + AddData(inputData1, 5), + CheckLastBatch((2, 2)), + StopStream + ) + + // Test recovery. + testStream(joined, OutputMode.Append)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData1, 4), + AddData(inputData2, 5), + CheckLastBatch((5, 5)), + AddData(inputData2, 4), + // This data will be used after restarting the query + AddData(inputData1, 7), + CheckLastBatch((4, 4)), + StopStream + ) + } + + withSQLConf(secondRunConfig: _*) { + // recovery again + testStream(joined, OutputMode.Append)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData1, 6), + AddData(inputData2, 6), + CheckLastBatch((6, 6)), + AddData(inputData2, 7), + CheckLastBatch((7, 7)), + StopStream + ) + + // recovery again + testStream(joined, OutputMode.Append)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData1, 8), + AddData(inputData2, 8), + CheckLastBatch((8, 8)), + StopStream + ) + } + } + } + } + + + testConfigSetups.foreach { + case (firstRunConfig, secondRunConfig) => + testWithRocksDBStateStore("checkpointFormatVersion2 Backward Compatibility - " + + "transformWithState - first run: (changeLogEnabled, ckpt ver): " + + s"${firstRunConfig(0)._2}, ${firstRunConfig(1)._2}" + + s" - second run: ${secondRunConfig(0)._2}, ${secondRunConfig(1)._2}") { + withTempDir { checkpointDir => + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + + withSQLConf(firstRunConfig: _*) { + testStream(result, Update())( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + Execute { q => + assert(q.lastProgress.stateOperators(0) + .customMetrics.get("numValueStateVars") > 0) + assert(q.lastProgress.stateOperators(0) + .customMetrics.get("numRegisteredTimers") == 0) + }, + AddData(inputData, "a", "b"), + CheckNewAnswer(("a", "2"), ("b", "1")), + StopStream + ) + testStream(result, Update())( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + // should remove state for "a" and not return anything for a + AddData(inputData, "a", "b"), + CheckNewAnswer(("b", "2")), + StopStream + ) + } + + withSQLConf(secondRunConfig: _*) { + testStream(result, Update())( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + // should recreate state for "a" and return count as 1 and + AddData(inputData, "a", "c"), + CheckNewAnswer(("a", "1"), ("c", "1")), + StopStream + ) + } + } + } + } + + test("checkpointFormatVersion2 validate ") { + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result, Update())( + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + Execute { q => + assert(q.lastProgress.stateOperators(0).customMetrics.get("numValueStateVars") > 0) + assert(q.lastProgress.stateOperators(0).customMetrics.get("numRegisteredTimers") == 0) + }, + AddData(inputData, "a", "b"), + CheckNewAnswer(("a", "2"), ("b", "1")), + StopStream, + StartStream(), + AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a + CheckNewAnswer(("b", "2")), + StopStream, + StartStream(), + AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and + CheckNewAnswer(("a", "1"), ("c", "1")) + ) + } + + // This test enable checkpoint format V2 without validating the checkpoint ID. Just to make + // sure it doesn't break and return the correct query results. + testWithChangelogCheckpointingEnabled(s"checkpointFormatVersion2") { + withSQLConf((SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key, "2")) { + withTempDir { checkpointDir => + val inputData = MemoryStream[Int] + val aggregated = + inputData + .toDF() + .groupBy($"value") + .agg(count("*")) + .as[(Int, Long)] + + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 3), + CheckLastBatch((3, 1)), + AddData(inputData, 3, 2), + CheckLastBatch((3, 2), (2, 1)), + StopStream + ) + + // Run the stream with changelog checkpointing enabled. + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 3, 2, 1), + CheckLastBatch((3, 3), (2, 2), (1, 1)), + // By default we run in new tuple mode. + AddData(inputData, 4, 4, 4, 4), + CheckLastBatch((4, 4)), + AddData(inputData, 5, 5), + CheckLastBatch((5, 2)) + ) + + // Run the stream with changelog checkpointing disabled. + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 4), + CheckLastBatch((4, 5)) + ) + } + } + } + + def validateBaseCheckpointInfo(): Unit = { + val checkpointInfoList = CkptIdCollectingStateStoreWrapper.getStateStoreCheckpointInfos + // Here we assume for every task, we fetch checkpointID from the N state stores in the same + // order. So we can separate stateStoreCkptId for different stores based on the order inside the + // same (batchId, partitionId) group. + val grouped = checkpointInfoList + .groupBy(info => (info.batchVersion, info.partitionId)) + .values + .flatMap { infos => + infos.zipWithIndex.map { case (info, index) => index -> info } + } + .groupBy(_._1) + .map { + case (_, grouped) => + grouped.map { case (_, info) => info } + } + + grouped.foreach { l => + for { + a <- l + b <- l + if a.partitionId == b.partitionId && a.batchVersion == b.batchVersion + 1 + } { + // if batch version exists, it should be the same as the checkpoint ID of the previous batch + assert(!a.baseStateStoreCkptId.isDefined || b.stateStoreCkptId == a.baseStateStoreCkptId) + } + } + } + + def validateCheckpointInfo( + numBatches: Int, + numStateStores: Int, + batchVersionSet: Set[Long]): Unit = { + val checkpointInfoList = CkptIdCollectingStateStoreWrapper.getStateStoreCheckpointInfos + // We have 6 batches, 2 partitions, and 1 state store per batch + assert(checkpointInfoList.size == numBatches * numStateStores * 2) + checkpointInfoList.foreach { l => + assert(l.stateStoreCkptId.isDefined) + if (batchVersionSet.contains(l.batchVersion)) { + assert(l.baseStateStoreCkptId.isDefined) + } + } + assert(checkpointInfoList.count(_.partitionId == 0) == numBatches * numStateStores) + assert(checkpointInfoList.count(_.partitionId == 1) == numBatches * numStateStores) + for (i <- 1 to numBatches) { + assert(checkpointInfoList.count(_.batchVersion == i) == numStateStores * 2) + } + validateBaseCheckpointInfo() + } + + + testWithCheckpointInfoTracked(s"checkpointFormatVersion2 validate ID - two jobs launched") { Review Comment: Can you explain what the test is for? ########## sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala: ########## @@ -0,0 +1,1103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.scalatest.Tag + +import org.apache.spark.{SparkContext, TaskContext} +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.streaming.{CommitLog, MemoryStream} +import org.apache.spark.sql.functions.count +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming._ +import org.apache.spark.sql.streaming.OutputMode.Update +import org.apache.spark.sql.test.TestSparkSession +import org.apache.spark.sql.types.StructType + +object CkptIdCollectingStateStoreWrapper { + // Internal list to hold checkpoint IDs (strings) + private var checkpointInfos: List[StateStoreCheckpointInfo] = List.empty + + // Method to add a string (checkpoint ID) to the list in a synchronized way + def addCheckpointInfo(checkpointID: StateStoreCheckpointInfo): Unit = synchronized { + checkpointInfos = checkpointID :: checkpointInfos + } + + // Method to read the list of checkpoint IDs in a synchronized way + def getStateStoreCheckpointInfos: List[StateStoreCheckpointInfo] = synchronized { + checkpointInfos + } + + def clear(): Unit = synchronized { + checkpointInfos = List.empty + } +} + +case class CkptIdCollectingStateStoreWrapper(innerStore: StateStore) extends StateStore { + + // Implement methods from ReadStateStore (parent trait) + + override def id: StateStoreId = innerStore.id + override def version: Long = innerStore.version + + override def get( + key: UnsafeRow, + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): UnsafeRow = { + innerStore.get(key, colFamilyName) + } + + override def valuesIterator( + key: UnsafeRow, + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[UnsafeRow] = { + innerStore.valuesIterator(key, colFamilyName) + } + + override def prefixScan( + prefixKey: UnsafeRow, + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[UnsafeRowPair] = { + innerStore.prefixScan(prefixKey, colFamilyName) + } + + override def iterator( + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[UnsafeRowPair] = { + innerStore.iterator(colFamilyName) + } + + override def abort(): Unit = innerStore.abort() + + // Implement methods from StateStore (current trait) + + override def removeColFamilyIfExists(colFamilyName: String): Boolean = { + innerStore.removeColFamilyIfExists(colFamilyName) + } + + override def createColFamilyIfAbsent( + colFamilyName: String, + keySchema: StructType, + valueSchema: StructType, + keyStateEncoderSpec: KeyStateEncoderSpec, + useMultipleValuesPerKey: Boolean = false, + isInternal: Boolean = false): Unit = { + innerStore.createColFamilyIfAbsent( + colFamilyName, + keySchema, + valueSchema, + keyStateEncoderSpec, + useMultipleValuesPerKey, + isInternal + ) + } + + override def put( + key: UnsafeRow, + value: UnsafeRow, + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = { + innerStore.put(key, value, colFamilyName) + } + + override def remove( + key: UnsafeRow, + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = { + innerStore.remove(key, colFamilyName) + } + + override def merge( + key: UnsafeRow, + value: UnsafeRow, + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = { + innerStore.merge(key, value, colFamilyName) + } + + override def commit(): Long = innerStore.commit() + override def metrics: StateStoreMetrics = innerStore.metrics + override def getStateStoreCheckpointInfo: StateStoreCheckpointInfo = { + val ret = innerStore.getStateStoreCheckpointInfo + CkptIdCollectingStateStoreWrapper.addCheckpointInfo(ret) + ret + } + override def hasCommitted: Boolean = innerStore.hasCommitted +} + +class CkptIdCollectingStateStoreProviderWrapper extends StateStoreProvider { + + val innerProvider = new RocksDBStateStoreProvider() + + // Now, delegate all methods in the wrapper class to the inner object + override def init( + stateStoreId: StateStoreId, + keySchema: StructType, + valueSchema: StructType, + keyStateEncoderSpec: KeyStateEncoderSpec, + useColumnFamilies: Boolean, + storeConfs: StateStoreConf, + hadoopConf: Configuration, + useMultipleValuesPerKey: Boolean = false): Unit = { + innerProvider.init( + stateStoreId, + keySchema, + valueSchema, + keyStateEncoderSpec, + useColumnFamilies, + storeConfs, + hadoopConf, + useMultipleValuesPerKey + ) + } + + override def stateStoreId: StateStoreId = innerProvider.stateStoreId + + override def close(): Unit = innerProvider.close() + + override def getStore(version: Long, stateStoreCkptId: Option[String] = None): StateStore = { + val innerStateStore = innerProvider.getStore(version, stateStoreCkptId) + CkptIdCollectingStateStoreWrapper(innerStateStore) + } + + override def getReadStore(version: Long, uniqueId: Option[String] = None): ReadStateStore = { + new WrappedReadStateStore( + CkptIdCollectingStateStoreWrapper(innerProvider.getReadStore(version, uniqueId))) + } + + override def doMaintenance(): Unit = innerProvider.doMaintenance() + + override def supportedCustomMetrics: Seq[StateStoreCustomMetric] = + innerProvider.supportedCustomMetrics +} + +// TODO add a test case for two of the tasks for the same shuffle partitions to finish and +// return their own state store checkpointID. This can happen because of task retry or +// speculative execution. +class RocksDBStateStoreCheckpointFormatV2Suite extends StreamTest + with AlsoTestWithRocksDBFeatures { + import testImplicits._ + + val providerClassName = classOf[CkptIdCollectingStateStoreProviderWrapper].getCanonicalName + + // Force test task retry number to be 2 + protected override def createSparkSession: TestSparkSession = { + new TestSparkSession(new SparkContext("local[1, 2]", this.getClass.getSimpleName, sparkConf)) + } + + override protected def beforeAll(): Unit = { + super.beforeAll() + } + + override def beforeEach(): Unit = { + CkptIdCollectingStateStoreWrapper.clear() + } + + def testWithCheckpointInfoTracked(testName: String, testTags: Tag*)( + testBody: => Any): Unit = { + super.testWithChangelogCheckpointingEnabled(testName, testTags: _*) { + super.beforeEach() + withSQLConf( + (SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName), + (SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "2"), + (SQLConf.SHUFFLE_PARTITIONS.key, "2")) { + testBody + } + // in case tests have any code that needs to execute after every test + super.afterEach() + } + } + + val changelogEnabled = + "spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled" -> "true" + val changelogDisabled = + "spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled" -> "false" + val ckptv1 = SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "1" + val ckptv2 = SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "2" + + val testConfigSetups = Seq( + // Enable and disable changelog under ckpt v2 + (Seq(changelogEnabled, ckptv2), Seq(changelogEnabled, ckptv2)), + (Seq(changelogDisabled, ckptv2), Seq(changelogDisabled, ckptv2)), + // Cross version cross changelog enabled/disabled + (Seq(changelogDisabled, ckptv1), Seq(changelogDisabled, ckptv2)), + (Seq(changelogEnabled, ckptv1), Seq(changelogEnabled, ckptv2)), + (Seq(changelogDisabled, ckptv1), Seq(changelogEnabled, ckptv2)), + (Seq(changelogEnabled, ckptv1), Seq(changelogDisabled, ckptv2)) + ) + + testConfigSetups.foreach { + case (firstRunConfig, secondRunConfig) => + testWithRocksDBStateStore("checkpointFormatVersion2 Backward Compatibility - simple agg - " + + s"first run: (changeLogEnabled, ckpt ver): " + + s"${firstRunConfig(0)._2}, ${firstRunConfig(1)._2}" + + s" - second run: ${secondRunConfig(0)._2}, ${secondRunConfig(1)._2}") { + withTempDir { checkpointDir => + val inputData = MemoryStream[Int] + val aggregated = + inputData + .toDF() + .groupBy($"value") + .agg(count("*")) + .as[(Int, Long)] + + withSQLConf(firstRunConfig: _*) { + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 3), + CheckLastBatch((3, 1)), + AddData(inputData, 3, 2), + CheckLastBatch((3, 2), (2, 1)), + StopStream + ) + + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 3, 2, 1), + CheckLastBatch((3, 3), (2, 2), (1, 1)), + // By default we run in new tuple mode. + AddData(inputData, 4, 4, 4, 4), + CheckLastBatch((4, 4)), + AddData(inputData, 5, 5), + CheckLastBatch((5, 2)) + ) + } + + withSQLConf(secondRunConfig: _*) { + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 4), + CheckLastBatch((4, 5)) + ) + } + } + } + } + + testConfigSetups.foreach { + case (firstRunConfig, secondRunConfig) => + testWithRocksDBStateStore("checkpointFormatVersion2 Backward Compatibility - dedup - " + + s"first run: (changeLogEnabled, ckpt ver): " + + s"${firstRunConfig(0)._2}, ${firstRunConfig(1)._2}" + + s" - second run: ${secondRunConfig(0)._2}, ${secondRunConfig(1)._2}") { + withTempDir { checkpointDir => + val inputData = MemoryStream[Int] + val deduplicated = inputData + .toDF() + .dropDuplicates("value") + .as[Int] + + withSQLConf(firstRunConfig: _*) { + testStream(deduplicated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 3), + CheckLastBatch(3), + AddData(inputData, 3, 2), + CheckLastBatch(2), + AddData(inputData, 3, 2, 1), + CheckLastBatch(1), + StopStream + ) + + // Test recovery + testStream(deduplicated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 4, 1, 3), + CheckLastBatch(4), + AddData(inputData, 5, 4, 4), + CheckLastBatch(5), + StopStream + ) + } + + withSQLConf(secondRunConfig: _*) { + // crash recovery again + testStream(deduplicated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 4, 7), + CheckLastBatch(7) + ) + } + } + } + } + + testConfigSetups.foreach { + case (firstRunConfig, secondRunConfig) => + testWithRocksDBStateStore("checkpointFormatVersion2 Backward Compatibility - " + + s"FlatMapGroupsWithState - first run: (changeLogEnabled, ckpt ver): " + + s"${firstRunConfig(0)._2}, ${firstRunConfig(1)._2}" + + s" - second run: ${secondRunConfig(0)._2}, ${secondRunConfig(1)._2}") { + withTempDir { checkpointDir => + val stateFunc = (key: Int, values: Iterator[Int], state: GroupState[Int]) => { + val count: Int = state.getOption.getOrElse(0) + values.size + state.update(count) + Iterator((key, count)) + } + + val inputData = MemoryStream[Int] + val aggregated = inputData + .toDF() + .toDF("key") + .selectExpr("key") + .as[Int] + .repartition($"key") + .groupByKey(x => x) + .flatMapGroupsWithState(OutputMode.Update, GroupStateTimeout.NoTimeout())(stateFunc) + + + withSQLConf(firstRunConfig: _*) { + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 3), + CheckLastBatch((3, 1)), + AddData(inputData, 3, 2), + CheckLastBatch((3, 2), (2, 1)), + StopStream + ) + + // Test recovery + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 4, 1, 3), + CheckLastBatch((4, 1), (1, 1), (3, 3)), + AddData(inputData, 5, 4, 4), + CheckLastBatch((5, 1), (4, 3)), + StopStream + ) + } + + withSQLConf(secondRunConfig: _*) { + // crash recovery again + // crash recovery again + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 4, 7), + CheckLastBatch((4, 4), (7, 1)), + AddData (inputData, 5), + CheckLastBatch((5, 2)), + StopStream + ) + } + } + } + } + + testConfigSetups.foreach { + case (firstRunConfig, secondRunConfig) => + testWithRocksDBStateStore("checkpointFormatVersion2 Backward Compatibility - ss join - " + + s"first run: (changeLogEnabled, ckpt ver): " + + s"${firstRunConfig(0)._2}, ${firstRunConfig(1)._2}" + + s" - second run: ${secondRunConfig(0)._2}, ${secondRunConfig(1)._2}") { + withTempDir { checkpointDir => + val inputData1 = MemoryStream[Int] + val inputData2 = MemoryStream[Int] + + val df1 = inputData1.toDS().toDF("value") + val df2 = inputData2.toDS().toDF("value") + + val joined = df1.join(df2, df1("value") === df2("value")) + + withSQLConf(firstRunConfig: _*) { + testStream(joined, OutputMode.Append)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData1, 3, 2), + AddData(inputData2, 3), + CheckLastBatch((3, 3)), + AddData(inputData2, 2), + // This data will be used after restarting the query + AddData(inputData1, 5), + CheckLastBatch((2, 2)), + StopStream + ) + + // Test recovery. + testStream(joined, OutputMode.Append)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData1, 4), + AddData(inputData2, 5), + CheckLastBatch((5, 5)), + AddData(inputData2, 4), + // This data will be used after restarting the query + AddData(inputData1, 7), + CheckLastBatch((4, 4)), + StopStream + ) + } + + withSQLConf(secondRunConfig: _*) { + // recovery again + testStream(joined, OutputMode.Append)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData1, 6), + AddData(inputData2, 6), + CheckLastBatch((6, 6)), + AddData(inputData2, 7), + CheckLastBatch((7, 7)), + StopStream + ) + + // recovery again + testStream(joined, OutputMode.Append)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData1, 8), + AddData(inputData2, 8), + CheckLastBatch((8, 8)), + StopStream + ) + } + } + } + } + + + testConfigSetups.foreach { + case (firstRunConfig, secondRunConfig) => + testWithRocksDBStateStore("checkpointFormatVersion2 Backward Compatibility - " + + "transformWithState - first run: (changeLogEnabled, ckpt ver): " + + s"${firstRunConfig(0)._2}, ${firstRunConfig(1)._2}" + + s" - second run: ${secondRunConfig(0)._2}, ${secondRunConfig(1)._2}") { + withTempDir { checkpointDir => + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + + withSQLConf(firstRunConfig: _*) { + testStream(result, Update())( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + Execute { q => + assert(q.lastProgress.stateOperators(0) + .customMetrics.get("numValueStateVars") > 0) + assert(q.lastProgress.stateOperators(0) + .customMetrics.get("numRegisteredTimers") == 0) + }, + AddData(inputData, "a", "b"), + CheckNewAnswer(("a", "2"), ("b", "1")), + StopStream + ) + testStream(result, Update())( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + // should remove state for "a" and not return anything for a + AddData(inputData, "a", "b"), + CheckNewAnswer(("b", "2")), + StopStream + ) + } + + withSQLConf(secondRunConfig: _*) { + testStream(result, Update())( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + // should recreate state for "a" and return count as 1 and + AddData(inputData, "a", "c"), + CheckNewAnswer(("a", "1"), ("c", "1")), + StopStream + ) + } + } + } + } + + test("checkpointFormatVersion2 validate ") { + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result, Update())( + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + Execute { q => + assert(q.lastProgress.stateOperators(0).customMetrics.get("numValueStateVars") > 0) + assert(q.lastProgress.stateOperators(0).customMetrics.get("numRegisteredTimers") == 0) + }, + AddData(inputData, "a", "b"), + CheckNewAnswer(("a", "2"), ("b", "1")), + StopStream, + StartStream(), + AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a + CheckNewAnswer(("b", "2")), + StopStream, + StartStream(), + AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and + CheckNewAnswer(("a", "1"), ("c", "1")) + ) + } + + // This test enable checkpoint format V2 without validating the checkpoint ID. Just to make + // sure it doesn't break and return the correct query results. + testWithChangelogCheckpointingEnabled(s"checkpointFormatVersion2") { + withSQLConf((SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key, "2")) { + withTempDir { checkpointDir => + val inputData = MemoryStream[Int] + val aggregated = + inputData + .toDF() + .groupBy($"value") + .agg(count("*")) + .as[(Int, Long)] + + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 3), + CheckLastBatch((3, 1)), + AddData(inputData, 3, 2), + CheckLastBatch((3, 2), (2, 1)), + StopStream + ) + + // Run the stream with changelog checkpointing enabled. + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 3, 2, 1), + CheckLastBatch((3, 3), (2, 2), (1, 1)), + // By default we run in new tuple mode. + AddData(inputData, 4, 4, 4, 4), + CheckLastBatch((4, 4)), + AddData(inputData, 5, 5), + CheckLastBatch((5, 2)) + ) + + // Run the stream with changelog checkpointing disabled. + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 4), + CheckLastBatch((4, 5)) + ) + } + } + } + + def validateBaseCheckpointInfo(): Unit = { + val checkpointInfoList = CkptIdCollectingStateStoreWrapper.getStateStoreCheckpointInfos + // Here we assume for every task, we fetch checkpointID from the N state stores in the same + // order. So we can separate stateStoreCkptId for different stores based on the order inside the + // same (batchId, partitionId) group. + val grouped = checkpointInfoList + .groupBy(info => (info.batchVersion, info.partitionId)) + .values + .flatMap { infos => + infos.zipWithIndex.map { case (info, index) => index -> info } + } + .groupBy(_._1) + .map { + case (_, grouped) => + grouped.map { case (_, info) => info } + } + + grouped.foreach { l => + for { + a <- l + b <- l + if a.partitionId == b.partitionId && a.batchVersion == b.batchVersion + 1 + } { + // if batch version exists, it should be the same as the checkpoint ID of the previous batch + assert(!a.baseStateStoreCkptId.isDefined || b.stateStoreCkptId == a.baseStateStoreCkptId) + } + } + } + + def validateCheckpointInfo( + numBatches: Int, + numStateStores: Int, + batchVersionSet: Set[Long]): Unit = { + val checkpointInfoList = CkptIdCollectingStateStoreWrapper.getStateStoreCheckpointInfos + // We have 6 batches, 2 partitions, and 1 state store per batch + assert(checkpointInfoList.size == numBatches * numStateStores * 2) + checkpointInfoList.foreach { l => + assert(l.stateStoreCkptId.isDefined) + if (batchVersionSet.contains(l.batchVersion)) { + assert(l.baseStateStoreCkptId.isDefined) + } + } + assert(checkpointInfoList.count(_.partitionId == 0) == numBatches * numStateStores) + assert(checkpointInfoList.count(_.partitionId == 1) == numBatches * numStateStores) + for (i <- 1 to numBatches) { + assert(checkpointInfoList.count(_.batchVersion == i) == numStateStores * 2) + } + validateBaseCheckpointInfo() + } + + + testWithCheckpointInfoTracked(s"checkpointFormatVersion2 validate ID - two jobs launched") { + withTempDir { checkpointDir => + val inputData = MemoryStream[Int] + val aggregated = + inputData + .toDF() + .groupBy($"value") + .agg(count("*")) + + val writer = (ds: DataFrame, batchId: Long) => { + ds.write.mode("append").saveAsTable("wei_test_t1") + ds.write.mode("append").saveAsTable("wei_test_t2") + } + + val query = aggregated.writeStream + .foreachBatch(writer) + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .outputMode("update") + .start() + + inputData.addData(1 to 100) + query.processAllAvailable() + + inputData.addData(1 to 100) + query.processAllAvailable() + + query.stop() + + val checkpointInfoList = CkptIdCollectingStateStoreWrapper.getStateStoreCheckpointInfos + + val pickedCheckpointInfoList = checkpointInfoList + .groupBy(x => (x.partitionId, x.batchVersion)).map(_._2.tail.head) + + println("wei== pickedCheckpointInfoList: ") + pickedCheckpointInfoList.foreach(println) + + println("wei== checkpointInfoList: ") + checkpointInfoList.foreach(println) + + Seq(0, 1).foreach { + partitionId => + val stateStoreCkptIds = pickedCheckpointInfoList + .filter(_.partitionId == partitionId).map(_.stateStoreCkptId) + val baseStateStoreCkptIds = pickedCheckpointInfoList + .filter(_.partitionId == partitionId).map(_.baseStateStoreCkptId) + + // Verify lineage for each partition across batches. Below should satisfy because + // these ids are stored in the following manner: + // stateStoreCkptIds: id3, id2, id1 + // baseStateStoreCkptIds: id2, id1, None + // Below checks [id2, id1] are the same, + // which is the lineage for this partition across batches + assert(stateStoreCkptIds.drop(1).iterator + .sameElements(baseStateStoreCkptIds.dropRight(1))) + } + + val versionToUniqueIdFromStateStore = Seq(1, 2).map { + batchVersion => + val res = pickedCheckpointInfoList + .filter(_.batchVersion == batchVersion).map(_.stateStoreCkptId.get) + + // batch Id is batchVersion - 1 + batchVersion - 1 -> res.toArray + }.toMap + + val commitLogPath = new Path( + new Path(checkpointDir.getAbsolutePath), "commits").toString + + val commitLog = new CommitLog(spark, commitLogPath) + val metadata_ = commitLog.get(Some(0), Some(1)).map(_._2) + + val versionToUniqueIdFromCommitLog = metadata_.zipWithIndex.map { case (metadata, idx) => + // Use stateUniqueIds(0) because there is only one state operator + val res2 = metadata.stateUniqueIds(0).map { uniqueIds => + uniqueIds + } + println("wei== res2") + res2.foreach(x => for (elem <- x) { + println(elem) + }) + idx -> res2 + }.toMap + + versionToUniqueIdFromCommitLog.foreach { + case (version, uniqueIds) => + versionToUniqueIdFromStateStore(version).sameElements(uniqueIds) + } + } + } + + // This test verifies when there are task retries, the unique ids actually stored in + // the commit log is the same as those recorded by CkptIdCollectingStateStoreWrapper + testWithCheckpointInfoTracked(s"checkpointFormatVersion2 validate ID - task retry") { + withTempDir { checkpointDir => + val inputData = MemoryStream[Int] + val aggregated = + inputData + .toDF() + .groupBy($"value") + .agg(count("*")) + + val writer = (ds: DataFrame, batchId: Long) => { + val _ = ds.rdd.filter { x => + val context = TaskContext.get() + // Retry in the first attempt + if (context.attemptNumber() == 0) { + throw new RuntimeException(s"fail the task at " + + s"partition ${context.partitionId()} batch Id: $batchId") + } + x.length >= 0 + }.collect() + } + + val query = aggregated.writeStream + .foreachBatch(writer) Review Comment: We are betting foreachbatch will run in the same task as the stateful operator. I assume you tested and it is right, but it is not robust enough to maintain as the underlying implementation can change. I think is is slightly better to do the close in ForEach() rather than foreachBatch(). ########## sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala: ########## @@ -0,0 +1,1103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.scalatest.Tag + +import org.apache.spark.{SparkContext, TaskContext} +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.streaming.{CommitLog, MemoryStream} +import org.apache.spark.sql.functions.count +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming._ +import org.apache.spark.sql.streaming.OutputMode.Update +import org.apache.spark.sql.test.TestSparkSession +import org.apache.spark.sql.types.StructType + +object CkptIdCollectingStateStoreWrapper { + // Internal list to hold checkpoint IDs (strings) + private var checkpointInfos: List[StateStoreCheckpointInfo] = List.empty + + // Method to add a string (checkpoint ID) to the list in a synchronized way + def addCheckpointInfo(checkpointID: StateStoreCheckpointInfo): Unit = synchronized { + checkpointInfos = checkpointID :: checkpointInfos + } + + // Method to read the list of checkpoint IDs in a synchronized way + def getStateStoreCheckpointInfos: List[StateStoreCheckpointInfo] = synchronized { + checkpointInfos + } + + def clear(): Unit = synchronized { + checkpointInfos = List.empty + } +} + +case class CkptIdCollectingStateStoreWrapper(innerStore: StateStore) extends StateStore { + + // Implement methods from ReadStateStore (parent trait) + + override def id: StateStoreId = innerStore.id + override def version: Long = innerStore.version + + override def get( + key: UnsafeRow, + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): UnsafeRow = { + innerStore.get(key, colFamilyName) + } + + override def valuesIterator( + key: UnsafeRow, + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[UnsafeRow] = { + innerStore.valuesIterator(key, colFamilyName) + } + + override def prefixScan( + prefixKey: UnsafeRow, + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[UnsafeRowPair] = { + innerStore.prefixScan(prefixKey, colFamilyName) + } + + override def iterator( + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[UnsafeRowPair] = { + innerStore.iterator(colFamilyName) + } + + override def abort(): Unit = innerStore.abort() + + // Implement methods from StateStore (current trait) + + override def removeColFamilyIfExists(colFamilyName: String): Boolean = { + innerStore.removeColFamilyIfExists(colFamilyName) + } + + override def createColFamilyIfAbsent( + colFamilyName: String, + keySchema: StructType, + valueSchema: StructType, + keyStateEncoderSpec: KeyStateEncoderSpec, + useMultipleValuesPerKey: Boolean = false, + isInternal: Boolean = false): Unit = { + innerStore.createColFamilyIfAbsent( + colFamilyName, + keySchema, + valueSchema, + keyStateEncoderSpec, + useMultipleValuesPerKey, + isInternal + ) + } + + override def put( + key: UnsafeRow, + value: UnsafeRow, + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = { + innerStore.put(key, value, colFamilyName) + } + + override def remove( + key: UnsafeRow, + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = { + innerStore.remove(key, colFamilyName) + } + + override def merge( + key: UnsafeRow, + value: UnsafeRow, + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = { + innerStore.merge(key, value, colFamilyName) + } + + override def commit(): Long = innerStore.commit() + override def metrics: StateStoreMetrics = innerStore.metrics + override def getStateStoreCheckpointInfo: StateStoreCheckpointInfo = { + val ret = innerStore.getStateStoreCheckpointInfo + CkptIdCollectingStateStoreWrapper.addCheckpointInfo(ret) + ret + } + override def hasCommitted: Boolean = innerStore.hasCommitted +} + +class CkptIdCollectingStateStoreProviderWrapper extends StateStoreProvider { + + val innerProvider = new RocksDBStateStoreProvider() + + // Now, delegate all methods in the wrapper class to the inner object + override def init( + stateStoreId: StateStoreId, + keySchema: StructType, + valueSchema: StructType, + keyStateEncoderSpec: KeyStateEncoderSpec, + useColumnFamilies: Boolean, + storeConfs: StateStoreConf, + hadoopConf: Configuration, + useMultipleValuesPerKey: Boolean = false): Unit = { + innerProvider.init( + stateStoreId, + keySchema, + valueSchema, + keyStateEncoderSpec, + useColumnFamilies, + storeConfs, + hadoopConf, + useMultipleValuesPerKey + ) + } + + override def stateStoreId: StateStoreId = innerProvider.stateStoreId + + override def close(): Unit = innerProvider.close() + + override def getStore(version: Long, stateStoreCkptId: Option[String] = None): StateStore = { + val innerStateStore = innerProvider.getStore(version, stateStoreCkptId) + CkptIdCollectingStateStoreWrapper(innerStateStore) + } + + override def getReadStore(version: Long, uniqueId: Option[String] = None): ReadStateStore = { + new WrappedReadStateStore( + CkptIdCollectingStateStoreWrapper(innerProvider.getReadStore(version, uniqueId))) + } + + override def doMaintenance(): Unit = innerProvider.doMaintenance() + + override def supportedCustomMetrics: Seq[StateStoreCustomMetric] = + innerProvider.supportedCustomMetrics +} + +// TODO add a test case for two of the tasks for the same shuffle partitions to finish and +// return their own state store checkpointID. This can happen because of task retry or +// speculative execution. +class RocksDBStateStoreCheckpointFormatV2Suite extends StreamTest + with AlsoTestWithRocksDBFeatures { + import testImplicits._ + + val providerClassName = classOf[CkptIdCollectingStateStoreProviderWrapper].getCanonicalName + + // Force test task retry number to be 2 + protected override def createSparkSession: TestSparkSession = { + new TestSparkSession(new SparkContext("local[1, 2]", this.getClass.getSimpleName, sparkConf)) + } + + override protected def beforeAll(): Unit = { + super.beforeAll() + } + + override def beforeEach(): Unit = { + CkptIdCollectingStateStoreWrapper.clear() + } + + def testWithCheckpointInfoTracked(testName: String, testTags: Tag*)( + testBody: => Any): Unit = { + super.testWithChangelogCheckpointingEnabled(testName, testTags: _*) { + super.beforeEach() + withSQLConf( + (SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName), + (SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "2"), + (SQLConf.SHUFFLE_PARTITIONS.key, "2")) { + testBody + } + // in case tests have any code that needs to execute after every test + super.afterEach() + } + } + + val changelogEnabled = + "spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled" -> "true" + val changelogDisabled = + "spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled" -> "false" + val ckptv1 = SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "1" + val ckptv2 = SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "2" + + val testConfigSetups = Seq( + // Enable and disable changelog under ckpt v2 + (Seq(changelogEnabled, ckptv2), Seq(changelogEnabled, ckptv2)), + (Seq(changelogDisabled, ckptv2), Seq(changelogDisabled, ckptv2)), + // Cross version cross changelog enabled/disabled + (Seq(changelogDisabled, ckptv1), Seq(changelogDisabled, ckptv2)), + (Seq(changelogEnabled, ckptv1), Seq(changelogEnabled, ckptv2)), + (Seq(changelogDisabled, ckptv1), Seq(changelogEnabled, ckptv2)), + (Seq(changelogEnabled, ckptv1), Seq(changelogDisabled, ckptv2)) + ) + + testConfigSetups.foreach { + case (firstRunConfig, secondRunConfig) => + testWithRocksDBStateStore("checkpointFormatVersion2 Backward Compatibility - simple agg - " + + s"first run: (changeLogEnabled, ckpt ver): " + + s"${firstRunConfig(0)._2}, ${firstRunConfig(1)._2}" + + s" - second run: ${secondRunConfig(0)._2}, ${secondRunConfig(1)._2}") { + withTempDir { checkpointDir => + val inputData = MemoryStream[Int] + val aggregated = + inputData + .toDF() + .groupBy($"value") + .agg(count("*")) + .as[(Int, Long)] + + withSQLConf(firstRunConfig: _*) { + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 3), + CheckLastBatch((3, 1)), + AddData(inputData, 3, 2), + CheckLastBatch((3, 2), (2, 1)), + StopStream + ) + + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 3, 2, 1), + CheckLastBatch((3, 3), (2, 2), (1, 1)), + // By default we run in new tuple mode. + AddData(inputData, 4, 4, 4, 4), + CheckLastBatch((4, 4)), + AddData(inputData, 5, 5), + CheckLastBatch((5, 2)) + ) + } + + withSQLConf(secondRunConfig: _*) { + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 4), + CheckLastBatch((4, 5)) + ) + } + } + } + } + + testConfigSetups.foreach { + case (firstRunConfig, secondRunConfig) => + testWithRocksDBStateStore("checkpointFormatVersion2 Backward Compatibility - dedup - " + + s"first run: (changeLogEnabled, ckpt ver): " + + s"${firstRunConfig(0)._2}, ${firstRunConfig(1)._2}" + + s" - second run: ${secondRunConfig(0)._2}, ${secondRunConfig(1)._2}") { + withTempDir { checkpointDir => + val inputData = MemoryStream[Int] + val deduplicated = inputData + .toDF() + .dropDuplicates("value") + .as[Int] + + withSQLConf(firstRunConfig: _*) { + testStream(deduplicated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 3), + CheckLastBatch(3), + AddData(inputData, 3, 2), + CheckLastBatch(2), + AddData(inputData, 3, 2, 1), + CheckLastBatch(1), + StopStream + ) + + // Test recovery + testStream(deduplicated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 4, 1, 3), + CheckLastBatch(4), + AddData(inputData, 5, 4, 4), + CheckLastBatch(5), + StopStream + ) + } + + withSQLConf(secondRunConfig: _*) { + // crash recovery again + testStream(deduplicated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 4, 7), + CheckLastBatch(7) + ) + } + } + } + } + + testConfigSetups.foreach { + case (firstRunConfig, secondRunConfig) => + testWithRocksDBStateStore("checkpointFormatVersion2 Backward Compatibility - " + + s"FlatMapGroupsWithState - first run: (changeLogEnabled, ckpt ver): " + + s"${firstRunConfig(0)._2}, ${firstRunConfig(1)._2}" + + s" - second run: ${secondRunConfig(0)._2}, ${secondRunConfig(1)._2}") { + withTempDir { checkpointDir => + val stateFunc = (key: Int, values: Iterator[Int], state: GroupState[Int]) => { + val count: Int = state.getOption.getOrElse(0) + values.size + state.update(count) + Iterator((key, count)) + } + + val inputData = MemoryStream[Int] + val aggregated = inputData + .toDF() + .toDF("key") + .selectExpr("key") + .as[Int] + .repartition($"key") + .groupByKey(x => x) + .flatMapGroupsWithState(OutputMode.Update, GroupStateTimeout.NoTimeout())(stateFunc) + + + withSQLConf(firstRunConfig: _*) { + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 3), + CheckLastBatch((3, 1)), + AddData(inputData, 3, 2), + CheckLastBatch((3, 2), (2, 1)), + StopStream + ) + + // Test recovery + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 4, 1, 3), + CheckLastBatch((4, 1), (1, 1), (3, 3)), + AddData(inputData, 5, 4, 4), + CheckLastBatch((5, 1), (4, 3)), + StopStream + ) + } + + withSQLConf(secondRunConfig: _*) { + // crash recovery again + // crash recovery again + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 4, 7), + CheckLastBatch((4, 4), (7, 1)), + AddData (inputData, 5), + CheckLastBatch((5, 2)), + StopStream + ) + } + } + } + } + + testConfigSetups.foreach { + case (firstRunConfig, secondRunConfig) => + testWithRocksDBStateStore("checkpointFormatVersion2 Backward Compatibility - ss join - " + + s"first run: (changeLogEnabled, ckpt ver): " + + s"${firstRunConfig(0)._2}, ${firstRunConfig(1)._2}" + + s" - second run: ${secondRunConfig(0)._2}, ${secondRunConfig(1)._2}") { + withTempDir { checkpointDir => + val inputData1 = MemoryStream[Int] + val inputData2 = MemoryStream[Int] + + val df1 = inputData1.toDS().toDF("value") + val df2 = inputData2.toDS().toDF("value") + + val joined = df1.join(df2, df1("value") === df2("value")) + + withSQLConf(firstRunConfig: _*) { + testStream(joined, OutputMode.Append)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData1, 3, 2), + AddData(inputData2, 3), + CheckLastBatch((3, 3)), + AddData(inputData2, 2), + // This data will be used after restarting the query + AddData(inputData1, 5), + CheckLastBatch((2, 2)), + StopStream + ) + + // Test recovery. + testStream(joined, OutputMode.Append)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData1, 4), + AddData(inputData2, 5), + CheckLastBatch((5, 5)), + AddData(inputData2, 4), + // This data will be used after restarting the query + AddData(inputData1, 7), + CheckLastBatch((4, 4)), + StopStream + ) + } + + withSQLConf(secondRunConfig: _*) { + // recovery again + testStream(joined, OutputMode.Append)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData1, 6), + AddData(inputData2, 6), + CheckLastBatch((6, 6)), + AddData(inputData2, 7), + CheckLastBatch((7, 7)), + StopStream + ) + + // recovery again + testStream(joined, OutputMode.Append)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData1, 8), + AddData(inputData2, 8), + CheckLastBatch((8, 8)), + StopStream + ) + } + } + } + } + + + testConfigSetups.foreach { + case (firstRunConfig, secondRunConfig) => + testWithRocksDBStateStore("checkpointFormatVersion2 Backward Compatibility - " + + "transformWithState - first run: (changeLogEnabled, ckpt ver): " + + s"${firstRunConfig(0)._2}, ${firstRunConfig(1)._2}" + + s" - second run: ${secondRunConfig(0)._2}, ${secondRunConfig(1)._2}") { + withTempDir { checkpointDir => + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + + withSQLConf(firstRunConfig: _*) { + testStream(result, Update())( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + Execute { q => + assert(q.lastProgress.stateOperators(0) + .customMetrics.get("numValueStateVars") > 0) + assert(q.lastProgress.stateOperators(0) + .customMetrics.get("numRegisteredTimers") == 0) + }, + AddData(inputData, "a", "b"), + CheckNewAnswer(("a", "2"), ("b", "1")), + StopStream + ) + testStream(result, Update())( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + // should remove state for "a" and not return anything for a + AddData(inputData, "a", "b"), + CheckNewAnswer(("b", "2")), + StopStream + ) + } + + withSQLConf(secondRunConfig: _*) { + testStream(result, Update())( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + // should recreate state for "a" and return count as 1 and + AddData(inputData, "a", "c"), + CheckNewAnswer(("a", "1"), ("c", "1")), + StopStream + ) + } + } + } + } + + test("checkpointFormatVersion2 validate ") { + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result, Update())( + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + Execute { q => + assert(q.lastProgress.stateOperators(0).customMetrics.get("numValueStateVars") > 0) + assert(q.lastProgress.stateOperators(0).customMetrics.get("numRegisteredTimers") == 0) + }, + AddData(inputData, "a", "b"), + CheckNewAnswer(("a", "2"), ("b", "1")), + StopStream, + StartStream(), + AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a + CheckNewAnswer(("b", "2")), + StopStream, + StartStream(), + AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and + CheckNewAnswer(("a", "1"), ("c", "1")) + ) + } + + // This test enable checkpoint format V2 without validating the checkpoint ID. Just to make + // sure it doesn't break and return the correct query results. + testWithChangelogCheckpointingEnabled(s"checkpointFormatVersion2") { + withSQLConf((SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key, "2")) { + withTempDir { checkpointDir => + val inputData = MemoryStream[Int] + val aggregated = + inputData + .toDF() + .groupBy($"value") + .agg(count("*")) + .as[(Int, Long)] + + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 3), + CheckLastBatch((3, 1)), + AddData(inputData, 3, 2), + CheckLastBatch((3, 2), (2, 1)), + StopStream + ) + + // Run the stream with changelog checkpointing enabled. + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 3, 2, 1), + CheckLastBatch((3, 3), (2, 2), (1, 1)), + // By default we run in new tuple mode. + AddData(inputData, 4, 4, 4, 4), + CheckLastBatch((4, 4)), + AddData(inputData, 5, 5), + CheckLastBatch((5, 2)) + ) + + // Run the stream with changelog checkpointing disabled. + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 4), + CheckLastBatch((4, 5)) + ) + } + } + } + + def validateBaseCheckpointInfo(): Unit = { + val checkpointInfoList = CkptIdCollectingStateStoreWrapper.getStateStoreCheckpointInfos + // Here we assume for every task, we fetch checkpointID from the N state stores in the same + // order. So we can separate stateStoreCkptId for different stores based on the order inside the + // same (batchId, partitionId) group. + val grouped = checkpointInfoList + .groupBy(info => (info.batchVersion, info.partitionId)) + .values + .flatMap { infos => + infos.zipWithIndex.map { case (info, index) => index -> info } + } + .groupBy(_._1) + .map { + case (_, grouped) => + grouped.map { case (_, info) => info } + } + + grouped.foreach { l => + for { + a <- l + b <- l + if a.partitionId == b.partitionId && a.batchVersion == b.batchVersion + 1 + } { + // if batch version exists, it should be the same as the checkpoint ID of the previous batch + assert(!a.baseStateStoreCkptId.isDefined || b.stateStoreCkptId == a.baseStateStoreCkptId) + } + } + } + + def validateCheckpointInfo( + numBatches: Int, + numStateStores: Int, + batchVersionSet: Set[Long]): Unit = { + val checkpointInfoList = CkptIdCollectingStateStoreWrapper.getStateStoreCheckpointInfos + // We have 6 batches, 2 partitions, and 1 state store per batch + assert(checkpointInfoList.size == numBatches * numStateStores * 2) + checkpointInfoList.foreach { l => + assert(l.stateStoreCkptId.isDefined) + if (batchVersionSet.contains(l.batchVersion)) { + assert(l.baseStateStoreCkptId.isDefined) + } + } + assert(checkpointInfoList.count(_.partitionId == 0) == numBatches * numStateStores) + assert(checkpointInfoList.count(_.partitionId == 1) == numBatches * numStateStores) + for (i <- 1 to numBatches) { + assert(checkpointInfoList.count(_.batchVersion == i) == numStateStores * 2) + } + validateBaseCheckpointInfo() + } + + + testWithCheckpointInfoTracked(s"checkpointFormatVersion2 validate ID - two jobs launched") { + withTempDir { checkpointDir => + val inputData = MemoryStream[Int] + val aggregated = + inputData + .toDF() + .groupBy($"value") + .agg(count("*")) + + val writer = (ds: DataFrame, batchId: Long) => { + ds.write.mode("append").saveAsTable("wei_test_t1") + ds.write.mode("append").saveAsTable("wei_test_t2") + } + + val query = aggregated.writeStream + .foreachBatch(writer) + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .outputMode("update") + .start() + + inputData.addData(1 to 100) + query.processAllAvailable() + + inputData.addData(1 to 100) + query.processAllAvailable() + + query.stop() + + val checkpointInfoList = CkptIdCollectingStateStoreWrapper.getStateStoreCheckpointInfos + + val pickedCheckpointInfoList = checkpointInfoList + .groupBy(x => (x.partitionId, x.batchVersion)).map(_._2.tail.head) + + println("wei== pickedCheckpointInfoList: ") + pickedCheckpointInfoList.foreach(println) + + println("wei== checkpointInfoList: ") + checkpointInfoList.foreach(println) + + Seq(0, 1).foreach { + partitionId => + val stateStoreCkptIds = pickedCheckpointInfoList + .filter(_.partitionId == partitionId).map(_.stateStoreCkptId) + val baseStateStoreCkptIds = pickedCheckpointInfoList + .filter(_.partitionId == partitionId).map(_.baseStateStoreCkptId) + + // Verify lineage for each partition across batches. Below should satisfy because + // these ids are stored in the following manner: + // stateStoreCkptIds: id3, id2, id1 + // baseStateStoreCkptIds: id2, id1, None + // Below checks [id2, id1] are the same, + // which is the lineage for this partition across batches + assert(stateStoreCkptIds.drop(1).iterator + .sameElements(baseStateStoreCkptIds.dropRight(1))) + } + + val versionToUniqueIdFromStateStore = Seq(1, 2).map { + batchVersion => + val res = pickedCheckpointInfoList + .filter(_.batchVersion == batchVersion).map(_.stateStoreCkptId.get) + + // batch Id is batchVersion - 1 + batchVersion - 1 -> res.toArray + }.toMap + + val commitLogPath = new Path( + new Path(checkpointDir.getAbsolutePath), "commits").toString + + val commitLog = new CommitLog(spark, commitLogPath) + val metadata_ = commitLog.get(Some(0), Some(1)).map(_._2) + + val versionToUniqueIdFromCommitLog = metadata_.zipWithIndex.map { case (metadata, idx) => + // Use stateUniqueIds(0) because there is only one state operator + val res2 = metadata.stateUniqueIds(0).map { uniqueIds => + uniqueIds + } + println("wei== res2") + res2.foreach(x => for (elem <- x) { + println(elem) + }) + idx -> res2 + }.toMap + + versionToUniqueIdFromCommitLog.foreach { + case (version, uniqueIds) => + versionToUniqueIdFromStateStore(version).sameElements(uniqueIds) + } + } + } + + // This test verifies when there are task retries, the unique ids actually stored in + // the commit log is the same as those recorded by CkptIdCollectingStateStoreWrapper + testWithCheckpointInfoTracked(s"checkpointFormatVersion2 validate ID - task retry") { + withTempDir { checkpointDir => + val inputData = MemoryStream[Int] + val aggregated = + inputData + .toDF() + .groupBy($"value") + .agg(count("*")) + + val writer = (ds: DataFrame, batchId: Long) => { + val _ = ds.rdd.filter { x => + val context = TaskContext.get() + // Retry in the first attempt + if (context.attemptNumber() == 0) { + throw new RuntimeException(s"fail the task at " + + s"partition ${context.partitionId()} batch Id: $batchId") + } + x.length >= 0 + }.collect() + } + + val query = aggregated.writeStream + .foreachBatch(writer) + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .outputMode("update") + .start() + + inputData.addData(1 to 100) + query.processAllAvailable() + + inputData.addData(1 to 100) + query.processAllAvailable() + + query.stop() + + val checkpointInfoList = CkptIdCollectingStateStoreWrapper.getStateStoreCheckpointInfos + // scalastyle:off line.size.limit + // Since every task is retried once, for each partition in each batch, we should have two + // state store checkpointInfo. Since these infos are appended sequentially, we can group them + // by partitionId and batchVersion, and pick the first one as they are the retried ones. + // e.g. + // [Picked] StateStoreCheckpointInfo[partitionId=1, batchVersion=2, stateStoreCkptId=Some(a9d5afec-0e8d-4473-b948-6c55513aa509), baseStateStoreCkptId=Some(061f7c53-b300-477a-a599-5387d55e315a)] + // [Picked] StateStoreCheckpointInfo[partitionId=0, batchVersion=2, stateStoreCkptId=Some(879cc517-6b85-4dae-abba-794bf2dbab82), baseStateStoreCkptId=Some(513726e7-2448-41a6-a874-92053c5cf86b)] + // StateStoreCheckpointInfo[partitionId=1, batchVersion=2, stateStoreCkptId=Some(7f4ad39f-d019-4ca2-8cf4-300379821cd6), baseStateStoreCkptId=Some(061f7c53-b300-477a-a599-5387d55e315a)] + // StateStoreCheckpointInfo[partitionId=0, batchVersion=2, stateStoreCkptId=Some(9dc215fe-54f9-4dc1-a59b-a8734f359e46), baseStateStoreCkptId=Some(513726e7-2448-41a6-a874-92053c5cf86b)] + // scalastyle:on line.size.limit + val pickedCheckpointInfoList = checkpointInfoList + .groupBy(x => (x.partitionId, x.batchVersion)).map(_._2.head) Review Comment: Does it assume that we always pick the first one in the list? Can we relax the check? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
