brkyvz commented on code in PR #47895:
URL: https://github.com/apache/spark/pull/47895#discussion_r1805462938
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala:
##########
@@ -900,12 +906,57 @@ class MicroBatchExecution(
*/
protected def markMicroBatchExecutionStart(execCtx:
MicroBatchExecutionContext): Unit = {}
+ /**
+ * Store the state store checkpoint id for a finishing batch to
`currentStateStoreCkptId`,
+ * which will be retrieved later when the next batch starts.
+ */
+ private def updateStateStoreCkptIdForOperator(
+ execCtx: MicroBatchExecutionContext,
+ opId: Long,
+ checkpointInfo: Array[StatefulOpStateStoreCheckpointInfo]): Unit = {
+ // TODO validate baseStateStoreCkptId
+ checkpointInfo.map(_.batchVersion).foreach { v =>
+ assert(
+ execCtx.batchId == -1 || v == execCtx.batchId + 1,
+ s"Batch version ${execCtx.batchId} should generate state store
checkpoint " +
+ s"version ${execCtx.batchId + 1} but we see ${v}")
+ }
+ val ckptIds = checkpointInfo.map { info =>
+ assert(info.stateStoreCkptId.isDefined)
+ info.stateStoreCkptId.get
+ }
+ currentStateStoreCkptId.put(opId, ckptIds)
+ }
+
+ /**
+ * Walk the query plan `latestExecPlan` to find out a StateStoreWriter
operator. Retrieve
+ * the state store checkpoint id from the operator and update it to
`currentStateStoreCkptId`.
+ * @param execCtx
+ * @param latestExecPlan
Review Comment:
@siying missed this comment
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala:
##########
@@ -210,7 +210,7 @@ abstract class StreamExecution(
this, s"spark.streaming.${Option(name).getOrElse(id)}")
/** Isolated spark session to run the batches with. */
- private val sparkSessionForStream = sparkSession.cloneSession()
+ protected val sparkSessionForStream = sparkSession.cloneSession()
Review Comment:
for the future - I feel like we should refactor these abstractions a bit to
ensure that developers cannot make the same wrong usage of session mistakes
again. Today it's too subtle and easy to hit
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala:
##########
@@ -44,19 +45,51 @@ import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.{OutputMode, StateOperatorProgress}
import org.apache.spark.sql.types._
-import org.apache.spark.util.{CompletionIterator, NextIterator, Utils}
+import org.apache.spark.util.{CollectionAccumulator, CompletionIterator,
NextIterator, Utils}
-/** Used to identify the state store for a given operator. */
+/** Used to identify the state store for a given operator.
+ *
+ * stateStoreCkptIds is used to identify the checkpoint used for a specific
stateful operator
+ * The basic workflow works as following:
+ * 1. When a stateful operator is created, it passes in the checkpoint IDs for
each stateful
+ * operator through the StatefulOperatorStateInfo.
+ * 2. When a stateful task starts to execute, it will find the checkpointID
for its shuffle
+ * partition and use it to recover the state store. The ID is eventually
passed into
+ * the StateStore layer and eventually RocksDB State Store, where it is
used to make sure
+ * the it loads the correct checkpoint
+ * 3. When the stateful task is finishing, after the state store is committed,
the checkpoint ID
+ * is fetched from the state store by calling
StateStore.getStateStoreCheckpointInfo() and added
+ * to the stateStoreCkptIds accumulator by calling
+ * StateStoreWriter.setStateStoreCheckpointInfo().
+ * 4. When ending the batch, MicroBatchExecution calls each stateful operator's
+ * getStateStoreCheckpointInfo() which aggregates checkpointIDs from
different partitions. The
+ * driver will persistent it into commit logs (not implemented yet).
+ * 5. When forming the next batch, the driver constructs the
StatefulOperatorStateInfo with the
+ * checkpoint IDs for the previous batch.
+ * */
case class StatefulOperatorStateInfo(
checkpointLocation: String,
queryRunId: UUID,
operatorId: Long,
storeVersion: Long,
- numPartitions: Int) {
+ numPartitions: Int,
+ stateStoreCkptIds: Option[Array[Array[String]]] = None) {
+
+ def getStateStoreCkptId(partitionId: Int): Option[Array[String]] = {
+ stateStoreCkptIds.map(_(partitionId))
+ }
+
override def toString(): String = {
s"state info [ checkpoint = $checkpointLocation, runId = $queryRunId, " +
- s"opId = $operatorId, ver = $storeVersion, numPartitions =
$numPartitions]"
+ s"opId = $operatorId, ver = $storeVersion, numPartitions =
$numPartitions] " +
+ s"stateStoreCkptIds = $stateStoreCkptIds"
+ }
+}
+
+object StatefulOperatorStateInfo {
+ def enableStateStoreCheckpointIds(conf: SQLConf): Boolean = {
Review Comment:
docs please
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]