bogao007 commented on code in PR #48005:
URL: https://github.com/apache/spark/pull/48005#discussion_r1797417979


##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -505,22 +514,75 @@ def transformWithStateUDF(
 
             return result
 
-        if isinstance(outputStructType, str):
-            outputStructType = cast(StructType, 
_parse_datatype_string(outputStructType))
+        def transformWithStateWithInitStateUDF(
+                statefulProcessorApiClient: StatefulProcessorApiClient,

Review Comment:
   Nit: fix the indent.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -168,13 +169,10 @@ class TransformWithStateInPandasStateServer(
         val requestedState = message.getSetHandleState.getState
         requestedState match {
           case HandleState.CREATED =>
-            logInfo(log"set handle state to Created")

Review Comment:
   Why do we need to remove these log lines?



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -505,22 +514,75 @@ def transformWithStateUDF(
 
             return result
 
-        if isinstance(outputStructType, str):
-            outputStructType = cast(StructType, 
_parse_datatype_string(outputStructType))
+        def transformWithStateWithInitStateUDF(
+                statefulProcessorApiClient: StatefulProcessorApiClient,
+                key: Any,
+                inputRows: Iterator["PandasDataFrameLike"],
+                # for non first batch, initialStates will be None
+                initialStates: Iterator["PandasDataFrameLike"] = None
+        ) -> Iterator["PandasDataFrameLike"]:
+            handle = StatefulProcessorHandle(statefulProcessorApiClient)
+
+            if statefulProcessorApiClient.handle_state == 
StatefulProcessorHandleState.CREATED:
+                statefulProcessor.init(handle)
+                # only process initial state if first batch
+                is_first_batch = statefulProcessorApiClient.is_first_batch()
+                if is_first_batch and initialStates is not None:
+                    seen_init_state_on_key = False
+                    for cur_initial_state in initialStates:
+                        if seen_init_state_on_key:
+                            raise Exception(f"TransformWithStateWithInitState: 
Cannot have more "
+                                            f"than one row in the initial 
states for the same key. "
+                                            f"Grouping key: {key}.")
+                        statefulProcessorApiClient.set_implicit_key(key)
+                        statefulProcessor.handleInitialState(key, 
cur_initial_state)
+                        seen_init_state_on_key = True
+                statefulProcessorApiClient.set_handle_state(
+                    StatefulProcessorHandleState.INITIALIZED
+                )
+
+            # if we don't have state for the given key but in initial state,

Review Comment:
   I guess you wanted to say `if we only have initial state but don't have 
input rows for the given key, the inputRows iterator could be empty`?



##########
python/pyspark/sql/streaming/stateful_processor.py:
##########
@@ -234,3 +234,12 @@ def close(self) -> None:
         operations.
         """
         ...
+
+    def handleInitialState(
+            self, key: Any, initialState: "PandasDataFrameLike"

Review Comment:
   Nit: indent.



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -505,22 +514,75 @@ def transformWithStateUDF(
 
             return result
 
-        if isinstance(outputStructType, str):
-            outputStructType = cast(StructType, 
_parse_datatype_string(outputStructType))
+        def transformWithStateWithInitStateUDF(
+                statefulProcessorApiClient: StatefulProcessorApiClient,
+                key: Any,
+                inputRows: Iterator["PandasDataFrameLike"],
+                # for non first batch, initialStates will be None
+                initialStates: Iterator["PandasDataFrameLike"] = None
+        ) -> Iterator["PandasDataFrameLike"]:
+            handle = StatefulProcessorHandle(statefulProcessorApiClient)
+
+            if statefulProcessorApiClient.handle_state == 
StatefulProcessorHandleState.CREATED:

Review Comment:
   There's something not very clear to me here, could you help me understand 
more? 
   
   We only call `handleInitialState` when handle state is `CREATED`, but after 
we processed the initial state of the first grouping key, we update the state 
to be `INITIALIZED`. Wouldn't that skip the initial state for other grouping 
keys?



##########
python/pyspark/sql/streaming/stateful_processor_api_client.py:
##########
@@ -154,6 +155,25 @@ def get_list_state(
             # TODO(SPARK-49233): Classify user facing errors.
             raise PySparkRuntimeError(f"Error initializing value state: " 
f"{response_message[1]}")
 
+    def is_first_batch(self) -> bool:
+        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+        is_first_batch = stateMessage.IsFirstBatch()
+        request = stateMessage.UtilsCallCommand(isFirstBatch=is_first_batch)
+        stateful_processor_call = 
stateMessage.StatefulProcessorCall(utilsCall=request)
+        message = 
stateMessage.StateRequest(statefulProcessorCall=stateful_processor_call)
+
+        self._send_proto_message(message.SerializeToString())
+        response_message = self._receive_proto_message()
+        status = response_message[0]
+        if status == 0:
+            return True
+        elif status == 1:
+            return False
+        else:
+            # TODO(SPARK-49233): Classify user facing errors.
+            raise PySparkRuntimeError(f"Error getting batch id: " 
f"{response_message[1]}")

Review Comment:
   Should we have a better error message?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala:
##########
@@ -50,14 +51,130 @@ class TransformWithStateInPandasPythonRunner(
     initialWorkerConf: Map[String, String],
     override val pythonMetrics: Map[String, SQLMetric],
     jobArtifactUUID: Option[String],
-    groupingKeySchema: StructType)
-  extends BasePythonRunner[InType, OutType](funcs.map(_._1), evalType, 
argOffsets, jobArtifactUUID)
-  with PythonArrowInput[InType]
-  with BasicPythonArrowOutput
-  with Logging {
+    groupingKeySchema: StructType,
+    hasInitialState: Boolean)
+  extends TransformWithStateInPandasPythonBaseRunner[InType](
+    funcs, evalType, argOffsets, _schema, processorHandle, _timeZoneId,
+    initialWorkerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema, 
hasInitialState)
+  with PythonArrowInput[InType] {
 
-  private val sqlConf = SQLConf.get
-  private val arrowMaxRecordsPerBatch = sqlConf.arrowMaxRecordsPerBatch
+  private var pandasWriter: BaseStreamingArrowWriter = _
+
+  override protected def writeNextInputToArrowStream(
+      root: VectorSchemaRoot,
+      writer: ArrowStreamWriter,
+      dataOut: DataOutputStream,
+      inputIterator: Iterator[InType]): Boolean = {
+    if (pandasWriter == null) {
+      pandasWriter = new BaseStreamingArrowWriter(root, writer, 
arrowMaxRecordsPerBatch)
+    }
+
+    if (inputIterator.hasNext) {
+      val startData = dataOut.size()
+      val next = inputIterator.next()
+      val dataIter = next._2
+
+      while (dataIter.hasNext) {
+        val dataRow = dataIter.next()
+        pandasWriter.writeRow(dataRow)
+      }
+      pandasWriter.finalizeCurrentArrowBatch()
+      val deltaData = dataOut.size() - startData
+      pythonMetrics("pythonDataSent") += deltaData
+      true
+    } else {
+      super[PythonArrowInput].close()
+      false
+    }
+  }
+}
+
+/**
+ * Python runner with initial state in TransformWithStateInPandas.
+ * Write input data as one InternalRow(inputRow, initialState) in each row in 
arrow batch.
+ */
+class TransformWithStateInPandasPythonInitialStateRunner(
+    funcs: Seq[(ChainedPythonFunctions, Long)],
+    evalType: Int,
+    argOffsets: Array[Array[Int]],
+    dataSchema: StructType,
+    initStateSchema: StructType,
+    processorHandle: StatefulProcessorHandleImpl,
+    _timeZoneId: String,
+    initialWorkerConf: Map[String, String],
+    override val pythonMetrics: Map[String, SQLMetric],
+    jobArtifactUUID: Option[String],
+    groupingKeySchema: StructType,
+    hasInitialState: Boolean)
+  extends TransformWithStateInPandasPythonBaseRunner[GroupedInType](
+    funcs, evalType, argOffsets, dataSchema, processorHandle, _timeZoneId,
+    initialWorkerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema, 
hasInitialState)
+  with PythonArrowInput[GroupedInType] {
+
+  override protected lazy val schema: StructType = new StructType()
+    .add("state", dataSchema)

Review Comment:
   Maybe use `data` or `inputDate` to differentiate with `initState`?



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -505,22 +514,75 @@ def transformWithStateUDF(
 
             return result
 
-        if isinstance(outputStructType, str):
-            outputStructType = cast(StructType, 
_parse_datatype_string(outputStructType))
+        def transformWithStateWithInitStateUDF(
+                statefulProcessorApiClient: StatefulProcessorApiClient,
+                key: Any,
+                inputRows: Iterator["PandasDataFrameLike"],
+                # for non first batch, initialStates will be None
+                initialStates: Iterator["PandasDataFrameLike"] = None
+        ) -> Iterator["PandasDataFrameLike"]:
+            handle = StatefulProcessorHandle(statefulProcessorApiClient)
+
+            if statefulProcessorApiClient.handle_state == 
StatefulProcessorHandleState.CREATED:

Review Comment:
   If my understanding is correct, we should move the `handleInitialState` 
outside the handle state check, do it after the `init` call.



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -402,6 +404,9 @@ def transformWithStateInPandas(
             The output mode of the stateful processor.
         timeMode : str
             The time mode semantics of the stateful processor for timers and 
TTL.
+        initialState: "GroupedData"

Review Comment:
   Let's use something like below to represent the actual type.
   ```
   :class:`pyspark.sql.types.DataType`
   ```



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -505,22 +514,75 @@ def transformWithStateUDF(
 
             return result
 
-        if isinstance(outputStructType, str):
-            outputStructType = cast(StructType, 
_parse_datatype_string(outputStructType))
+        def transformWithStateWithInitStateUDF(
+                statefulProcessorApiClient: StatefulProcessorApiClient,
+                key: Any,
+                inputRows: Iterator["PandasDataFrameLike"],
+                # for non first batch, initialStates will be None

Review Comment:
   For non first batch, would initialStates be None or empty?



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -505,22 +514,75 @@ def transformWithStateUDF(
 
             return result
 
-        if isinstance(outputStructType, str):
-            outputStructType = cast(StructType, 
_parse_datatype_string(outputStructType))
+        def transformWithStateWithInitStateUDF(
+                statefulProcessorApiClient: StatefulProcessorApiClient,
+                key: Any,
+                inputRows: Iterator["PandasDataFrameLike"],
+                # for non first batch, initialStates will be None
+                initialStates: Iterator["PandasDataFrameLike"] = None
+        ) -> Iterator["PandasDataFrameLike"]:

Review Comment:
   Can we add some commentss on the possible input combinations that we need to 
handle in this udf for people to understand easier? IIUC there should be 3 
cases:
   - Both `inputRows` and `initialStates` contain data. This would only happen 
in the first batch and the associated grouping key contains both input data and 
initial state.
   - Only `inputRows` contains data. This could happen when either the grouping 
key doesn't have any initial state to process or it's non first batch.
   - Only `initialStates ` contains data. This could happen when the grouping 
key doesn't have any associated input data but it has initial state to process.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to