brkyvz commented on code in PR #47895:
URL: https://github.com/apache/spark/pull/47895#discussion_r1795913662


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala:
##########
@@ -126,6 +129,7 @@ class StateStoreRDD[T: ClassTag, U: ClassTag](
     val inputIter = dataRDD.iterator(partition, ctxt)
     val store = StateStore.get(
       storeProviderId, keySchema, valueSchema, keyStateEncoderSpec, 
storeVersion,
+      uniqueId.map(_(partition.index).head),

Review Comment:
   ditto



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala:
##########
@@ -90,6 +91,7 @@ class ReadStateStoreRDD[T: ClassTag, U: ClassTag](
     val inputIter = dataRDD.iterator(partition, ctxt)
     val store = StateStore.getReadOnly(
       storeProviderId, keySchema, valueSchema, keyStateEncoderSpec, 
storeVersion,
+      stateStoreCkptIds.map(_(partition.index).head),

Review Comment:
   nit: `_.apply(...).head`



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala:
##########
@@ -900,12 +906,57 @@ class MicroBatchExecution(
    */
   protected def markMicroBatchExecutionStart(execCtx: 
MicroBatchExecutionContext): Unit = {}
 
+  /**
+   * Store the state store checkpoint id for a finishing batch to 
`currentStateStoreCkptId`,
+   * which will be retrieved later when the next batch starts.
+   */
+  private def updateStateStoreCkptIdForOperator(
+      execCtx: MicroBatchExecutionContext,
+      opId: Long,
+      checkpointInfo: Array[StatefulOpStateStoreCheckpointInfo]): Unit = {
+    // TODO validate baseStateStoreCkptId
+    checkpointInfo.map(_.batchVersion).foreach { v =>
+      assert(
+        execCtx.batchId == -1 || v == execCtx.batchId + 1,
+        s"Batch version ${execCtx.batchId} should generate state store 
checkpoint " +
+          s"version ${execCtx.batchId + 1} but we see ${v}")
+    }
+    val ckptIds = checkpointInfo.map { info =>
+      assert(info.stateStoreCkptId.isDefined)
+      info.stateStoreCkptId.get
+    }
+    currentStateStoreCkptId.put(opId, ckptIds)
+  }
+
+  /**
+   * Walk the query plan `latestExecPlan` to find out a StateStoreWriter 
operator. Retrieve
+   * the state store checkpoint id from the operator and update it to 
`currentStateStoreCkptId`.
+   * @param execCtx
+   * @param latestExecPlan

Review Comment:
   you can remove these



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala:
##########
@@ -233,6 +238,15 @@ case class StateStoreMetrics(
     memoryUsedBytes: Long,
     customMetrics: Map[StateStoreCustomMetric, Long])
 
+case class StateStoreCheckpointInfo(
+    partitionId: Int,
+    batchVersion: Long,
+    // The checkpoint ID for a checkpoint at `batchVersion`. This is used to 
identify the checkpoint

Review Comment:
   can you move these above to @param lines in the scaladoc?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala:
##########
@@ -900,12 +906,57 @@ class MicroBatchExecution(
    */
   protected def markMicroBatchExecutionStart(execCtx: 
MicroBatchExecutionContext): Unit = {}
 
+  /**
+   * Store the state store checkpoint id for a finishing batch to 
`currentStateStoreCkptId`,
+   * which will be retrieved later when the next batch starts.
+   */
+  private def updateStateStoreCkptIdForOperator(
+      execCtx: MicroBatchExecutionContext,
+      opId: Long,
+      checkpointInfo: Array[StatefulOpStateStoreCheckpointInfo]): Unit = {
+    // TODO validate baseStateStoreCkptId
+    checkpointInfo.map(_.batchVersion).foreach { v =>
+      assert(
+        execCtx.batchId == -1 || v == execCtx.batchId + 1,
+        s"Batch version ${execCtx.batchId} should generate state store 
checkpoint " +
+          s"version ${execCtx.batchId + 1} but we see ${v}")
+    }
+    val ckptIds = checkpointInfo.map { info =>
+      assert(info.stateStoreCkptId.isDefined)
+      info.stateStoreCkptId.get
+    }
+    currentStateStoreCkptId.put(opId, ckptIds)
+  }
+
+  /**
+   * Walk the query plan `latestExecPlan` to find out a StateStoreWriter 
operator. Retrieve
+   * the state store checkpoint id from the operator and update it to 
`currentStateStoreCkptId`.
+   * @param execCtx
+   * @param latestExecPlan
+   */
+  private def updateStateStoreCkptId(
+      execCtx: MicroBatchExecutionContext,
+      latestExecPlan: SparkPlan): Unit = {
+    latestExecPlan.collect {
+      case e: StateStoreWriter =>
+        assert(e.stateInfo.isDefined)
+        updateStateStoreCkptIdForOperator(
+          execCtx,
+          e.stateInfo.get.operatorId,
+          e.getStateStoreCheckpointInfo())
+    }
+  }
+
   /**
    * Called after the microbatch has completed execution. It takes care of 
committing the offset
    * to commit log and other bookkeeping.
    */
   protected def markMicroBatchEnd(execCtx: MicroBatchExecutionContext): Unit = 
{
-    watermarkTracker.updateWatermark(execCtx.executionPlan.executedPlan)
+    val latestExecPlan = execCtx.executionPlan.executedPlan
+    watermarkTracker.updateWatermark(latestExecPlan)
+    if 
(StatefulOperatorStateInfo.enableStateStoreCheckpointIds(sparkSession.sessionState.conf))
 {

Review Comment:
   should you be using `sparkSessionForStream` here? Otherwise this can change 
from microbatch to microbatch, which is risky



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala:
##########
@@ -900,12 +906,46 @@ class MicroBatchExecution(
    */
   protected def markMicroBatchExecutionStart(execCtx: 
MicroBatchExecutionContext): Unit = {}
 
+  private def updateStateStoreCkptIdForOperator(
+      execCtx: MicroBatchExecutionContext,
+      opId: Long,
+      checkpointInfo: Array[StatefulOpStateStoreCheckpointInfo]): Unit = {
+    // TODO validate baseStateStoreCkptId
+    checkpointInfo.map(_.batchVersion).foreach { v =>
+      assert(
+        execCtx.batchId == -1 || v == execCtx.batchId + 1,
+        s"Batch version ${execCtx.batchId} should generate state store 
checkpoint " +
+          s"version ${execCtx.batchId + 1} but we see ${v}")
+    }
+    currentStateStoreCkptId.put(opId, checkpointInfo.map { c =>
+      assert(c.stateStoreCkptId.isDefined)
+      c.stateStoreCkptId.get
+    })
+  }
+
+  private def updateStateStoreCkptId(
+      execCtx: MicroBatchExecutionContext,
+      latestExecPlan: SparkPlan): Unit = {
+    latestExecPlan.collect {
+      case e: StateStoreWriter =>
+        assert(e.stateInfo.isDefined)

Review Comment:
   did you forget addressing this?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to