sahnib commented on code in PR #44884:
URL: https://github.com/apache/spark/pull/44884#discussion_r1473582927


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala:
##########
@@ -67,29 +67,37 @@ class QueryInfoImpl(
  * Class that provides a concrete implementation of a StatefulProcessorHandle. 
Note that we keep
  * track of valid transitions as various functions are invoked to track object 
lifecycle.
  * @param store - instance of state store
+ * @param runId - unique id for the current run
+ * @param isStreaming - defines whether the query is streaming or batch
  */
-class StatefulProcessorHandleImpl(store: StateStore, runId: UUID)
+class StatefulProcessorHandleImpl(store: StateStore, runId: UUID, isStreaming: 
Boolean = true)
   extends StatefulProcessorHandle with Logging {
   import StatefulProcessorHandleState._
 
   private def buildQueryInfo(): QueryInfo = {
-    val taskCtxOpt = Option(TaskContext.get())
-    // Task context is not available in tests, so we generate a random query 
id and batch id here
-    val queryId = if (taskCtxOpt.isDefined) {
-      taskCtxOpt.get.getLocalProperty(StreamExecution.QUERY_ID_KEY)
-    } else {
-      assert(Utils.isTesting, "Failed to find query id in task context")
-      UUID.randomUUID().toString
-    }
 
-    val batchId = if (taskCtxOpt.isDefined) {
-      taskCtxOpt.get.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY).toLong
+    if (!isStreaming) {
+      val queryId = "00000000-0000-0000-0000-000000000000"
+      val batchId = 0L
+      new QueryInfoImpl(UUID.fromString(queryId), runId, batchId)
     } else {
-      assert(Utils.isTesting, "Failed to find batch id in task context")
-      0
+      val taskCtxOpt = Option(TaskContext.get())
+      // Task context is not available in tests, so we generate a random query 
id and batch id here
+      val queryId = if (taskCtxOpt.isDefined) {
+        taskCtxOpt.get.getLocalProperty(StreamExecution.QUERY_ID_KEY)
+      } else {
+        assert(Utils.isTesting, "Failed to find query id in task context")
+        UUID.randomUUID().toString
+      }
+
+      val batchId = if (taskCtxOpt.isDefined) {
+        
taskCtxOpt.get.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY).toLong
+      } else {
+        assert(Utils.isTesting, "Failed to find batch id in task context")
+        0
+      }

Review Comment:
   [nit] We can simplify this to the below code. 
   
   ```
   private val BATCH_QUERY_ID = "00000000-0000-0000-0000-000000000000"
   
   .... 
   
      val (queryId, batchId) =  if (!isStreaming) {
         (BATCH_QUERY_ID, 0L)
       } else if (taskCtxOpt.isDefined) {
         (taskCtxOpt.get.getLocalProperty(StreamExecution.QUERY_ID_KEY),
           
taskCtxOpt.get.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY).toLong)
       } else {
         assert(Utils.isTesting, "Failed to find query id/batch Id in task 
context")
         (UUID.randomUUID().toString, 0)
       }
   
   ```



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala:
##########
@@ -742,6 +742,7 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
           batchTimestampMs = None,
           eventTimeWatermarkForLateEvents = None,
           eventTimeWatermarkForEviction = None,
+          isStreaming = true,

Review Comment:
   [nit] Should we now rename this class to 
`TransformWithStateStreamingSrategy` to make clear that its only for Streaming 
workloads. 



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala:
##########
@@ -152,22 +159,104 @@ case class TransformWithStateExec(
   override protected def doExecute(): RDD[InternalRow] = {
     metrics // force lazy init at driver
 
-    child.execute().mapPartitionsWithStateStore[InternalRow](
-      getStateInfo,
-      schemaForKeyRow,
-      schemaForValueRow,
-      numColsPrefixKey = 0,
-      session.sqlContext.sessionState,
-      Some(session.sqlContext.streams.stateStoreCoordinator),
-      useColumnFamilies = true
-    ) {
-      case (store: StateStore, singleIterator: Iterator[InternalRow]) =>
-        val processorHandle = new StatefulProcessorHandleImpl(store, 
getStateInfo.queryRunId)
-        assert(processorHandle.getHandleState == 
StatefulProcessorHandleState.CREATED)
-        statefulProcessor.init(processorHandle, outputMode)
-        
processorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED)
-        val result = processDataWithPartition(singleIterator, store, 
processorHandle)
-        result
+    if (isStreaming) {
+      child.execute().mapPartitionsWithStateStore[InternalRow](
+        getStateInfo,
+        schemaForKeyRow,
+        schemaForValueRow,
+        numColsPrefixKey = 0,
+        session.sqlContext.sessionState,
+        Some(session.sqlContext.streams.stateStoreCoordinator),
+        useColumnFamilies = true
+      ) {
+        case (store: StateStore, singleIterator: Iterator[InternalRow]) =>
+          processData(store, singleIterator)
+      }
+    } else {
+      // If the query is running in batch mode, we need to create a new 
StateStore and instantiate
+      // a temp directory on the executors in mapPartitionsWithIndex.
+      child.execute().mapPartitionsWithIndex[InternalRow](
+        (i, iter) => {
+          val providerId = new StateStoreProviderId(
+            StateStoreId(Utils.createTempDir().getAbsolutePath,

Review Comment:
   Do we need a temporary path per operator? We can create temp directory once 
and reuse it. StateStore should create sub-directories inside it for operators. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to