brkyvz commented on code in PR #47895:
URL: https://github.com/apache/spark/pull/47895#discussion_r1805462938


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala:
##########
@@ -900,12 +906,57 @@ class MicroBatchExecution(
    */
   protected def markMicroBatchExecutionStart(execCtx: 
MicroBatchExecutionContext): Unit = {}
 
+  /**
+   * Store the state store checkpoint id for a finishing batch to 
`currentStateStoreCkptId`,
+   * which will be retrieved later when the next batch starts.
+   */
+  private def updateStateStoreCkptIdForOperator(
+      execCtx: MicroBatchExecutionContext,
+      opId: Long,
+      checkpointInfo: Array[StatefulOpStateStoreCheckpointInfo]): Unit = {
+    // TODO validate baseStateStoreCkptId
+    checkpointInfo.map(_.batchVersion).foreach { v =>
+      assert(
+        execCtx.batchId == -1 || v == execCtx.batchId + 1,
+        s"Batch version ${execCtx.batchId} should generate state store 
checkpoint " +
+          s"version ${execCtx.batchId + 1} but we see ${v}")
+    }
+    val ckptIds = checkpointInfo.map { info =>
+      assert(info.stateStoreCkptId.isDefined)
+      info.stateStoreCkptId.get
+    }
+    currentStateStoreCkptId.put(opId, ckptIds)
+  }
+
+  /**
+   * Walk the query plan `latestExecPlan` to find out a StateStoreWriter 
operator. Retrieve
+   * the state store checkpoint id from the operator and update it to 
`currentStateStoreCkptId`.
+   * @param execCtx
+   * @param latestExecPlan

Review Comment:
   @siying missed this comment



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala:
##########
@@ -210,7 +210,7 @@ abstract class StreamExecution(
     this, s"spark.streaming.${Option(name).getOrElse(id)}")
 
   /** Isolated spark session to run the batches with. */
-  private val sparkSessionForStream = sparkSession.cloneSession()
+  protected val sparkSessionForStream = sparkSession.cloneSession()

Review Comment:
   for the future - I feel like we should refactor these abstractions a bit to 
ensure that developers cannot make the same wrong usage of session mistakes 
again. Today it's too subtle and easy to hit



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala:
##########
@@ -44,19 +45,51 @@ import org.apache.spark.sql.execution.streaming.state._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.{OutputMode, StateOperatorProgress}
 import org.apache.spark.sql.types._
-import org.apache.spark.util.{CompletionIterator, NextIterator, Utils}
+import org.apache.spark.util.{CollectionAccumulator, CompletionIterator, 
NextIterator, Utils}
 
 
-/** Used to identify the state store for a given operator. */
+/** Used to identify the state store for a given operator.
+ *
+ * stateStoreCkptIds is used to identify the checkpoint used for a specific 
stateful operator
+ * The basic workflow works as following:
+ * 1. When a stateful operator is created, it passes in the checkpoint IDs for 
each stateful
+ *    operator through the StatefulOperatorStateInfo.
+ * 2. When a stateful task starts to execute, it will find the checkpointID 
for its shuffle
+ *    partition and use it to recover the state store. The ID is eventually 
passed into
+ *    the StateStore layer and eventually  RocksDB State Store, where it is 
used to make sure
+ *    the it loads the correct checkpoint
+ * 3. When the stateful task is finishing, after the state store is committed, 
the checkpoint ID
+ *    is fetched from the state store by calling 
StateStore.getStateStoreCheckpointInfo() and added
+ *    to the stateStoreCkptIds accumulator by calling
+ *    StateStoreWriter.setStateStoreCheckpointInfo().
+ * 4. When ending the batch, MicroBatchExecution calls each stateful operator's
+ *    getStateStoreCheckpointInfo() which aggregates checkpointIDs from 
different partitions. The
+ *    driver will persistent it into commit logs (not implemented yet).
+ * 5. When forming the next batch, the driver constructs the 
StatefulOperatorStateInfo with the
+ *    checkpoint IDs for the previous batch.
+ * */
 case class StatefulOperatorStateInfo(
     checkpointLocation: String,
     queryRunId: UUID,
     operatorId: Long,
     storeVersion: Long,
-    numPartitions: Int) {
+    numPartitions: Int,
+    stateStoreCkptIds: Option[Array[Array[String]]] = None) {
+
+  def getStateStoreCkptId(partitionId: Int): Option[Array[String]] = {
+    stateStoreCkptIds.map(_(partitionId))
+  }
+
   override def toString(): String = {
     s"state info [ checkpoint = $checkpointLocation, runId = $queryRunId, " +
-      s"opId = $operatorId, ver = $storeVersion, numPartitions = 
$numPartitions]"
+      s"opId = $operatorId, ver = $storeVersion, numPartitions = 
$numPartitions] " +
+      s"stateStoreCkptIds = $stateStoreCkptIds"
+  }
+}
+
+object StatefulOperatorStateInfo {
+  def enableStateStoreCheckpointIds(conf: SQLConf): Boolean = {

Review Comment:
   docs please



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to