HeartSaVioR commented on code in PR #48005:
URL: https://github.com/apache/spark/pull/48005#discussion_r1827111574


##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -409,6 +410,9 @@ def transformWithStateInPandas(
             The output mode of the stateful processor.
         timeMode : str
             The time mode semantics of the stateful processor for timers and 
TTL.
+        initialState : :class:`pyspark.sql.GroupedData`
+            Optional. The grouped dataframe on given grouping key as initial 
states used for initialization

Review Comment:
   nit: Now the method doc for Scala version and PySpark version are diverged, 
not only for the type (which is expected) but also the description itself.
   
   For example, here is the explanation of `initialState` in Scala API: 
   > User provided initial state that will be used to initiate state for the 
query in the first batch.
   
   Probably better to revisit both API doc at some point and sync between twos.
   
   Before doing that, I think the part `on given grouping key` is redundant, 
and makes confusion. We should have checked the compatibility of the grouping 
key between two groups (current Dataset, and Dataset for initialState), right? 
If then we could just remove it.



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -551,25 +550,103 @@ def transformWithStateUDF(
             # TODO(SPARK-49603) set the handle state in the lazily initialized 
iterator
 
             result = itertools.chain(*result_iter_list)
+            return result
+
+        def transformWithStateUDF(
+            statefulProcessorApiClient: StatefulProcessorApiClient,
+            key: Any,
+            inputRows: Iterator["PandasDataFrameLike"],
+        ) -> Iterator["PandasDataFrameLike"]:
+            handle = StatefulProcessorHandle(statefulProcessorApiClient)
+
+            if statefulProcessorApiClient.handle_state == 
StatefulProcessorHandleState.CREATED:
+                statefulProcessor.init(handle)
+                statefulProcessorApiClient.set_handle_state(
+                    StatefulProcessorHandleState.INITIALIZED
+                )
+
+            result = handle_data_with_timers(statefulProcessorApiClient, key, 
inputRows)
+            return result
+
+        def transformWithStateWithInitStateUDF(
+            statefulProcessorApiClient: StatefulProcessorApiClient,
+            key: Any,
+            inputRows: Iterator["PandasDataFrameLike"],
+            initialStates: Iterator["PandasDataFrameLike"] = None,
+        ) -> Iterator["PandasDataFrameLike"]:
+            """
+            UDF for TWS operator with non-empty initial states. Possible input 
combinations
+            of inputRows and initialStates iterator:
+            - Both `inputRows` and `initialStates` are non-empty: for the 
given key, both input rows
+              and initial states contains the grouping key, both input rows 
and initial states contains data.
+            - `InitialStates` is non-empty, while `initialStates` is empty. 
For the given key, only
+              initial states contains the grouping key and data, and it is 
first batch.
+            - `initialStates` is empty, while `inputRows` is not empty. For 
the given grouping key, only inputRows
+              contains the grouping key and data, and it is first batch.
+            - `initialStates` is None, while `inputRows` is not empty. This is 
not first batch. `initialStates`
+              is initialized to the positional value as None.
+            """
+            handle = StatefulProcessorHandle(statefulProcessorApiClient)
+
+            if statefulProcessorApiClient.handle_state == 
StatefulProcessorHandleState.CREATED:
+                statefulProcessor.init(handle)
+                statefulProcessorApiClient.set_handle_state(
+                    StatefulProcessorHandleState.INITIALIZED
+                )
+
+            # only process initial state if first batch
+            is_first_batch = statefulProcessorApiClient.is_first_batch()
+            if is_first_batch and initialStates is not None:

Review Comment:
   I'd expect caller to handle this; providing initialStates for non-first 
batch is already adding unnecessary overhead and ideally caller should provide 
None for non-first batch. I'm OK to double check here for safety purpose, but 
maybe I'd do opposite, assert that (!is_first_batch and initialStates is None) 
is True.



##########
python/pyspark/sql/streaming/stateful_processor_api_client.py:
##########
@@ -338,6 +339,27 @@ def get_map_state(
             # TODO(SPARK-49233): Classify user facing errors.
             raise PySparkRuntimeError(f"Error initializing map state: " 
f"{response_message[1]}")
 
+    def is_first_batch(self) -> bool:
+        import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
+
+        is_first_batch = stateMessage.IsFirstBatch()
+        request = stateMessage.UtilsCallCommand(isFirstBatch=is_first_batch)
+        stateful_processor_call = 
stateMessage.StatefulProcessorCall(utilsCall=request)
+        message = 
stateMessage.StateRequest(statefulProcessorCall=stateful_processor_call)
+
+        self._send_proto_message(message.SerializeToString())
+        response_message = self._receive_proto_message()
+        status = response_message[0]
+        if status == 0:
+            return True
+        elif status == 1:
+            return False
+        else:

Review Comment:
   For all other calls it's mostly status == 1.



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -551,25 +550,103 @@ def transformWithStateUDF(
             # TODO(SPARK-49603) set the handle state in the lazily initialized 
iterator
 
             result = itertools.chain(*result_iter_list)
+            return result
+
+        def transformWithStateUDF(
+            statefulProcessorApiClient: StatefulProcessorApiClient,
+            key: Any,
+            inputRows: Iterator["PandasDataFrameLike"],
+        ) -> Iterator["PandasDataFrameLike"]:
+            handle = StatefulProcessorHandle(statefulProcessorApiClient)
+
+            if statefulProcessorApiClient.handle_state == 
StatefulProcessorHandleState.CREATED:
+                statefulProcessor.init(handle)
+                statefulProcessorApiClient.set_handle_state(
+                    StatefulProcessorHandleState.INITIALIZED
+                )
+
+            result = handle_data_with_timers(statefulProcessorApiClient, key, 
inputRows)
+            return result
+
+        def transformWithStateWithInitStateUDF(
+            statefulProcessorApiClient: StatefulProcessorApiClient,
+            key: Any,
+            inputRows: Iterator["PandasDataFrameLike"],
+            initialStates: Iterator["PandasDataFrameLike"] = None,
+        ) -> Iterator["PandasDataFrameLike"]:
+            """
+            UDF for TWS operator with non-empty initial states. Possible input 
combinations
+            of inputRows and initialStates iterator:
+            - Both `inputRows` and `initialStates` are non-empty: for the 
given key, both input rows
+              and initial states contains the grouping key, both input rows 
and initial states contains data.
+            - `InitialStates` is non-empty, while `initialStates` is empty. 
For the given key, only
+              initial states contains the grouping key and data, and it is 
first batch.
+            - `initialStates` is empty, while `inputRows` is not empty. For 
the given grouping key, only inputRows
+              contains the grouping key and data, and it is first batch.
+            - `initialStates` is None, while `inputRows` is not empty. This is 
not first batch. `initialStates`
+              is initialized to the positional value as None.
+            """
+            handle = StatefulProcessorHandle(statefulProcessorApiClient)
+
+            if statefulProcessorApiClient.handle_state == 
StatefulProcessorHandleState.CREATED:
+                statefulProcessor.init(handle)
+                statefulProcessorApiClient.set_handle_state(
+                    StatefulProcessorHandleState.INITIALIZED
+                )
+
+            # only process initial state if first batch
+            is_first_batch = statefulProcessorApiClient.is_first_batch()
+            if is_first_batch and initialStates is not None:
+                for cur_initial_state in initialStates:
+                    statefulProcessorApiClient.set_implicit_key(key)
+                    # TODO(SPARK-50194) integration with new timer API & 
initial state timer register
+                    statefulProcessor.handleInitialState(key, 
cur_initial_state)
+
+            # if we don't have input rows for the given key but only have 
initial state
+            # for the grouping key, the inputRows iterator could be empty
+            input_rows_empty = False
+            try:
+                first = next(inputRows)
+            except StopIteration:
+                input_rows_empty = True
+            else:
+                inputRows = itertools.chain([first], inputRows)
+
+            if not input_rows_empty:

Review Comment:
   If you don't have a test covering this scenario, please add it as well.



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -551,25 +550,103 @@ def transformWithStateUDF(
             # TODO(SPARK-49603) set the handle state in the lazily initialized 
iterator
 
             result = itertools.chain(*result_iter_list)
+            return result
+
+        def transformWithStateUDF(
+            statefulProcessorApiClient: StatefulProcessorApiClient,
+            key: Any,
+            inputRows: Iterator["PandasDataFrameLike"],
+        ) -> Iterator["PandasDataFrameLike"]:
+            handle = StatefulProcessorHandle(statefulProcessorApiClient)
+
+            if statefulProcessorApiClient.handle_state == 
StatefulProcessorHandleState.CREATED:
+                statefulProcessor.init(handle)
+                statefulProcessorApiClient.set_handle_state(
+                    StatefulProcessorHandleState.INITIALIZED
+                )
+
+            result = handle_data_with_timers(statefulProcessorApiClient, key, 
inputRows)
+            return result
+
+        def transformWithStateWithInitStateUDF(
+            statefulProcessorApiClient: StatefulProcessorApiClient,
+            key: Any,
+            inputRows: Iterator["PandasDataFrameLike"],
+            initialStates: Iterator["PandasDataFrameLike"] = None,
+        ) -> Iterator["PandasDataFrameLike"]:
+            """
+            UDF for TWS operator with non-empty initial states. Possible input 
combinations
+            of inputRows and initialStates iterator:
+            - Both `inputRows` and `initialStates` are non-empty: for the 
given key, both input rows
+              and initial states contains the grouping key, both input rows 
and initial states contains data.
+            - `InitialStates` is non-empty, while `initialStates` is empty. 
For the given key, only
+              initial states contains the grouping key and data, and it is 
first batch.
+            - `initialStates` is empty, while `inputRows` is not empty. For 
the given grouping key, only inputRows
+              contains the grouping key and data, and it is first batch.
+            - `initialStates` is None, while `inputRows` is not empty. This is 
not first batch. `initialStates`

Review Comment:
   This represents the difference between an empty Dataset (or iterator) and 
None, right? Just to make clear.



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -551,25 +550,103 @@ def transformWithStateUDF(
             # TODO(SPARK-49603) set the handle state in the lazily initialized 
iterator
 
             result = itertools.chain(*result_iter_list)
+            return result
+
+        def transformWithStateUDF(
+            statefulProcessorApiClient: StatefulProcessorApiClient,
+            key: Any,
+            inputRows: Iterator["PandasDataFrameLike"],
+        ) -> Iterator["PandasDataFrameLike"]:
+            handle = StatefulProcessorHandle(statefulProcessorApiClient)
+
+            if statefulProcessorApiClient.handle_state == 
StatefulProcessorHandleState.CREATED:
+                statefulProcessor.init(handle)
+                statefulProcessorApiClient.set_handle_state(
+                    StatefulProcessorHandleState.INITIALIZED
+                )
+
+            result = handle_data_with_timers(statefulProcessorApiClient, key, 
inputRows)
+            return result
+
+        def transformWithStateWithInitStateUDF(
+            statefulProcessorApiClient: StatefulProcessorApiClient,
+            key: Any,
+            inputRows: Iterator["PandasDataFrameLike"],
+            initialStates: Iterator["PandasDataFrameLike"] = None,
+        ) -> Iterator["PandasDataFrameLike"]:
+            """
+            UDF for TWS operator with non-empty initial states. Possible input 
combinations
+            of inputRows and initialStates iterator:
+            - Both `inputRows` and `initialStates` are non-empty: for the 
given key, both input rows

Review Comment:
   nit: `both input rows and initial states contains the grouping key` sound to 
be redundant since we call out `for the given key`. inputRows and initialStates 
are expected to be flatten Dataset (not grouped one), right? Their grouping key 
is the given key.



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -551,25 +550,103 @@ def transformWithStateUDF(
             # TODO(SPARK-49603) set the handle state in the lazily initialized 
iterator
 
             result = itertools.chain(*result_iter_list)
+            return result
+
+        def transformWithStateUDF(
+            statefulProcessorApiClient: StatefulProcessorApiClient,
+            key: Any,
+            inputRows: Iterator["PandasDataFrameLike"],
+        ) -> Iterator["PandasDataFrameLike"]:
+            handle = StatefulProcessorHandle(statefulProcessorApiClient)
+
+            if statefulProcessorApiClient.handle_state == 
StatefulProcessorHandleState.CREATED:
+                statefulProcessor.init(handle)
+                statefulProcessorApiClient.set_handle_state(
+                    StatefulProcessorHandleState.INITIALIZED
+                )
+
+            result = handle_data_with_timers(statefulProcessorApiClient, key, 
inputRows)
+            return result
+
+        def transformWithStateWithInitStateUDF(
+            statefulProcessorApiClient: StatefulProcessorApiClient,
+            key: Any,
+            inputRows: Iterator["PandasDataFrameLike"],
+            initialStates: Iterator["PandasDataFrameLike"] = None,
+        ) -> Iterator["PandasDataFrameLike"]:
+            """
+            UDF for TWS operator with non-empty initial states. Possible input 
combinations
+            of inputRows and initialStates iterator:
+            - Both `inputRows` and `initialStates` are non-empty: for the 
given key, both input rows
+              and initial states contains the grouping key, both input rows 
and initial states contains data.
+            - `InitialStates` is non-empty, while `initialStates` is empty. 
For the given key, only
+              initial states contains the grouping key and data, and it is 
first batch.
+            - `initialStates` is empty, while `inputRows` is not empty. For 
the given grouping key, only inputRows
+              contains the grouping key and data, and it is first batch.
+            - `initialStates` is None, while `inputRows` is not empty. This is 
not first batch. `initialStates`
+              is initialized to the positional value as None.
+            """
+            handle = StatefulProcessorHandle(statefulProcessorApiClient)
+
+            if statefulProcessorApiClient.handle_state == 
StatefulProcessorHandleState.CREATED:
+                statefulProcessor.init(handle)
+                statefulProcessorApiClient.set_handle_state(
+                    StatefulProcessorHandleState.INITIALIZED
+                )
+
+            # only process initial state if first batch
+            is_first_batch = statefulProcessorApiClient.is_first_batch()
+            if is_first_batch and initialStates is not None:
+                for cur_initial_state in initialStates:
+                    statefulProcessorApiClient.set_implicit_key(key)
+                    # TODO(SPARK-50194) integration with new timer API & 
initial state timer register
+                    statefulProcessor.handleInitialState(key, 
cur_initial_state)
+
+            # if we don't have input rows for the given key but only have 
initial state
+            # for the grouping key, the inputRows iterator could be empty
+            input_rows_empty = False
+            try:
+                first = next(inputRows)
+            except StopIteration:
+                input_rows_empty = True
+            else:
+                inputRows = itertools.chain([first], inputRows)
+
+            if not input_rows_empty:

Review Comment:
   Wait, isn't there a case where inputRows iterator is empty but timer is 
expected to be expired? 
   
   My understanding is that if you pass transformWithStateWithInitStateUDF as 
udf, it will be used for all batches, right? Then how this could handle the 
case of grouping key for batch N where there is no data for grouping key but 
timer to expire?



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala:
##########
@@ -168,26 +168,53 @@ case class FlatMapGroupsInPandasWithState(
  * methods for new rows in each trigger and the user's state/state variables 
will be stored
  * persistently across invocations.
  * @param functionExpr function called on each group
- * @param groupingAttributes used to group the data
+ * @param groupingAttributesLen length of the seq of grouping attributes for 
input dataframe
  * @param outputAttrs used to define the output rows
  * @param outputMode defines the output mode for the statefulProcessor
  * @param timeMode the time mode semantics of the stateful processor for 
timers and TTL.
  * @param child logical plan of the underlying data
+ * @param initialState logical plan of initial state
+ * @param initGroupingAttrsLen length of the seq of grouping attributes for 
initial state dataframe
  */
 case class TransformWithStateInPandas(
     functionExpr: Expression,
-    groupingAttributes: Seq[Attribute],
+    groupingAttributesLen: Int,
     outputAttrs: Seq[Attribute],
     outputMode: OutputMode,
     timeMode: TimeMode,
-    child: LogicalPlan) extends UnaryNode {
+    child: LogicalPlan,
+    hasInitialState: Boolean,
+    initialState: LogicalPlan,
+    initGroupingAttrsLen: Int,
+    initialStateSchema: StructType) extends BinaryNode {
+  override def left: LogicalPlan = child
+
+  override def right: LogicalPlan = initialState
 
   override def output: Seq[Attribute] = outputAttrs
 
   override def producedAttributes: AttributeSet = AttributeSet(outputAttrs)
 
-  override protected def withNewChildInternal(
-      newChild: LogicalPlan): TransformWithStateInPandas = copy(child = 
newChild)
+  override lazy val references: AttributeSet =
+    AttributeSet(leftAttributes ++ rightAttributes ++ functionExpr.references) 
-- producedAttributes
+
+  override protected def withNewChildrenInternal(
+      newLeft: LogicalPlan, newRight: LogicalPlan): TransformWithStateInPandas 
=
+    copy(child = newLeft, initialState = newRight)
+
+  // We call the following attributes from `SparkStrategies` because attribute 
were
+  // resolved when we got there; if we directly pass Seq of attributes before 
it
+  // is fully analyzed, we will have conflicting attributes when resolving 
(initial state
+  // and child have the same names on grouping attributes)
+  def leftAttributes: Seq[Attribute] = left.output.take(groupingAttributesLen)

Review Comment:
   You can also assert this node to be resolved when this method is called. e.g.
   
   ```
   assert(resolved, "this method is expected to be called after resolution")
   ```
   
   Also we could make the code comment as method doc so that it's easily seen 
based on IDE.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala:
##########
@@ -788,15 +788,20 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
    */
   object TransformWithStateInPandasStrategy extends Strategy {
     override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
-      case TransformWithStateInPandas(
-        func, groupingAttributes, outputAttrs, outputMode, timeMode, child) =>
+      case t@TransformWithStateInPandas(

Review Comment:
   nit: spaces around `@`



##########
python/pyspark/sql/streaming/stateful_processor.py:
##########
@@ -426,3 +426,10 @@ def close(self) -> None:
         operations.
         """
         ...
+
+    def handleInitialState(self, key: Any, initialState: 
"PandasDataFrameLike") -> None:
+        """
+        Optional to implement. Will act as no-op if not defined or no initial 
state input. Function

Review Comment:
   nit: Same, it's diverged with method doc for Scala API.
   
   > Function that will be invoked only in the first batch for users to process 
initial states.
   
   We can either revisit as a whole, or do the sync now at least for code 
change in this PR, while we are here.



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -551,25 +550,103 @@ def transformWithStateUDF(
             # TODO(SPARK-49603) set the handle state in the lazily initialized 
iterator
 
             result = itertools.chain(*result_iter_list)
+            return result
+
+        def transformWithStateUDF(
+            statefulProcessorApiClient: StatefulProcessorApiClient,
+            key: Any,
+            inputRows: Iterator["PandasDataFrameLike"],
+        ) -> Iterator["PandasDataFrameLike"]:
+            handle = StatefulProcessorHandle(statefulProcessorApiClient)
+
+            if statefulProcessorApiClient.handle_state == 
StatefulProcessorHandleState.CREATED:
+                statefulProcessor.init(handle)
+                statefulProcessorApiClient.set_handle_state(
+                    StatefulProcessorHandleState.INITIALIZED
+                )
+
+            result = handle_data_with_timers(statefulProcessorApiClient, key, 
inputRows)
+            return result
+
+        def transformWithStateWithInitStateUDF(
+            statefulProcessorApiClient: StatefulProcessorApiClient,
+            key: Any,
+            inputRows: Iterator["PandasDataFrameLike"],
+            initialStates: Iterator["PandasDataFrameLike"] = None,
+        ) -> Iterator["PandasDataFrameLike"]:
+            """
+            UDF for TWS operator with non-empty initial states. Possible input 
combinations
+            of inputRows and initialStates iterator:
+            - Both `inputRows` and `initialStates` are non-empty: for the 
given key, both input rows
+              and initial states contains the grouping key, both input rows 
and initial states contains data.
+            - `InitialStates` is non-empty, while `initialStates` is empty. 
For the given key, only

Review Comment:
   nit: `InitialStates` is non-empty, while `initialStates` is empty.
   
   you may want to change either one.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala:
##########
@@ -168,26 +168,53 @@ case class FlatMapGroupsInPandasWithState(
  * methods for new rows in each trigger and the user's state/state variables 
will be stored
  * persistently across invocations.
  * @param functionExpr function called on each group
- * @param groupingAttributes used to group the data
+ * @param groupingAttributesLen length of the seq of grouping attributes for 
input dataframe

Review Comment:
   Shall we describe the prerequisite of the child (left) and initialState 
(right), grouping key attributes should be duplicated at the first place along 
with output attributes? Better to understand how this works and provide the 
params from caller site.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala:
##########
@@ -168,26 +168,53 @@ case class FlatMapGroupsInPandasWithState(
  * methods for new rows in each trigger and the user's state/state variables 
will be stored
  * persistently across invocations.
  * @param functionExpr function called on each group
- * @param groupingAttributes used to group the data
+ * @param groupingAttributesLen length of the seq of grouping attributes for 
input dataframe
  * @param outputAttrs used to define the output rows
  * @param outputMode defines the output mode for the statefulProcessor
  * @param timeMode the time mode semantics of the stateful processor for 
timers and TTL.
  * @param child logical plan of the underlying data
+ * @param initialState logical plan of initial state
+ * @param initGroupingAttrsLen length of the seq of grouping attributes for 
initial state dataframe
  */
 case class TransformWithStateInPandas(
     functionExpr: Expression,
-    groupingAttributes: Seq[Attribute],
+    groupingAttributesLen: Int,
     outputAttrs: Seq[Attribute],
     outputMode: OutputMode,
     timeMode: TimeMode,
-    child: LogicalPlan) extends UnaryNode {
+    child: LogicalPlan,
+    hasInitialState: Boolean,
+    initialState: LogicalPlan,
+    initGroupingAttrsLen: Int,
+    initialStateSchema: StructType) extends BinaryNode {
+  override def left: LogicalPlan = child
+
+  override def right: LogicalPlan = initialState
 
   override def output: Seq[Attribute] = outputAttrs
 
   override def producedAttributes: AttributeSet = AttributeSet(outputAttrs)
 
-  override protected def withNewChildInternal(
-      newChild: LogicalPlan): TransformWithStateInPandas = copy(child = 
newChild)
+  override lazy val references: AttributeSet =
+    AttributeSet(leftAttributes ++ rightAttributes ++ functionExpr.references) 
-- producedAttributes
+
+  override protected def withNewChildrenInternal(
+      newLeft: LogicalPlan, newRight: LogicalPlan): TransformWithStateInPandas 
=
+    copy(child = newLeft, initialState = newRight)
+
+  // We call the following attributes from `SparkStrategies` because attribute 
were
+  // resolved when we got there; if we directly pass Seq of attributes before 
it
+  // is fully analyzed, we will have conflicting attributes when resolving 
(initial state
+  // and child have the same names on grouping attributes)
+  def leftAttributes: Seq[Attribute] = left.output.take(groupingAttributesLen)
+
+  def rightAttributes: Seq[Attribute] = if (hasInitialState) {

Review Comment:
   same here, leave a method doc with requirement.



##########
sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala:
##########
@@ -474,24 +474,61 @@ class RelationalGroupedDataset protected[sql](
       func: Column,
       outputStructType: StructType,
       outputModeStr: String,
-      timeModeStr: String): DataFrame = {
-    val groupingNamedExpressions = groupingExprs.map {
-      case ne: NamedExpression => ne
-      case other => Alias(other, other.toString)()
+      timeModeStr: String,
+      initialState: RelationalGroupedDataset): DataFrame = {
+    def exprToAttr(expr: Seq[Expression]): Seq[Attribute] = {
+      expr.map {
+        case ne: NamedExpression => ne
+        case other => Alias(other, other.toString)()
+      }.map(_.toAttribute)
     }
-    val groupingAttrs = groupingNamedExpressions.map(_.toAttribute)
+
+    val groupingAttrs = exprToAttr(groupingExprs)
     val outputAttrs = toAttributes(outputStructType)
     val outputMode = InternalOutputModes(outputModeStr)
     val timeMode = TimeModes(timeModeStr)
 
-    val plan = TransformWithStateInPandas(
-      func.expr,
-      groupingAttrs,
-      outputAttrs,
-      outputMode,
-      timeMode,
-      child = df.logicalPlan
-    )
+    val plan: LogicalPlan = if (initialState == null) {
+      TransformWithStateInPandas(

Review Comment:
   Are the grouping attributes guaranteed to be placed for the first place? And 
also are they duplicated along with output attributes? 
   
   I suspect here we have to do the projection at least for child. Please leave 
a code comment explaining if it's not.
   
   If we needed a projection and no test was failing without this, please add a 
test which picks non-contiguous columns as grouping key, for both non-initial 
state and initial state.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -348,6 +351,23 @@ class TransformWithStateInPandasStateServer(
     }
   }
 
+  private[sql] def handleStatefulProcessorUtilRequest(message: 
UtilsCallCommand): Unit = {
+    message.getMethodCase match {
+      case UtilsCallCommand.MethodCase.ISFIRSTBATCH =>
+        if (!hasInitialState) {
+          // In physical planning, hasInitialState will always be flipped

Review Comment:
   What's the reason we have to deduce the info of isFirstBatch based on 
hasInitialState? It isn't exactly same, right? hasInitialState can be false 
even for the first batch. How we will deal with it if isFirstBatch is used 
without the context of initial state? 
   
   Please avoid coupling something which introduces dependency and assumption, 
and try not to couple with how "current" caller will use it. If you have to 
rely on assumption (e.g. can't add more info to carry over in the plan), please 
find the dependency which can give the consistent result. e.g. state operator 
info has version which you can get batchId from version - 1, and isFirstBatch 
is batchId = 0.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala:
##########
@@ -30,15 +30,16 @@ import org.apache.spark.api.python.{BasePythonRunner, 
ChainedPythonFunctions, Py
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.execution.metric.SQLMetric
-import 
org.apache.spark.sql.execution.python.TransformWithStateInPandasPythonRunner.{InType,
 OutType}
+import 
org.apache.spark.sql.execution.python.TransformWithStateInPandasPythonRunner._

Review Comment:
   nit: we generally avoid using wildcard on import unless it's 10s of.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala:
##########
@@ -788,15 +788,20 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
    */
   object TransformWithStateInPandasStrategy extends Strategy {
     override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
-      case TransformWithStateInPandas(
-        func, groupingAttributes, outputAttrs, outputMode, timeMode, child) =>
+      case t@TransformWithStateInPandas(
+      func, _, outputAttrs, outputMode, timeMode, child,
+      hasInitialState, i, _, initialStateSchema) =>

Review Comment:
   nit: let's ensure the param name to be understandable, we don't shorten 
others but i for initialState; why not just use the full name.



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala:
##########
@@ -561,6 +565,13 @@ class TransformWithStateInPandasStateServerSuite extends 
SparkFunSuite with Befo
     verify(arrowStreamWriter).finalizeCurrentArrowBatch()
   }
 
+  test("stateful processor - is first batch") {

Review Comment:
   Once my comment in above is addressed, please have a negative test as well.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -348,6 +351,23 @@ class TransformWithStateInPandasStateServer(
     }
   }
 
+  private[sql] def handleStatefulProcessorUtilRequest(message: 
UtilsCallCommand): Unit = {
+    message.getMethodCase match {
+      case UtilsCallCommand.MethodCase.ISFIRSTBATCH =>
+        if (!hasInitialState) {
+          // In physical planning, hasInitialState will always be flipped
+          // if it is not first batch
+          sendResponse(1)

Review Comment:
   We have been using the code 1 as "error", not "true"/"false". We use 2 in 
false of the boolean result. Let's be consistent.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala:
##########
@@ -51,15 +52,137 @@ class TransformWithStateInPandasPythonRunner(
     override val pythonMetrics: Map[String, SQLMetric],
     jobArtifactUUID: Option[String],
     groupingKeySchema: StructType,
-    batchTimestampMs: Option[Long] = None,
-    eventTimeWatermarkForEviction: Option[Long] = None)
-  extends BasePythonRunner[InType, OutType](funcs.map(_._1), evalType, 
argOffsets, jobArtifactUUID)
-  with PythonArrowInput[InType]
-  with BasicPythonArrowOutput
-  with Logging {
+    batchTimestampMs: Option[Long],
+    eventTimeWatermarkForEviction: Option[Long],
+    hasInitialState: Boolean)
+  extends TransformWithStateInPandasPythonBaseRunner[InType](
+    funcs, evalType, argOffsets, _schema, processorHandle, _timeZoneId,
+    initialWorkerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema,
+    batchTimestampMs, eventTimeWatermarkForEviction, hasInitialState)
+    with PythonArrowInput[InType] {
 
-  private val sqlConf = SQLConf.get
-  private val arrowMaxRecordsPerBatch = sqlConf.arrowMaxRecordsPerBatch
+  private var pandasWriter: BaseStreamingArrowWriter = _
+
+  override protected def writeNextInputToArrowStream(
+      root: VectorSchemaRoot,
+      writer: ArrowStreamWriter,
+      dataOut: DataOutputStream,
+      inputIterator: Iterator[InType]): Boolean = {
+    if (pandasWriter == null) {
+      pandasWriter = new BaseStreamingArrowWriter(root, writer, 
arrowMaxRecordsPerBatch)
+    }
+
+    if (inputIterator.hasNext) {
+      val startData = dataOut.size()
+      val next = inputIterator.next()
+      val dataIter = next._2
+
+      while (dataIter.hasNext) {
+        val dataRow = dataIter.next()
+        pandasWriter.writeRow(dataRow)
+      }
+      pandasWriter.finalizeCurrentArrowBatch()
+      val deltaData = dataOut.size() - startData
+      pythonMetrics("pythonDataSent") += deltaData
+      true
+    } else {
+      super[PythonArrowInput].close()
+      false
+    }
+  }
+}
+
+/**
+ * Python runner with initial state in TransformWithStateInPandas.
+ * Write input data as one InternalRow(inputRow, initialState) in each row in 
arrow batch.
+ */
+class TransformWithStateInPandasPythonInitialStateRunner(
+    funcs: Seq[(ChainedPythonFunctions, Long)],
+    evalType: Int,
+    argOffsets: Array[Array[Int]],
+    dataSchema: StructType,
+    initStateSchema: StructType,
+    processorHandle: StatefulProcessorHandleImpl,
+    _timeZoneId: String,
+    initialWorkerConf: Map[String, String],
+    override val pythonMetrics: Map[String, SQLMetric],
+    jobArtifactUUID: Option[String],
+    groupingKeySchema: StructType,
+    batchTimestampMs: Option[Long],
+    eventTimeWatermarkForEviction: Option[Long],
+    hasInitialState: Boolean)
+  extends TransformWithStateInPandasPythonBaseRunner[GroupedInType](
+    funcs, evalType, argOffsets, dataSchema, processorHandle, _timeZoneId,
+    initialWorkerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema,
+    batchTimestampMs, eventTimeWatermarkForEviction, hasInitialState)
+    with PythonArrowInput[GroupedInType] {

Review Comment:
   ditto



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala:
##########
@@ -61,6 +62,8 @@ class TransformWithStateInPandasStateServerSuite extends 
SparkFunSuite with Befo
   var arrowStreamWriter: BaseStreamingArrowWriter = _
   var batchTimestampMs: Option[Long] = _
   var eventTimeWatermarkForEviction: Option[Long] = _
+  var initialStateSchema: StructType = StructType(Seq())

Review Comment:
   nit: twos are not used



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala:
##########
@@ -51,15 +52,137 @@ class TransformWithStateInPandasPythonRunner(
     override val pythonMetrics: Map[String, SQLMetric],
     jobArtifactUUID: Option[String],
     groupingKeySchema: StructType,
-    batchTimestampMs: Option[Long] = None,
-    eventTimeWatermarkForEviction: Option[Long] = None)
-  extends BasePythonRunner[InType, OutType](funcs.map(_._1), evalType, 
argOffsets, jobArtifactUUID)
-  with PythonArrowInput[InType]
-  with BasicPythonArrowOutput
-  with Logging {
+    batchTimestampMs: Option[Long],
+    eventTimeWatermarkForEviction: Option[Long],
+    hasInitialState: Boolean)
+  extends TransformWithStateInPandasPythonBaseRunner[InType](
+    funcs, evalType, argOffsets, _schema, processorHandle, _timeZoneId,
+    initialWorkerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema,
+    batchTimestampMs, eventTimeWatermarkForEviction, hasInitialState)
+    with PythonArrowInput[InType] {

Review Comment:
   nit: shift left 2 spaces



##########
python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py:
##########
@@ -536,6 +536,108 @@ def check_results(batch_df, batch_id):
             EventTimeStatefulProcessor(), check_results
         )
 
+    def _test_transform_with_state_init_state_in_pandas(self, 
stateful_processor, check_results):
+        input_path = tempfile.mkdtemp()
+        self._prepare_test_resource1(input_path)

Review Comment:
   I see you are covering both cases in this test, which is great!
   
   * grouping key in input, but not in initial state (1)
   * grouping key in initial state, but not in input (3)



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala:
##########
@@ -51,15 +52,137 @@ class TransformWithStateInPandasPythonRunner(
     override val pythonMetrics: Map[String, SQLMetric],
     jobArtifactUUID: Option[String],
     groupingKeySchema: StructType,
-    batchTimestampMs: Option[Long] = None,
-    eventTimeWatermarkForEviction: Option[Long] = None)
-  extends BasePythonRunner[InType, OutType](funcs.map(_._1), evalType, 
argOffsets, jobArtifactUUID)
-  with PythonArrowInput[InType]
-  with BasicPythonArrowOutput
-  with Logging {
+    batchTimestampMs: Option[Long],
+    eventTimeWatermarkForEviction: Option[Long],
+    hasInitialState: Boolean)
+  extends TransformWithStateInPandasPythonBaseRunner[InType](
+    funcs, evalType, argOffsets, _schema, processorHandle, _timeZoneId,
+    initialWorkerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema,
+    batchTimestampMs, eventTimeWatermarkForEviction, hasInitialState)
+    with PythonArrowInput[InType] {
 
-  private val sqlConf = SQLConf.get
-  private val arrowMaxRecordsPerBatch = sqlConf.arrowMaxRecordsPerBatch
+  private var pandasWriter: BaseStreamingArrowWriter = _
+
+  override protected def writeNextInputToArrowStream(
+      root: VectorSchemaRoot,
+      writer: ArrowStreamWriter,
+      dataOut: DataOutputStream,
+      inputIterator: Iterator[InType]): Boolean = {
+    if (pandasWriter == null) {
+      pandasWriter = new BaseStreamingArrowWriter(root, writer, 
arrowMaxRecordsPerBatch)
+    }
+
+    if (inputIterator.hasNext) {
+      val startData = dataOut.size()
+      val next = inputIterator.next()
+      val dataIter = next._2
+
+      while (dataIter.hasNext) {
+        val dataRow = dataIter.next()
+        pandasWriter.writeRow(dataRow)
+      }
+      pandasWriter.finalizeCurrentArrowBatch()
+      val deltaData = dataOut.size() - startData
+      pythonMetrics("pythonDataSent") += deltaData
+      true
+    } else {
+      super[PythonArrowInput].close()
+      false
+    }
+  }
+}
+
+/**
+ * Python runner with initial state in TransformWithStateInPandas.
+ * Write input data as one InternalRow(inputRow, initialState) in each row in 
arrow batch.
+ */
+class TransformWithStateInPandasPythonInitialStateRunner(
+    funcs: Seq[(ChainedPythonFunctions, Long)],
+    evalType: Int,
+    argOffsets: Array[Array[Int]],
+    dataSchema: StructType,
+    initStateSchema: StructType,
+    processorHandle: StatefulProcessorHandleImpl,
+    _timeZoneId: String,
+    initialWorkerConf: Map[String, String],
+    override val pythonMetrics: Map[String, SQLMetric],
+    jobArtifactUUID: Option[String],
+    groupingKeySchema: StructType,
+    batchTimestampMs: Option[Long],
+    eventTimeWatermarkForEviction: Option[Long],
+    hasInitialState: Boolean)
+  extends TransformWithStateInPandasPythonBaseRunner[GroupedInType](
+    funcs, evalType, argOffsets, dataSchema, processorHandle, _timeZoneId,
+    initialWorkerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema,
+    batchTimestampMs, eventTimeWatermarkForEviction, hasInitialState)
+    with PythonArrowInput[GroupedInType] {
+
+  override protected lazy val schema: StructType = new StructType()
+    .add("inputData", dataSchema)
+    .add("initState", initStateSchema)
+
+  private var pandasWriter: BaseStreamingArrowWriter = _
+
+  override protected def writeNextInputToArrowStream(
+      root: VectorSchemaRoot,
+      writer: ArrowStreamWriter,
+      dataOut: DataOutputStream,
+      inputIterator:
+      Iterator[GroupedInType]): Boolean = {
+    if (pandasWriter == null) {
+      pandasWriter = new BaseStreamingArrowWriter(root, writer, 
arrowMaxRecordsPerBatch)
+    }
+
+    if (inputIterator.hasNext) {
+      val startData = dataOut.size()
+      // a new grouping key with data & init state iter
+      val next = inputIterator.next()
+      val dataIter = next._2
+      val initIter = next._3
+
+      while (dataIter.hasNext || initIter.hasNext) {
+        val dataRow =
+          if (dataIter.hasNext) dataIter.next()
+          else InternalRow.empty
+        val initRow =
+          if (initIter.hasNext) initIter.next()
+          else InternalRow.empty
+        pandasWriter.writeRow(InternalRow(dataRow, initRow))
+      }
+      pandasWriter.finalizeCurrentArrowBatch()
+      val deltaData = dataOut.size() - startData
+      pythonMetrics("pythonDataSent") += deltaData
+      true
+    } else {
+      super[PythonArrowInput].close()
+      false
+    }
+  }
+}
+
+/**
+ * Base Python runner implementation for TransformWithStateInPandas.
+ */
+abstract class TransformWithStateInPandasPythonBaseRunner[I](
+    funcs: Seq[(ChainedPythonFunctions, Long)],
+    evalType: Int,
+    argOffsets: Array[Array[Int]],
+    _schema: StructType,
+    processorHandle: StatefulProcessorHandleImpl,
+    _timeZoneId: String,
+    initialWorkerConf: Map[String, String],
+    override val pythonMetrics: Map[String, SQLMetric],
+    jobArtifactUUID: Option[String],
+    groupingKeySchema: StructType,
+    batchTimestampMs: Option[Long],
+    eventTimeWatermarkForEviction: Option[Long],
+    hasInitialState: Boolean)
+  extends BasePythonRunner[I, ColumnarBatch](funcs.map(_._1), evalType, 
argOffsets, jobArtifactUUID)
+    with PythonArrowInput[I]

Review Comment:
   ditto for all `with` lines



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala:
##########
@@ -51,15 +52,137 @@ class TransformWithStateInPandasPythonRunner(
     override val pythonMetrics: Map[String, SQLMetric],
     jobArtifactUUID: Option[String],
     groupingKeySchema: StructType,
-    batchTimestampMs: Option[Long] = None,
-    eventTimeWatermarkForEviction: Option[Long] = None)
-  extends BasePythonRunner[InType, OutType](funcs.map(_._1), evalType, 
argOffsets, jobArtifactUUID)
-  with PythonArrowInput[InType]
-  with BasicPythonArrowOutput
-  with Logging {
+    batchTimestampMs: Option[Long],
+    eventTimeWatermarkForEviction: Option[Long],
+    hasInitialState: Boolean)
+  extends TransformWithStateInPandasPythonBaseRunner[InType](
+    funcs, evalType, argOffsets, _schema, processorHandle, _timeZoneId,
+    initialWorkerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema,
+    batchTimestampMs, eventTimeWatermarkForEviction, hasInitialState)
+    with PythonArrowInput[InType] {
 
-  private val sqlConf = SQLConf.get
-  private val arrowMaxRecordsPerBatch = sqlConf.arrowMaxRecordsPerBatch
+  private var pandasWriter: BaseStreamingArrowWriter = _
+
+  override protected def writeNextInputToArrowStream(
+      root: VectorSchemaRoot,
+      writer: ArrowStreamWriter,
+      dataOut: DataOutputStream,
+      inputIterator: Iterator[InType]): Boolean = {
+    if (pandasWriter == null) {
+      pandasWriter = new BaseStreamingArrowWriter(root, writer, 
arrowMaxRecordsPerBatch)
+    }
+
+    if (inputIterator.hasNext) {
+      val startData = dataOut.size()
+      val next = inputIterator.next()
+      val dataIter = next._2
+
+      while (dataIter.hasNext) {
+        val dataRow = dataIter.next()
+        pandasWriter.writeRow(dataRow)
+      }
+      pandasWriter.finalizeCurrentArrowBatch()
+      val deltaData = dataOut.size() - startData
+      pythonMetrics("pythonDataSent") += deltaData
+      true
+    } else {
+      super[PythonArrowInput].close()
+      false
+    }
+  }
+}
+
+/**
+ * Python runner with initial state in TransformWithStateInPandas.
+ * Write input data as one InternalRow(inputRow, initialState) in each row in 
arrow batch.
+ */
+class TransformWithStateInPandasPythonInitialStateRunner(
+    funcs: Seq[(ChainedPythonFunctions, Long)],
+    evalType: Int,
+    argOffsets: Array[Array[Int]],
+    dataSchema: StructType,
+    initStateSchema: StructType,
+    processorHandle: StatefulProcessorHandleImpl,
+    _timeZoneId: String,
+    initialWorkerConf: Map[String, String],
+    override val pythonMetrics: Map[String, SQLMetric],
+    jobArtifactUUID: Option[String],
+    groupingKeySchema: StructType,
+    batchTimestampMs: Option[Long],
+    eventTimeWatermarkForEviction: Option[Long],
+    hasInitialState: Boolean)
+  extends TransformWithStateInPandasPythonBaseRunner[GroupedInType](
+    funcs, evalType, argOffsets, dataSchema, processorHandle, _timeZoneId,
+    initialWorkerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema,
+    batchTimestampMs, eventTimeWatermarkForEviction, hasInitialState)
+    with PythonArrowInput[GroupedInType] {
+
+  override protected lazy val schema: StructType = new StructType()
+    .add("inputData", dataSchema)
+    .add("initState", initStateSchema)
+
+  private var pandasWriter: BaseStreamingArrowWriter = _
+
+  override protected def writeNextInputToArrowStream(
+      root: VectorSchemaRoot,
+      writer: ArrowStreamWriter,
+      dataOut: DataOutputStream,
+      inputIterator:
+      Iterator[GroupedInType]): Boolean = {

Review Comment:
   nit: shifting one line above (any reason it's placed to the next line?)



##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -1190,3 +1190,70 @@ def dump_stream(self, iterator, stream):
         """
         result = [(b, t) for x in iterator for y, t in x for b in y]
         super().dump_stream(result, stream)
+
+
+class 
TransformWithStateInPandasInitStateSerializer(TransformWithStateInPandasSerializer):
+    """
+    Serializer used by Python worker to evaluate UDF for
+    
:meth:`pyspark.sql.GroupedData.transformWithStateInPandasInitStateSerializer`.
+    Parameters
+    ----------
+    Same as input parameters in TransformWithStateInPandasSerializer.
+    """
+
+    def __init__(self, timezone, safecheck, assign_cols_by_name, 
arrow_max_records_per_batch):
+        super(TransformWithStateInPandasInitStateSerializer, self).__init__(
+            timezone, safecheck, assign_cols_by_name, 
arrow_max_records_per_batch
+        )
+        self.init_key_offsets = None
+
+    def load_stream(self, stream):
+        import pyarrow as pa
+
+        def generate_data_batches(batches):
+            """
+            Deserialize ArrowRecordBatches and return a generator of 
pandas.Series list.
+            The deserialization logic assumes that Arrow RecordBatches contain 
the data with the
+            ordering that data chunks for same grouping key will appear 
sequentially.
+            See `TransformWithStateInPandasPythonBaseRunner` for arrow batch 
schema sent from JVM.
+            This function flatten the columns of input rows and initial state 
rows and feed them into
+            the data generator.
+            """
+
+            def flatten_columns(cur_batch, col_name):
+                state_column = 
cur_batch.column(cur_batch.schema.get_field_index(col_name))
+                state_field_names = [
+                    state_column.type[i].name for i in 
range(state_column.type.num_fields)
+                ]
+                state_field_arrays = [
+                    state_column.field(i) for i in 
range(state_column.type.num_fields)
+                ]
+                table_from_fields = pa.Table.from_arrays(
+                    state_field_arrays, names=state_field_names
+                )
+                return table_from_fields
+
+            for batch in batches:

Review Comment:
   Maybe better to have a brief comment about how the batch has constructed or 
some characteristic, or even where to read the code to understand the data 
structure. Personally I read this code before reading the part of building 
batch, and have to make an assumption that a batch must only have data from a 
single grouping key, otherwise it won't work.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala:
##########
@@ -168,26 +168,53 @@ case class FlatMapGroupsInPandasWithState(
  * methods for new rows in each trigger and the user's state/state variables 
will be stored
  * persistently across invocations.
  * @param functionExpr function called on each group
- * @param groupingAttributes used to group the data
+ * @param groupingAttributesLen length of the seq of grouping attributes for 
input dataframe
  * @param outputAttrs used to define the output rows
  * @param outputMode defines the output mode for the statefulProcessor
  * @param timeMode the time mode semantics of the stateful processor for 
timers and TTL.
  * @param child logical plan of the underlying data
+ * @param initialState logical plan of initial state
+ * @param initGroupingAttrsLen length of the seq of grouping attributes for 
initial state dataframe
  */
 case class TransformWithStateInPandas(
     functionExpr: Expression,
-    groupingAttributes: Seq[Attribute],
+    groupingAttributesLen: Int,
     outputAttrs: Seq[Attribute],
     outputMode: OutputMode,
     timeMode: TimeMode,
-    child: LogicalPlan) extends UnaryNode {
+    child: LogicalPlan,
+    hasInitialState: Boolean,
+    initialState: LogicalPlan,
+    initGroupingAttrsLen: Int,
+    initialStateSchema: StructType) extends BinaryNode {
+  override def left: LogicalPlan = child
+
+  override def right: LogicalPlan = initialState
 
   override def output: Seq[Attribute] = outputAttrs
 
   override def producedAttributes: AttributeSet = AttributeSet(outputAttrs)
 
-  override protected def withNewChildInternal(
-      newChild: LogicalPlan): TransformWithStateInPandas = copy(child = 
newChild)
+  override lazy val references: AttributeSet =
+    AttributeSet(leftAttributes ++ rightAttributes ++ functionExpr.references) 
-- producedAttributes
+
+  override protected def withNewChildrenInternal(
+      newLeft: LogicalPlan, newRight: LogicalPlan): TransformWithStateInPandas 
=
+    copy(child = newLeft, initialState = newRight)
+
+  // We call the following attributes from `SparkStrategies` because attribute 
were
+  // resolved when we got there; if we directly pass Seq of attributes before 
it
+  // is fully analyzed, we will have conflicting attributes when resolving 
(initial state
+  // and child have the same names on grouping attributes)
+  def leftAttributes: Seq[Attribute] = left.output.take(groupingAttributesLen)

Review Comment:
   Also, for the code comment, we don't need to explain where this method is 
called in here - we will have to update the code comment whenever there is a 
new caller. Let's just describe the requirement from this method's point of 
view. It is needed to be called "after" this node is resolved.



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -551,25 +550,103 @@ def transformWithStateUDF(
             # TODO(SPARK-49603) set the handle state in the lazily initialized 
iterator
 
             result = itertools.chain(*result_iter_list)
+            return result
+
+        def transformWithStateUDF(
+            statefulProcessorApiClient: StatefulProcessorApiClient,
+            key: Any,
+            inputRows: Iterator["PandasDataFrameLike"],
+        ) -> Iterator["PandasDataFrameLike"]:
+            handle = StatefulProcessorHandle(statefulProcessorApiClient)
+
+            if statefulProcessorApiClient.handle_state == 
StatefulProcessorHandleState.CREATED:
+                statefulProcessor.init(handle)
+                statefulProcessorApiClient.set_handle_state(
+                    StatefulProcessorHandleState.INITIALIZED
+                )
+
+            result = handle_data_with_timers(statefulProcessorApiClient, key, 
inputRows)
+            return result
+
+        def transformWithStateWithInitStateUDF(
+            statefulProcessorApiClient: StatefulProcessorApiClient,
+            key: Any,
+            inputRows: Iterator["PandasDataFrameLike"],
+            initialStates: Iterator["PandasDataFrameLike"] = None,
+        ) -> Iterator["PandasDataFrameLike"]:
+            """
+            UDF for TWS operator with non-empty initial states. Possible input 
combinations
+            of inputRows and initialStates iterator:
+            - Both `inputRows` and `initialStates` are non-empty: for the 
given key, both input rows

Review Comment:
   ditto for all others



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala:
##########
@@ -51,15 +52,137 @@ class TransformWithStateInPandasPythonRunner(
     override val pythonMetrics: Map[String, SQLMetric],
     jobArtifactUUID: Option[String],
     groupingKeySchema: StructType,
-    batchTimestampMs: Option[Long] = None,
-    eventTimeWatermarkForEviction: Option[Long] = None)
-  extends BasePythonRunner[InType, OutType](funcs.map(_._1), evalType, 
argOffsets, jobArtifactUUID)
-  with PythonArrowInput[InType]
-  with BasicPythonArrowOutput
-  with Logging {
+    batchTimestampMs: Option[Long],
+    eventTimeWatermarkForEviction: Option[Long],
+    hasInitialState: Boolean)
+  extends TransformWithStateInPandasPythonBaseRunner[InType](
+    funcs, evalType, argOffsets, _schema, processorHandle, _timeZoneId,
+    initialWorkerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema,
+    batchTimestampMs, eventTimeWatermarkForEviction, hasInitialState)
+    with PythonArrowInput[InType] {
 
-  private val sqlConf = SQLConf.get
-  private val arrowMaxRecordsPerBatch = sqlConf.arrowMaxRecordsPerBatch
+  private var pandasWriter: BaseStreamingArrowWriter = _
+
+  override protected def writeNextInputToArrowStream(
+      root: VectorSchemaRoot,
+      writer: ArrowStreamWriter,
+      dataOut: DataOutputStream,
+      inputIterator: Iterator[InType]): Boolean = {
+    if (pandasWriter == null) {
+      pandasWriter = new BaseStreamingArrowWriter(root, writer, 
arrowMaxRecordsPerBatch)
+    }
+
+    if (inputIterator.hasNext) {
+      val startData = dataOut.size()
+      val next = inputIterator.next()
+      val dataIter = next._2
+
+      while (dataIter.hasNext) {
+        val dataRow = dataIter.next()
+        pandasWriter.writeRow(dataRow)
+      }
+      pandasWriter.finalizeCurrentArrowBatch()
+      val deltaData = dataOut.size() - startData
+      pythonMetrics("pythonDataSent") += deltaData
+      true
+    } else {
+      super[PythonArrowInput].close()
+      false
+    }
+  }
+}
+
+/**
+ * Python runner with initial state in TransformWithStateInPandas.
+ * Write input data as one InternalRow(inputRow, initialState) in each row in 
arrow batch.
+ */
+class TransformWithStateInPandasPythonInitialStateRunner(
+    funcs: Seq[(ChainedPythonFunctions, Long)],
+    evalType: Int,
+    argOffsets: Array[Array[Int]],
+    dataSchema: StructType,
+    initStateSchema: StructType,
+    processorHandle: StatefulProcessorHandleImpl,
+    _timeZoneId: String,
+    initialWorkerConf: Map[String, String],
+    override val pythonMetrics: Map[String, SQLMetric],
+    jobArtifactUUID: Option[String],
+    groupingKeySchema: StructType,
+    batchTimestampMs: Option[Long],
+    eventTimeWatermarkForEviction: Option[Long],
+    hasInitialState: Boolean)
+  extends TransformWithStateInPandasPythonBaseRunner[GroupedInType](
+    funcs, evalType, argOffsets, dataSchema, processorHandle, _timeZoneId,
+    initialWorkerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema,
+    batchTimestampMs, eventTimeWatermarkForEviction, hasInitialState)
+    with PythonArrowInput[GroupedInType] {
+
+  override protected lazy val schema: StructType = new StructType()
+    .add("inputData", dataSchema)
+    .add("initState", initStateSchema)
+
+  private var pandasWriter: BaseStreamingArrowWriter = _
+
+  override protected def writeNextInputToArrowStream(
+      root: VectorSchemaRoot,
+      writer: ArrowStreamWriter,
+      dataOut: DataOutputStream,
+      inputIterator:
+      Iterator[GroupedInType]): Boolean = {

Review Comment:
   If the combined line exceeds 100 chars, `: Boolean = {` should only be in 
this line, with 2 spaces shifted left from parameters.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to