HeartSaVioR commented on code in PR #45467:
URL: https://github.com/apache/spark/pull/45467#discussion_r1540460865


##########
sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala:
##########
@@ -676,6 +676,42 @@ class KeyValueGroupedDataset[K, V] private[sql](
     )
   }
 
+  /**
+   * (Scala-specific)
+   * Invokes methods defined in the stateful processor used in arbitrary state 
API v2.
+   * Functions as the function above, but with additional initial state.
+   *
+   * @tparam U The type of the output objects. Must be encodable to Spark SQL 
types.
+   * @tparam S The type of initial state objects. Must be encodable to Spark 
SQL types.
+   * @param StatefulProcessorWithInitialState Instance of statefulProcessor 
whose functions will
+   *                                          be invoked by the operator.
+   * @param timeoutMode       The timeout mode of the stateful processor.
+   * @param outputMode        The output mode of the stateful processor. 
Defaults to APPEND mode.
+   * @param initialState      User provided initial state that will be used to 
initiate state for
+   *                          the query in the first batch.
+   *
+   */
+  private[sql] def transformWithState[U: Encoder, S: Encoder](
+      statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
+      timeoutMode: TimeoutMode,
+      outputMode: OutputMode,
+      initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = {
+    Dataset[U](
+      sparkSession,
+      TransformWithState[K, V, U, S](
+        groupingAttributes,
+        dataAttributes,
+        statefulProcessor,
+        timeoutMode,
+        outputMode,
+        child = logicalPlan,
+        initialState.groupingAttributes,
+        initialState.dataAttributes,
+        initialState.queryExecution.logical

Review Comment:
   Shall we follow the practice we did in flatMapGroupsWithState for safeness 
sake? 
   
   `initialState.queryExecution.analyzed`



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala:
##########
@@ -268,11 +268,13 @@ class IncrementalExecution(
         )
 
       case t: TransformWithStateExec =>
+        val hasInitialState = (isFirstBatch && t.hasInitialState)

Review Comment:
   I don't think we want to allow adding state in the middle of the query 
lifecycle. Here `isFirstBatch` does not mean batch ID = 0 but mean this is the 
first batch in this query run. 
   
   This should follow the above logic we did for FlatMapGroupsWithStateExec, 
`currentBatchId == 0L`.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala:
##########
@@ -268,11 +268,13 @@ class IncrementalExecution(
         )
 
       case t: TransformWithStateExec =>
+        val hasInitialState = (isFirstBatch && t.hasInitialState)

Review Comment:
   Please let me know if this is a different functionality than we had in 
flatMapGroupsWithState.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala:
##########
@@ -271,57 +320,111 @@ case class TransformWithStateExec(
       case _ =>
     }
 
-    if (isStreaming) {
-      child.execute().mapPartitionsWithStateStore[InternalRow](
+    if (hasInitialState) {
+      val storeConf = new StateStoreConf(session.sqlContext.sessionState.conf)
+      val hadoopConfBroadcast = sparkContext.broadcast(

Review Comment:
   Yeah there is a code comment. The practice seems to be that it's better to 
use broadcast rather than task serialization as it could be huge.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala:
##########
@@ -271,57 +320,111 @@ case class TransformWithStateExec(
       case _ =>
     }
 
-    if (isStreaming) {
-      child.execute().mapPartitionsWithStateStore[InternalRow](
+    if (hasInitialState) {
+      val storeConf = new StateStoreConf(session.sqlContext.sessionState.conf)
+      val hadoopConfBroadcast = sparkContext.broadcast(
+        new 
SerializableConfiguration(session.sqlContext.sessionState.newHadoopConf()))
+      child.execute().stateStoreAwareZipPartitions(
+        initialState.execute(),
         getStateInfo,
-        schemaForKeyRow,
-        schemaForValueRow,
-        NoPrefixKeyStateEncoderSpec(schemaForKeyRow),
-        session.sqlContext.sessionState,
-        Some(session.sqlContext.streams.stateStoreCoordinator),
-        useColumnFamilies = true,
-        useMultipleValuesPerKey = true
-      ) {
-        case (store: StateStore, singleIterator: Iterator[InternalRow]) =>
-          processData(store, singleIterator)
+        storeNames = Seq(),
+        session.sqlContext.streams.stateStoreCoordinator) {
+        // The state store aware zip partitions will provide us with two 
iterators,
+        // child data iterator and the initial state iterator per partition.
+        case (partitionId, childDataIterator, initStateIterator) =>
+          if (isStreaming) {
+            val stateStoreId = StateStoreId(stateInfo.get.checkpointLocation,
+              stateInfo.get.operatorId, partitionId)
+            val storeProviderId = StateStoreProviderId(stateStoreId, 
stateInfo.get.queryRunId)
+            val store = StateStore.get(
+              storeProviderId = storeProviderId,
+              keySchema = schemaForKeyRow,
+              valueSchema = schemaForValueRow,
+              NoPrefixKeyStateEncoderSpec(schemaForKeyRow),
+              version = stateInfo.get.storeVersion,
+              useColumnFamilies = true,
+              storeConf = storeConf,
+              hadoopConf = hadoopConfBroadcast.value.value
+            )
+
+            processDataWithInitialState(store, childDataIterator, 
initStateIterator)
+          } else {
+            val providerId = {
+              val tempDirPath = Utils.createTempDir().getAbsolutePath
+              new StateStoreProviderId(
+                StateStoreId(tempDirPath, 0, partitionId), 
getStateInfo.queryRunId)
+            }
+            val sqlConf = new SQLConf()
+            sqlConf.setConfString(SQLConf.STATE_STORE_PROVIDER_CLASS.key,
+              classOf[RocksDBStateStoreProvider].getName)
+
+            // Create StateStoreProvider for this partition
+            val stateStoreProvider = StateStoreProvider.createAndInit(
+              providerId,
+              schemaForKeyRow,
+              schemaForValueRow,
+              NoPrefixKeyStateEncoderSpec(schemaForKeyRow),
+              useColumnFamilies = true,
+              storeConf = new StateStoreConf(sqlConf),
+              hadoopConf = hadoopConfBroadcast.value.value,
+              useMultipleValuesPerKey = true)
+            val store = stateStoreProvider.getStore(0)
+
+            processDataWithInitialState(store, childDataIterator, 
initStateIterator)

Review Comment:
   We close the state store and state store provider in batch codepath (see 
below). Shall we do that here as well?
   
   Also, this is a good representation that we have duplicated code. two batch 
parts have similarity on spinning up state store provider and state store, and 
also closing them. That could be extracted out.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala:
##########
@@ -271,57 +320,111 @@ case class TransformWithStateExec(
       case _ =>
     }
 
-    if (isStreaming) {
-      child.execute().mapPartitionsWithStateStore[InternalRow](
+    if (hasInitialState) {
+      val storeConf = new StateStoreConf(session.sqlContext.sessionState.conf)
+      val hadoopConfBroadcast = sparkContext.broadcast(
+        new 
SerializableConfiguration(session.sqlContext.sessionState.newHadoopConf()))
+      child.execute().stateStoreAwareZipPartitions(
+        initialState.execute(),
         getStateInfo,
-        schemaForKeyRow,
-        schemaForValueRow,
-        NoPrefixKeyStateEncoderSpec(schemaForKeyRow),
-        session.sqlContext.sessionState,
-        Some(session.sqlContext.streams.stateStoreCoordinator),
-        useColumnFamilies = true,
-        useMultipleValuesPerKey = true
-      ) {
-        case (store: StateStore, singleIterator: Iterator[InternalRow]) =>
-          processData(store, singleIterator)
+        storeNames = Seq(),
+        session.sqlContext.streams.stateStoreCoordinator) {
+        // The state store aware zip partitions will provide us with two 
iterators,
+        // child data iterator and the initial state iterator per partition.
+        case (partitionId, childDataIterator, initStateIterator) =>
+          if (isStreaming) {
+            val stateStoreId = StateStoreId(stateInfo.get.checkpointLocation,
+              stateInfo.get.operatorId, partitionId)
+            val storeProviderId = StateStoreProviderId(stateStoreId, 
stateInfo.get.queryRunId)
+            val store = StateStore.get(
+              storeProviderId = storeProviderId,
+              keySchema = schemaForKeyRow,
+              valueSchema = schemaForValueRow,
+              NoPrefixKeyStateEncoderSpec(schemaForKeyRow),
+              version = stateInfo.get.storeVersion,
+              useColumnFamilies = true,
+              storeConf = storeConf,
+              hadoopConf = hadoopConfBroadcast.value.value
+            )
+
+            processDataWithInitialState(store, childDataIterator, 
initStateIterator)
+          } else {
+            val providerId = {
+              val tempDirPath = Utils.createTempDir().getAbsolutePath
+              new StateStoreProviderId(
+                StateStoreId(tempDirPath, 0, partitionId), 
getStateInfo.queryRunId)
+            }
+            val sqlConf = new SQLConf()
+            sqlConf.setConfString(SQLConf.STATE_STORE_PROVIDER_CLASS.key,
+              classOf[RocksDBStateStoreProvider].getName)
+
+            // Create StateStoreProvider for this partition
+            val stateStoreProvider = StateStoreProvider.createAndInit(
+              providerId,
+              schemaForKeyRow,
+              schemaForValueRow,
+              NoPrefixKeyStateEncoderSpec(schemaForKeyRow),
+              useColumnFamilies = true,
+              storeConf = new StateStoreConf(sqlConf),
+              hadoopConf = hadoopConfBroadcast.value.value,
+              useMultipleValuesPerKey = true)
+            val store = stateStoreProvider.getStore(0)
+
+            processDataWithInitialState(store, childDataIterator, 
initStateIterator)
+          }
       }
     } else {
-      // If the query is running in batch mode, we need to create a new 
StateStore and instantiate
-      // a temp directory on the executors in mapPartitionsWithIndex.
-      val broadcastedHadoopConf =
+      if (isStreaming) {
+        child.execute().mapPartitionsWithStateStore[InternalRow](
+          getStateInfo,
+          schemaForKeyRow,
+          schemaForValueRow,
+          NoPrefixKeyStateEncoderSpec(schemaForKeyRow),
+          session.sqlContext.sessionState,
+          Some(session.sqlContext.streams.stateStoreCoordinator),
+          useColumnFamilies = true
+        ) {
+          case (store: StateStore, singleIterator: Iterator[InternalRow]) =>
+            processData(store, singleIterator)
+        }
+      } else {
+        // If the query is running in batch mode, we need to create a new 
StateStore and instantiate

Review Comment:
   nit: apply the same practice while we are here? broadcast



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala:
##########
@@ -341,8 +444,37 @@ case class TransformWithStateExec(
     processorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED)
     processDataWithPartition(singleIterator, store, processorHandle)
   }
+
+  private def processDataWithInitialState(
+      store: StateStore,
+      childDataIterator: Iterator[InternalRow],
+      initStateIterator: Iterator[InternalRow]):
+    CompletionIterator[InternalRow, Iterator[InternalRow]] = {
+    val processorHandle = new StatefulProcessorHandleImpl(store, 
getStateInfo.queryRunId,
+      keyEncoder, timeoutMode, isStreaming)
+    assert(processorHandle.getHandleState == 
StatefulProcessorHandleState.CREATED)
+    statefulProcessor.setHandle(processorHandle)
+    statefulProcessor.init(outputMode, timeoutMode)
+    processorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED)
+
+    // Check if is first batch
+    // Only process initial states for first batch
+    if (processorHandle.getQueryInfo().getBatchId == 0) {

Review Comment:
   OK I see we have multiple checks. Though still better to change the 
condition in IncrementalExecution as reader can misunderstand that there are 
inconsistency between flatMapGroupsWithState and transformWithState.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to