micheal-o commented on code in PR #53720:
URL: https://github.com/apache/spark/pull/53720#discussion_r2692923807


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala:
##########
@@ -402,3 +477,15 @@ private[state] class 
StateRewriterUnsupportedStoreMetadataVersionError(
     checkpointLocation,
     subClass = "UNSUPPORTED_STATE_STORE_METADATA_VERSION",
     messageParameters = Map.empty)
+
+private[state] class StateRewriterCheckpointVersionMismatchError(

Review Comment:
   ditto: name



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala:
##########
@@ -224,9 +261,8 @@ class StateRewriter(
         schemaProvider,
         executorSqlConf
       )
-
-      partitionWriter.write(partitionIter)
-    }
+      Iterator(partitionWriter.write(partitionIter))
+    }.sortBy(_.partitionId).collect()

Review Comment:
   qq, I thought the result is already sorted by partitionId? or that isn't 
true?



##########
common/utils/src/main/resources/error/error-conditions.json:
##########
@@ -5608,6 +5608,13 @@
           "Unsupported state store metadata version encountered.",
           "Only StateStoreMetadataV1 and StateStoreMetadataV2 are supported."
         ]
+      },
+      "CHECKPOINT_VERSION_MISMATCH" : {

Review Comment:
   STATE_CHECKPOINT_FORMAT_VERSION_MISMATCH



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala:
##########
@@ -620,358 +957,42 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
         }
       }
     }
+  } // End of foreach loop for changelog checkpointing dimension
 
-    // Run transformWithState tests with different encoding formats
-    Seq("unsaferow", "avro").foreach { encodingFormat =>
-      def testWithChangelogAndEncodingConfig(testName: String)(testFun: => 
Unit): Unit = {
-        test(s"$testName ($changelogCpTestSuffix, encoding = 
$encodingFormat)") {
-          withSQLConf(
-            
"spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled" ->
-              changelogCheckpointingEnabled.toString,
-            SQLConf.STREAMING_STATE_STORE_ENCODING_FORMAT.key -> 
encodingFormat) {
-            testFun
-          }
-        }
-      }
-
-      testWithChangelogAndEncodingConfig(
-          "SPARK-54411: transformWithState with multiple column families") {
-        withTempDir { sourceDir =>
-          withTempDir { targetDir =>
-            val inputData = MemoryStream[String]
-            val query = inputData.toDS()
-              .groupByKey(x => x)
-              .transformWithState(new MultiStateVarProcessor(),
-                TimeMode.None(),
-                OutputMode.Update())
-            def runQuery(checkpointLocation: String, roundsOfData: Int): Unit 
= {
-              val dataActions = (1 to roundsOfData).flatMap { _ =>
-                Seq(
-                  AddData(inputData, "a", "b", "a"),
-                  ProcessAllAvailable()
-                )
-              }
-              testStream(query)(
-                Seq(StartStream(checkpointLocation = checkpointLocation)) ++
-                  dataActions ++
-                  Seq(StopStream): _*
-              )
-            }
-
-            runQuery(sourceDir.getAbsolutePath, 2)
-            runQuery(targetDir.getAbsolutePath, 1)
-
-            val schemas = 
MultiStateVarProcessorTestUtils.getSchemasWithMetadata()
-            val columnFamilyToSelectExprs = MultiStateVarProcessorTestUtils
-              .getColumnFamilyToSelectExprs()
-
-            val columnFamilyToStateSourceOptions = schemas.keys.map { cfName =>
-              val base = Map(
-                StateSourceOptions.STATE_VAR_NAME -> cfName
-              )
-
-              val withFlatten =
-                if (cfName == MultiStateVarProcessorTestUtils.ITEMS_LIST) {
-                  base + (StateSourceOptions.FLATTEN_COLLECTION_TYPES -> 
"true")
-                } else {
-                  base
-                }
-
-              cfName -> withFlatten
-            }.toMap
-
-            performRoundTripTest(
-              sourceDir.getAbsolutePath,
-              targetDir.getAbsolutePath,
-              storeToColumnFamilies = Map(StateStoreId.DEFAULT_STORE_NAME -> 
schemas.keys.toList),
-              storeToColumnFamilyToSelectExprs =
-                Map(StateStoreId.DEFAULT_STORE_NAME -> 
columnFamilyToSelectExprs),
-              storeToColumnFamilyToStateSourceOptions =
-                Map(StateStoreId.DEFAULT_STORE_NAME -> 
columnFamilyToStateSourceOptions),
-              operatorName = 
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
-            )
-          }
-        }
-      }
-
-      testWithChangelogAndEncodingConfig("SPARK-54411: transformWithState with 
eventTime timers") {
-        withTempDir { sourceDir =>
-          withTempDir { targetDir =>
-            val inputData = MemoryStream[(String, Long)]
-            val result = inputData.toDS()
-              .select(col("_1").as("key"), 
timestamp_seconds(col("_2")).as("eventTime"))
-              .withWatermark("eventTime", "10 seconds")
-              .as[(String, Timestamp)]
-              .groupByKey(_._1)
-              .transformWithState(
-                new EventTimeTimerProcessor(),
-                TimeMode.EventTime(),
-                OutputMode.Update())
-
-            testStream(result, OutputMode.Update())(
-              StartStream(checkpointLocation = sourceDir.getAbsolutePath),
-              AddData(inputData, ("a", 1L), ("b", 2L), ("c", 3L)),
-              ProcessAllAvailable(),
-              StopStream
-            )
-
-            testStream(result, OutputMode.Update())(
-              StartStream(checkpointLocation = targetDir.getAbsolutePath),
-              AddData(inputData, ("x", 1L)),
-              ProcessAllAvailable(),
-              StopStream
-            )
-
-            val (schemaMap, selectExprs, stateSourceOptions) =
-              getTimerStateConfigsForCountState(TimeMode.EventTime())
-
-            performRoundTripTest(
-              sourceDir.getAbsolutePath,
-              targetDir.getAbsolutePath,
-              storeToColumnFamilies =
-                Map(StateStoreId.DEFAULT_STORE_NAME -> schemaMap.keys.toList),
-              storeToColumnFamilyToSelectExprs =
-                Map(StateStoreId.DEFAULT_STORE_NAME -> selectExprs),
-              storeToColumnFamilyToStateSourceOptions =
-                Map(StateStoreId.DEFAULT_STORE_NAME -> stateSourceOptions),
-              operatorName = 
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
-            )
-          }
-        }
-      }
-
-      testWithChangelogAndEncodingConfig(
-        "SPARK-54411: transformWithState with processing time timers") {
-        withTempDir { sourceDir =>
-          withTempDir { targetDir =>
-            val clock = new StreamManualClock
-            val inputData = MemoryStream[String]
-            val result = inputData.toDS()
-              .groupByKey(x => x)
-              .transformWithState(new 
RunningCountStatefulProcessorWithProcTimeTimer(),
-                TimeMode.ProcessingTime(),
-                OutputMode.Update())
-
-            testStream(result, OutputMode.Update())(
-              StartStream(checkpointLocation = sourceDir.getAbsolutePath,
-                trigger = Trigger.ProcessingTime("1 second"),
-                triggerClock = clock),
-              AddData(inputData, "a"),
-              AdvanceManualClock(1 * 1000),
-              CheckNewAnswer(("a", "1")),
-              StopStream
-            )
-
-            val clock2 = new StreamManualClock
-            testStream(result, OutputMode.Update())(
-              StartStream(checkpointLocation = targetDir.getAbsolutePath,
-                trigger = Trigger.ProcessingTime("1 second"),
-                triggerClock = clock2),
-              AddData(inputData, "x"),
-              AdvanceManualClock(1 * 1000),
-              CheckNewAnswer(("a", "1"), ("x", "1")),
-              StopStream
-            )
-
-            val (schemaMap, selectExprs, sourceOptions) =
-              getTimerStateConfigsForCountState(TimeMode.ProcessingTime())
-
-            performRoundTripTest(
-              sourceDir.getAbsolutePath,
-              targetDir.getAbsolutePath,
-              storeToColumnFamilies =
-                Map(StateStoreId.DEFAULT_STORE_NAME -> schemaMap.keys.toList),
-              storeToColumnFamilyToSelectExprs =
-                Map(StateStoreId.DEFAULT_STORE_NAME -> selectExprs),
-              storeToColumnFamilyToStateSourceOptions =
-                Map(StateStoreId.DEFAULT_STORE_NAME -> sourceOptions),
-              operatorName = 
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
-            )
-          }
-        }
-      }
-
-      testWithChangelogAndEncodingConfig("SPARK-54411: transformWithState with 
list and TTL") {
-        withTempDir { sourceDir =>
-          withTempDir { targetDir =>
-            val clock = new StreamManualClock
-            val inputData = MemoryStream[InputEvent]
-            val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
-            val result = inputData.toDS()
-              .groupByKey(x => x.key)
-              .transformWithState(new ListStateTTLProcessor(ttlConfig),
-                TimeMode.ProcessingTime(),
-                OutputMode.Update())
-
-            testStream(result, OutputMode.Update())(
-              StartStream(checkpointLocation = sourceDir.getAbsolutePath,
-                trigger = Trigger.ProcessingTime("1 second"),
-                triggerClock = clock),
-              AddData(inputData, InputEvent("k1", "put", 1)),
-              AdvanceManualClock(1 * 1000),
-              CheckNewAnswer(),
-              StopStream
-            )
-
-            val clock2 = new StreamManualClock
-            testStream(result, OutputMode.Update())(
-              StartStream(checkpointLocation = targetDir.getAbsolutePath,
-                trigger = Trigger.ProcessingTime("1 second"),
-                triggerClock = clock2),
-              AddData(inputData, InputEvent("k1", "append", 1)),
-              AdvanceManualClock(1 * 1000),
-              CheckNewAnswer(),
-              StopStream
-            )
-
-            val schemas = 
TTLProcessorUtils.getListStateTTLSchemasWithMetadata()
-
-            val columnFamilyToSelectExprs = Map(
-              TTLProcessorUtils.LIST_STATE -> 
TTLProcessorUtils.getTTLSelectExpressions(
-                TTLProcessorUtils.LIST_STATE
-            ))
-
-            val columnFamilyToStateSourceOptions = schemas.keys.map { cfName =>
-              val base = Map(StateSourceOptions.STATE_VAR_NAME -> cfName)
-
-              val withFlatten =
-                if (cfName == TTLProcessorUtils.LIST_STATE) {
-                  base + (StateSourceOptions.FLATTEN_COLLECTION_TYPES -> 
"true")
-                } else {
-                  base
-                }
-
-              cfName -> withFlatten
-            }.toMap
-
-            performRoundTripTest(
-              sourceDir.getAbsolutePath,
-              targetDir.getAbsolutePath,
-              storeToColumnFamilies =
-                Map(StateStoreId.DEFAULT_STORE_NAME -> schemas.keys.toList),
-              storeToColumnFamilyToSelectExprs =
-                Map(StateStoreId.DEFAULT_STORE_NAME -> 
columnFamilyToSelectExprs),
-              storeToColumnFamilyToStateSourceOptions =
-                Map(StateStoreId.DEFAULT_STORE_NAME -> 
columnFamilyToStateSourceOptions),
-              operatorName = 
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
-            )
-          }
-        }
-      }
-
-      testWithChangelogAndEncodingConfig("SPARK-54411: transformWithState with 
map and TTL") {
-        withTempDir { sourceDir =>
-          withTempDir { targetDir =>
-            val clock = new StreamManualClock
-            val inputData = MemoryStream[MapInputEvent]
-            val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
-            val result = inputData.toDS()
-              .groupByKey(x => x.key)
-              .transformWithState(new MapStateTTLProcessor(ttlConfig),
-                TimeMode.ProcessingTime(),
-                OutputMode.Update())
-
-            testStream(result)(
-              StartStream(checkpointLocation = sourceDir.getAbsolutePath,
-                trigger = Trigger.ProcessingTime("1 second"),
-                triggerClock = clock),
-              AddData(inputData, MapInputEvent("a", "key1", "put", 1)),
-              AdvanceManualClock(1 * 1000),
-              CheckNewAnswer(),
-              StopStream
-            )
-
-            val clock2 = new StreamManualClock
-            testStream(result)(
-              StartStream(checkpointLocation = targetDir.getAbsolutePath,
-                trigger = Trigger.ProcessingTime("1 second"),
-                triggerClock = clock2),
-              AddData(inputData, MapInputEvent("x", "key1", "put", 1)),
-              AdvanceManualClock(1 * 1000),
-              CheckNewAnswer(),
-              StopStream
-            )
-
-            val schemas = TTLProcessorUtils.getMapStateTTLSchemasWithMetadata()
-
-            val columnFamilyToSelectExprs = Map(
-              TTLProcessorUtils.MAP_STATE -> 
TTLProcessorUtils.getTTLSelectExpressions(
-                TTLProcessorUtils.MAP_STATE
-            ))
-
-            val columnFamilyToStateSourceOptions = schemas.keys.map { cfName =>
-              cfName -> Map(StateSourceOptions.STATE_VAR_NAME -> cfName)
-            }.toMap
-
-            performRoundTripTest(
-              sourceDir.getAbsolutePath,
-              targetDir.getAbsolutePath,
-              storeToColumnFamilies =
-                Map(StateStoreId.DEFAULT_STORE_NAME -> schemas.keys.toList),
-              storeToColumnFamilyToSelectExprs =
-                Map(StateStoreId.DEFAULT_STORE_NAME -> 
columnFamilyToSelectExprs),
-              storeToColumnFamilyToStateSourceOptions =
-                Map(StateStoreId.DEFAULT_STORE_NAME -> 
columnFamilyToStateSourceOptions),
-              operatorName = 
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
-            )
+  test("SPARK-54590: Rewriter throw exception if checkpoint version is not set 
correct") {
+    withSQLConf(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "2") {
+      withTempDir { sourceDir =>
+        withTempDir { targetDir =>
+          // Step 1: Create state by running a streaming aggregation
+          runDropDuplicatesQuery(sourceDir.getAbsolutePath)
+          val sourceCheckpointMetadata = new StreamingQueryCheckpointMetadata(
+            spark, sourceDir.getAbsolutePath)
+          val readBatchId = 
sourceCheckpointMetadata.commitLog.getLatestBatchId().get
+          // Forced set STATE_STORE_CHECKPOINT_FORMAT_VERSION to 1 to mimic 
when user forgot to
+          // update checkpoint version to 2 in sqlConfig when running 
stateRewriter
+          // on checkpointV2 query.
+          spark.conf.unset(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key)
+          val rewriter = new StateRewriter(
+            spark,
+            readBatchId,
+            readBatchId + 1,
+            sourceDir.getAbsolutePath,
+            spark.sessionState.newHadoopConf()
+          )
+          val ex = intercept[StateRewriterInvalidCheckpointError] {
+            rewriter.run()
           }
-        }
-      }
-
-      testWithChangelogAndEncodingConfig("SPARK-54411: transformWithState with 
value and TTL") {
-        withTempDir { sourceDir =>
-          withTempDir { targetDir =>
-            val clock = new StreamManualClock
-            val inputData = MemoryStream[InputEvent]
-            val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
-            val result = inputData.toDS()
-              .groupByKey(x => x.key)
-              .transformWithState(new ValueStateTTLProcessor(ttlConfig),
-                TimeMode.ProcessingTime(),
-                OutputMode.Update())
-
-            testStream(result)(
-              StartStream(checkpointLocation = sourceDir.getAbsolutePath,
-                trigger = Trigger.ProcessingTime("1 second"),
-                triggerClock = clock),
-              AddData(inputData, InputEvent("k1", "put", 1)),
-              AddData(inputData, InputEvent("k2", "put", 2)),
-              AdvanceManualClock(1 * 1000),
-              CheckNewAnswer(),
-              StopStream
-            )
 
-            val clock2 = new StreamManualClock
-            testStream(result)(
-              StartStream(checkpointLocation = targetDir.getAbsolutePath,
-                trigger = Trigger.ProcessingTime("1 second"),
-                triggerClock = clock2),
-              AddData(inputData, InputEvent("x", "put", 1)),
-              AdvanceManualClock(1 * 1000),
-              CheckNewAnswer(),
-              StopStream
-            )
-
-            val schemas = 
TTLProcessorUtils.getValueStateTTLSchemasWithMetadata()
-
-            val columnFamilyToStateSourceOptions = schemas.keys.map { cfName =>
-              cfName -> Map(StateSourceOptions.STATE_VAR_NAME -> cfName)
-            }.toMap
-
-            performRoundTripTest(
-              sourceDir.getAbsolutePath,
-              targetDir.getAbsolutePath,
-              storeToColumnFamilies =
-                Map(StateStoreId.DEFAULT_STORE_NAME -> schemas.keys.toList),
-              storeToColumnFamilyToStateSourceOptions =
-                Map(StateStoreId.DEFAULT_STORE_NAME -> 
columnFamilyToStateSourceOptions),
-              operatorName = 
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
-            )
-          }
+          assert(ex.getCondition == 
"STATE_REWRITER_INVALID_CHECKPOINT.CHECKPOINT_VERSION_MISMATCH")

Review Comment:
   use `checkError` func instead of doing these asserts



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala:
##########
@@ -342,6 +378,37 @@ class StateRewriter(
       None
     }
   }
+
+  private def verifyCheckpointVersion(): Unit = {
+    // Verify checkpoint version in sqlConf based on commitLog for 
readCheckpoint
+    // in case user forgot to set STATE_STORE_CHECKPOINT_FORMAT_VERSION.
+    // Using read batch commit since the latest commit could be a skipped 
batch.
+    // If SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION is wrong, 
readCheckpoint.commitLog
+    // will throw an exception, and we will propagate this exception upstream.
+    // This prevents the StateRewriter from failing to write the correct state 
files
+    try {
+      val writeCheckpoint =
+        new StreamingQueryCheckpointMetadata(sparkSession, 
resolvedCheckpointLocation)
+      val readCheckpoint = if (readResolvedCheckpointLocation.isDefined) {
+        new StreamingQueryCheckpointMetadata(sparkSession, 
readResolvedCheckpointLocation.get)
+      } else {
+        // Same checkpoint for read & write
+        writeCheckpoint
+      }
+      readCheckpoint.commitLog.get(readBatchId)
+    } catch {
+        case e: IllegalStateException =>

Review Comment:
   Do: `case e: IllegalStateException if e.getCause != null && 
e.getCause.isInstanceOf[SparkThrowable] =>`. 
   
   Otherwise the code below can lead to `NullPointerException`



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala:
##########
@@ -81,24 +82,58 @@ class StateRewriter(
   private val stateRootLocation = new Path(
     resolvedCheckpointLocation, 
StreamingCheckpointConstants.DIR_NAME_STATE).toString
 
-  def run(): Unit = {
+  // If checkpoint id is enabled, return
+  // Map[operatorId, Array[partition -> Array[stateStore -> 
StateStoreCheckpointId]]].
+  // Otherwise, return None
+  def run(): Option[Map[Long, Array[Array[String]]]] = {
     logInfo(log"Starting state rewrite for " +
       log"checkpointLocation=${MDC(CHECKPOINT_LOCATION, 
resolvedCheckpointLocation)}, " +
       log"readCheckpointLocation=" +
       log"${MDC(CHECKPOINT_LOCATION, 
readResolvedCheckpointLocation.getOrElse(""))}, " +
       log"readBatchId=${MDC(BATCH_ID, readBatchId)}, " +
       log"writeBatchId=${MDC(BATCH_ID, writeBatchId)}")
 
-    val (_, timeTakenMs) = Utils.timeTakenMs {
+    val (checkpointIds, timeTakenMs) = Utils.timeTakenMs {
       runInternal()
     }
 
     logInfo(log"State rewrite completed in ${MDC(DURATION, timeTakenMs)} ms 
for " +
-      log"checkpointLocation=${MDC(CHECKPOINT_LOCATION, 
resolvedCheckpointLocation)}")
+      log"checkpointLocation=${MDC(CHECKPOINT_LOCATION, 
resolvedCheckpointLocation)} " +
+      log"checkpointInfos=${MDC(LAST_COMMITTED_CHECKPOINT_ID, checkpointIds)}"

Review Comment:
   Actually we can skip logging the `checkpointIds` here, since they could be 
alot (e.g. 200 per operator) and would cause noise.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala:
##########
@@ -342,6 +378,37 @@ class StateRewriter(
       None
     }
   }
+
+  private def verifyCheckpointVersion(): Unit = {
+    // Verify checkpoint version in sqlConf based on commitLog for 
readCheckpoint
+    // in case user forgot to set STATE_STORE_CHECKPOINT_FORMAT_VERSION.
+    // Using read batch commit since the latest commit could be a skipped 
batch.
+    // If SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION is wrong, 
readCheckpoint.commitLog
+    // will throw an exception, and we will propagate this exception upstream.
+    // This prevents the StateRewriter from failing to write the correct state 
files
+    try {
+      val writeCheckpoint =

Review Comment:
   Please look at my previous comment 
[here](https://github.com/apache/spark/pull/53720#discussion_r2688023829) again 
and the sample code I gave there:
   1. This isn't how I suggested  `writeCheckpoint` to be implemented. It needs 
to use the passed in write checkpoint metadata i.e. 
`writeCheckpointMetadata.getOrElse(...)`
   2. I also said `writeCheckpoint` & `readCheckpoint` to be class vals i.e. 
defined at class level as `private lazy val`
   
   Please take a look at that comment again.



##########
common/utils/src/main/resources/error/error-conditions.json:
##########
@@ -5608,6 +5608,13 @@
           "Unsupported state store metadata version encountered.",
           "Only StateStoreMetadataV1 and StateStoreMetadataV2 are supported."
         ]
+      },
+      "CHECKPOINT_VERSION_MISMATCH" : {
+        "message" : [
+          "The checkpoint format version in SQLConf does not match the 
checkpoint version in the commit log.",
+          "Expected version v<expectedVersion>, but found v<actualVersion>.",
+          "Please set '<sqlConfKey>' to <expectedVersion> in your SQLConf 
before running StateRewriter."

Review Comment:
   nit: `before retrying.` No need to say running StateRewriter since users 
don't know  about it.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala:
##########
@@ -364,6 +431,14 @@ private[state] object StateRewriterErrors {
       checkpointLocation: String): StateRewriterInvalidCheckpointError = {
     new StateRewriterUnsupportedStoreMetadataVersionError(checkpointLocation)
   }
+
+  def checkpointVersionMismatchError(

Review Comment:
   ditto: name



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to