neilramaswamy commented on code in PR #47895:
URL: https://github.com/apache/spark/pull/47895#discussion_r1762053718


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala:
##########
@@ -896,12 +903,55 @@ class MicroBatchExecution(
    */
   protected def markMicroBatchExecutionStart(execCtx: 
MicroBatchExecutionContext): Unit = {}
 
+  private def updateCheckpointIdForOperator(
+      execCtx: MicroBatchExecutionContext,
+      opId: Long,
+      checkpointInfo: Array[StateStoreCheckpointInfo]): Unit = {
+    // TODO validate baseCheckpointId
+    checkpointInfo.map(_.batchVersion).foreach { v =>
+      assert(
+        execCtx.batchId == -1 || v == execCtx.batchId + 1,
+        s"version $v doesn't match current Batch ID ${execCtx.batchId}")
+    }
+    currentCheckpointUniqueId.put(opId, checkpointInfo.map { c =>
+      assert(c.checkpointId.isDefined)
+      c.checkpointId.get
+    })
+  }
+
+  private def updateCheckpointId(
+      execCtx: MicroBatchExecutionContext,
+      latestExecPlan: SparkPlan): Unit = {
+    // This function cannot handle MBP now.
+    latestExecPlan.collect {
+      case e: StateStoreSaveExec =>
+        assert(e.stateInfo.isDefined)
+        updateCheckpointIdForOperator(execCtx, e.stateInfo.get.operatorId, 
e.getCheckpointInfo())
+      case e: SessionWindowStateStoreSaveExec =>
+        assert(e.stateInfo.isDefined)
+        updateCheckpointIdForOperator(execCtx, e.stateInfo.get.operatorId, 
e.getCheckpointInfo())
+      case e: StreamingDeduplicateExec =>
+        assert(e.stateInfo.isDefined)
+        updateCheckpointIdForOperator(execCtx, e.stateInfo.get.operatorId, 
e.getCheckpointInfo())
+      case e: StreamingDeduplicateWithinWatermarkExec =>
+        assert(e.stateInfo.isDefined)
+        updateCheckpointIdForOperator(execCtx, e.stateInfo.get.operatorId, 
e.getCheckpointInfo())
+      // TODO Need to deal with FlatMapGroupsWithStateExec, 
TransformWithStateExec,

Review Comment:
   You will, though, probably have to do some work to make sure that 
`getCheckpointInfo` can be called for any stateful operator.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala:
##########
@@ -896,12 +903,55 @@ class MicroBatchExecution(
    */
   protected def markMicroBatchExecutionStart(execCtx: 
MicroBatchExecutionContext): Unit = {}
 
+  private def updateCheckpointIdForOperator(
+      execCtx: MicroBatchExecutionContext,
+      opId: Long,
+      checkpointInfo: Array[StateStoreCheckpointInfo]): Unit = {
+    // TODO validate baseCheckpointId
+    checkpointInfo.map(_.batchVersion).foreach { v =>
+      assert(
+        execCtx.batchId == -1 || v == execCtx.batchId + 1,
+        s"version $v doesn't match current Batch ID ${execCtx.batchId}")
+    }
+    currentCheckpointUniqueId.put(opId, checkpointInfo.map { c =>
+      assert(c.checkpointId.isDefined)
+      c.checkpointId.get
+    })
+  }
+
+  private def updateCheckpointId(
+      execCtx: MicroBatchExecutionContext,
+      latestExecPlan: SparkPlan): Unit = {
+    // This function cannot handle MBP now.
+    latestExecPlan.collect {
+      case e: StateStoreSaveExec =>
+        assert(e.stateInfo.isDefined)
+        updateCheckpointIdForOperator(execCtx, e.stateInfo.get.operatorId, 
e.getCheckpointInfo())
+      case e: SessionWindowStateStoreSaveExec =>
+        assert(e.stateInfo.isDefined)
+        updateCheckpointIdForOperator(execCtx, e.stateInfo.get.operatorId, 
e.getCheckpointInfo())
+      case e: StreamingDeduplicateExec =>
+        assert(e.stateInfo.isDefined)
+        updateCheckpointIdForOperator(execCtx, e.stateInfo.get.operatorId, 
e.getCheckpointInfo())
+      case e: StreamingDeduplicateWithinWatermarkExec =>
+        assert(e.stateInfo.isDefined)
+        updateCheckpointIdForOperator(execCtx, e.stateInfo.get.operatorId, 
e.getCheckpointInfo())
+      // TODO Need to deal with FlatMapGroupsWithStateExec, 
TransformWithStateExec,

Review Comment:
   Why not?
   
   And I also don't see why we need to enumerate all of these here. Can we 
leverage the `StatefulOperator` trait and use that to get the state info? It 
should clean this up quite a bit.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala:
##########
@@ -126,12 +130,15 @@ class IncrementalExecution(
 
   /** Get the state info of the next stateful operator */
   private def nextStatefulOperationStateInfo(): StatefulOperatorStateInfo = {
-    StatefulOperatorStateInfo(
+    val operatorId = statefulOperatorId.getAndIncrement()
+    val ret = StatefulOperatorStateInfo(
       checkpointLocation,
       runId,
-      statefulOperatorId.getAndIncrement(),
+      operatorId,
       currentBatchId,
-      numStateStores)
+      numStateStores,
+      currentCheckpointUniqueId.get(operatorId))
+    ret

Review Comment:
   `ret` is not needed



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala:
##########
@@ -57,7 +59,9 @@ class IncrementalExecution(
     val prevOffsetSeqMetadata: Option[OffsetSeqMetadata],
     val offsetSeqMetadata: OffsetSeqMetadata,
     val watermarkPropagator: WatermarkPropagator,
-    val isFirstBatch: Boolean)
+    val isFirstBatch: Boolean,
+    val currentCheckpointUniqueId:
+      MutableMap[Long, Array[String]] = MutableMap[Long, Array[String]]())

Review Comment:
   Is it always true that partition IDs are always `[0, numPartitions)`?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala:
##########
@@ -896,12 +903,55 @@ class MicroBatchExecution(
    */
   protected def markMicroBatchExecutionStart(execCtx: 
MicroBatchExecutionContext): Unit = {}
 
+  private def updateCheckpointIdForOperator(
+      execCtx: MicroBatchExecutionContext,
+      opId: Long,
+      checkpointInfo: Array[StateStoreCheckpointInfo]): Unit = {
+    // TODO validate baseCheckpointId
+    checkpointInfo.map(_.batchVersion).foreach { v =>
+      assert(
+        execCtx.batchId == -1 || v == execCtx.batchId + 1,
+        s"version $v doesn't match current Batch ID ${execCtx.batchId}")
+    }
+    currentCheckpointUniqueId.put(opId, checkpointInfo.map { c =>
+      assert(c.checkpointId.isDefined)
+      c.checkpointId.get
+    })
+  }
+
+  private def updateCheckpointId(
+      execCtx: MicroBatchExecutionContext,
+      latestExecPlan: SparkPlan): Unit = {
+    // This function cannot handle MBP now.
+    latestExecPlan.collect {
+      case e: StateStoreSaveExec =>
+        assert(e.stateInfo.isDefined)
+        updateCheckpointIdForOperator(execCtx, e.stateInfo.get.operatorId, 
e.getCheckpointInfo())
+      case e: SessionWindowStateStoreSaveExec =>
+        assert(e.stateInfo.isDefined)
+        updateCheckpointIdForOperator(execCtx, e.stateInfo.get.operatorId, 
e.getCheckpointInfo())
+      case e: StreamingDeduplicateExec =>
+        assert(e.stateInfo.isDefined)
+        updateCheckpointIdForOperator(execCtx, e.stateInfo.get.operatorId, 
e.getCheckpointInfo())
+      case e: StreamingDeduplicateWithinWatermarkExec =>
+        assert(e.stateInfo.isDefined)
+        updateCheckpointIdForOperator(execCtx, e.stateInfo.get.operatorId, 
e.getCheckpointInfo())
+      // TODO Need to deal with FlatMapGroupsWithStateExec, 
TransformWithStateExec,
+      // FlatMapGroupsInPandasWithStateExec, StreamingSymmetricHashJoinExec,
+      // StreamingGlobalLimitExec later
+    }
+  }
+
   /**
    * Called after the microbatch has completed execution. It takes care of 
committing the offset
    * to commit log and other bookkeeping.
    */
   protected def markMicroBatchEnd(execCtx: MicroBatchExecutionContext): Unit = 
{
-    watermarkTracker.updateWatermark(execCtx.executionPlan.executedPlan)
+    val latestExecPlan = execCtx.executionPlan.executedPlan
+    watermarkTracker.updateWatermark(latestExecPlan)
+    if (sparkSession.sessionState.conf.stateStoreCheckpointFormatVersion >= 2) 
{

Review Comment:
   I don't really like the `>= 2` sprinkled everywhere. Can you define a 
constant somewhere, and then have a utility method that you call



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to