This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 603dc509821 [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState 
in PySpark
603dc509821 is described below

commit 603dc5098217d9580f611873165d25392f41cdfe
Author: Jungtaek Lim <kabhwan.opensou...@gmail.com>
AuthorDate: Thu Sep 22 12:35:07 2022 +0900

    [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to introduce the new API `applyInPandasWithState` in 
PySpark, which provides the functionality to perform arbitrary stateful 
processing in Structured Streaming.
    
    This will be a pair API with applyInPandas - applyInPandas in PySpark 
covers the use case of flatMapGroups in Scala/Java API, applyInPandasWithState 
in PySpark covers the use case of flatMapGroupsWithState in Scala/Java API.
    
    The signature of API follows:
    
    ```
    # call this function after groupBy
    def applyInPandasWithState(
        self,
        func: "PandasGroupedMapFunctionWithState",
        outputStructType: Union[StructType, str],
        stateStructType: Union[StructType, str],
        outputMode: str,
        timeoutConf: str,
    ) -> DataFrame
    ```
    
    and the signature of user function follows:
    
    ```
    def func(
        key: Tuple,
        pdf_iter: Iterator[pandas.DataFrame],
        state: GroupStateImpl
    ) -> Iterator[pandas.DataFrame]
    ```
    
    (Please refer the code diff for function doc of new function.)
    
    Major design choices which differ from existing APIs:
    
    1. The new API is untyped, while flatMapGroupsWithState in typed API.
    
    This is based on the nature of Python language - it's really duck typing 
and type definition is just a hint. We don't have the implementation of typed 
API for PySpark DataFrame.
    
    This leads us to design the API to be untyped, meaning, all types for 
(input, state, output) should be Row-compatible. While we don't require end 
users to deal with `Row` directly, the model they will use for state and output 
must be convertible to Row with default encoder. If they want the python type 
for state which is not compatible with Row (e.g. custom class), they need to 
pickle and use BinaryType to store it.
    
    This requires end users to specify the type of state and output via Spark 
SQL schema in the method.
    
    Note that this helps to ensure compatibility for state data across Spark 
versions, as long as the encoders for 1) python type -> python Row and 2) 
python Row -> UnsafeRow are not changed. We won't change the underlying data 
layout for UnsafeRow, as it will break all of existing stateful query.
    
    2. The new API will produce Pandas DataFrame to user function, while 
flatMapGroupsWithState produces iterator of rows.
    
    We decided to follow the user experience applyInPandas provides for both 
consistency and performance (Arrow batching, vectorization, etc). This leads us 
to design the user function to leverage pandas DataFrame rather than iterator 
of rows. While this leads inconsistency of the UX from the Scala/Java API, we 
don't think this will come up as a problem since Pandas is considered as 
de-facto standard for Python data scientists.
    
    3. The new API will produce iterator of Pandas DataFrame to user function 
and also require to return iterator of Pandas DataFrame to address scalability.
    
    There is known limitation of applyInPandas, scalability. It basically 
requires data in a specific group to be fit into memory. During the design 
phase of new API, we decided to address the scalability rather than inheriting 
the limitation.
    
    To address the scalability, we tweak the user function to receive an 
iterator (generator) of Pandas DataFrame instead of a single Pandas DataFrame, 
and also return an iterator (generator) of Pandas DataFrame. We think it does 
not hurt the UX too much, as for-each and yield would be enough to deal with 
the requirement of dealing with iterator.
    
    Implementation perspective, we split the data in a specific group to 
multiple chunks, which each chunk is stored and sent as "an" Arrow RecordBatch, 
and then finally materialized to "a" pandas DataFrame. This way, as long as end 
users don't materialize lots of pandas DataFrames from the iterator at the same 
time, only one chunk will be materialized into memory which is scalable. 
Similar logic applies to the output of user function, hence scalable as well.
    
    4. The new API also bin-packs the data with multiple groups into "an" Arrow 
RecordBatch.
    
    Given the API is mainly used for streaming workload, it could be high 
likely that the volume of data in a specific group may not be huge enough to 
leverage the benefit of Arrow columnar batching, which would hurt the 
performance. To address this, we also do the opposite thing what we do for 
scalability, bin-pack. That said, an Arrow RecordBatch can contain data for 
multiple groups, as well as a part of data for specific group. This address 
both aspects of concerns together, scalabilit [...]
    
    Note that we are not implementing all of features Scala/Java API provide 
from the initial phase. e.g. Support for batch query and support for initial 
state will be left as TODO.
    
    ### Why are the changes needed?
    
    PySpark users don't have a way to perform arbitrary stateful processing in 
Structured Streaming and being forced to use either Java or Scala which is 
unacceptable for users in many cases. This PR enables PySpark users to deal 
with it without moving to Java/Scala world.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. We are exposing new public API in PySpark which performs arbitrary 
stateful processing.
    
    ### How was this patch tested?
    
    N/A. We will make sure test suites are constructed via E2E manner under 
[SPARK-40431](https://issues.apache.org/jira/browse/SPARK-40431) - #37894
    
    Closes #37893 from 
HeartSaVioR/SPARK-40434-on-top-of-SPARK-40433-SPARK-40432.
    
    Lead-authored-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
    Co-authored-by: Hyukjin Kwon <gurwls...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../org/apache/spark/api/python/PythonRunner.scala |   2 +
 python/pyspark/rdd.py                              |   2 +
 python/pyspark/sql/pandas/_typing/__init__.pyi     |   6 +
 python/pyspark/sql/pandas/functions.py             |   2 +
 python/pyspark/sql/pandas/group_ops.py             | 125 +++++++-
 python/pyspark/sql/pandas/serializers.py           | 355 ++++++++++++++++++++-
 python/pyspark/sql/streaming/state.py              |  55 +++-
 python/pyspark/sql/udf.py                          |   9 +-
 python/pyspark/worker.py                           | 143 +++++++++
 .../analysis/UnsupportedOperationChecker.scala     |  62 ++++
 .../plans/logical/pythonLogicalOperators.scala     |  34 ++
 .../spark/sql/RelationalGroupedDataset.scala       |  45 +++
 .../spark/sql/execution/SparkStrategies.scala      |  23 ++
 .../spark/sql/execution/arrow/ArrowWriter.scala    |  16 +-
 .../ApplyInPandasWithStatePythonRunner.scala       | 223 +++++++++++++
 .../python/ApplyInPandasWithStateWriter.scala      | 276 ++++++++++++++++
 .../python/FlatMapCoGroupsInPandasExec.scala       |   4 +-
 .../python/FlatMapGroupsInPandasExec.scala         |   2 +-
 .../FlatMapGroupsInPandasWithStateExec.scala       | 214 +++++++++++++
 .../sql/execution/python/PandasGroupUtils.scala    |   7 +-
 .../sql/execution/python/PythonArrowInput.scala    |   1 -
 .../execution/streaming/IncrementalExecution.scala |   9 +
 .../sql/execution/streaming/state/package.scala    |   2 +-
 23 files changed, 1599 insertions(+), 18 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index 5a13674e8bf..7b31fa93c32 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -53,6 +53,7 @@ private[spark] object PythonEvalType {
   val SQL_MAP_PANDAS_ITER_UDF = 205
   val SQL_COGROUPED_MAP_PANDAS_UDF = 206
   val SQL_MAP_ARROW_ITER_UDF = 207
+  val SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE = 208
 
   def toString(pythonEvalType: Int): String = pythonEvalType match {
     case NON_UDF => "NON_UDF"
@@ -65,6 +66,7 @@ private[spark] object PythonEvalType {
     case SQL_MAP_PANDAS_ITER_UDF => "SQL_MAP_PANDAS_ITER_UDF"
     case SQL_COGROUPED_MAP_PANDAS_UDF => "SQL_COGROUPED_MAP_PANDAS_UDF"
     case SQL_MAP_ARROW_ITER_UDF => "SQL_MAP_ARROW_ITER_UDF"
+    case SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE => 
"SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE"
   }
 }
 
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 7ef0014ae75..5f4f4d494e1 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -105,6 +105,7 @@ if TYPE_CHECKING:
         PandasMapIterUDFType,
         PandasCogroupedMapUDFType,
         ArrowMapIterUDFType,
+        PandasGroupedMapUDFWithStateType,
     )
     from pyspark.sql.dataframe import DataFrame
     from pyspark.sql.types import AtomicType, StructType
@@ -147,6 +148,7 @@ class PythonEvalType:
     SQL_MAP_PANDAS_ITER_UDF: "PandasMapIterUDFType" = 205
     SQL_COGROUPED_MAP_PANDAS_UDF: "PandasCogroupedMapUDFType" = 206
     SQL_MAP_ARROW_ITER_UDF: "ArrowMapIterUDFType" = 207
+    SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: "PandasGroupedMapUDFWithStateType" 
= 208
 
 
 def portable_hash(x: Hashable) -> int:
diff --git a/python/pyspark/sql/pandas/_typing/__init__.pyi 
b/python/pyspark/sql/pandas/_typing/__init__.pyi
index 27ac64a7238..acca8c00f2a 100644
--- a/python/pyspark/sql/pandas/_typing/__init__.pyi
+++ b/python/pyspark/sql/pandas/_typing/__init__.pyi
@@ -30,6 +30,7 @@ from typing_extensions import Protocol, Literal
 from types import FunctionType
 
 from pyspark.sql._typing import LiteralType
+from pyspark.sql.streaming.state import GroupState
 from pandas.core.frame import DataFrame as PandasDataFrame
 from pandas.core.series import Series as PandasSeries
 from numpy import ndarray as NDArray
@@ -51,6 +52,7 @@ PandasScalarIterUDFType = Literal[204]
 PandasMapIterUDFType = Literal[205]
 PandasCogroupedMapUDFType = Literal[206]
 ArrowMapIterUDFType = Literal[207]
+PandasGroupedMapUDFWithStateType = Literal[208]
 
 class PandasVariadicScalarToScalarFunction(Protocol):
     def __call__(self, *_: DataFrameOrSeriesLike_) -> DataFrameOrSeriesLike_: 
...
@@ -256,6 +258,10 @@ PandasGroupedMapFunction = Union[
     Callable[[Any, DataFrameLike], DataFrameLike],
 ]
 
+PandasGroupedMapFunctionWithState = Callable[
+    [Any, Iterable[DataFrameLike], GroupState], Iterable[DataFrameLike]
+]
+
 class PandasVariadicGroupedAggFunction(Protocol):
     def __call__(self, *_: SeriesLike) -> LiteralType: ...
 
diff --git a/python/pyspark/sql/pandas/functions.py 
b/python/pyspark/sql/pandas/functions.py
index 94fabdbb295..d0f81e2f633 100644
--- a/python/pyspark/sql/pandas/functions.py
+++ b/python/pyspark/sql/pandas/functions.py
@@ -369,6 +369,7 @@ def pandas_udf(f=None, returnType=None, functionType=None):
         PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
         PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
         PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
+        PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
         None,
     ]:  # None means it should infer the type from type hints.
 
@@ -402,6 +403,7 @@ def _create_pandas_udf(f, returnType, evalType):
         PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
         PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
         PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
+        PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
     ]:
         # In case of 'SQL_GROUPED_MAP_PANDAS_UDF', deprecation warning is 
being triggered
         # at `apply` instead.
diff --git a/python/pyspark/sql/pandas/group_ops.py 
b/python/pyspark/sql/pandas/group_ops.py
index 6178433573e..0945c0078a2 100644
--- a/python/pyspark/sql/pandas/group_ops.py
+++ b/python/pyspark/sql/pandas/group_ops.py
@@ -15,18 +15,20 @@
 # limitations under the License.
 #
 import sys
-from typing import List, Union, TYPE_CHECKING
+from typing import List, Union, TYPE_CHECKING, cast
 import warnings
 
 from pyspark.rdd import PythonEvalType
 from pyspark.sql.column import Column
 from pyspark.sql.dataframe import DataFrame
-from pyspark.sql.types import StructType
+from pyspark.sql.streaming.state import GroupStateTimeout
+from pyspark.sql.types import StructType, _parse_datatype_string
 
 if TYPE_CHECKING:
     from pyspark.sql.pandas._typing import (
         GroupedMapPandasUserDefinedFunction,
         PandasGroupedMapFunction,
+        PandasGroupedMapFunctionWithState,
         PandasCogroupedMapFunction,
     )
     from pyspark.sql.group import GroupedData
@@ -216,6 +218,125 @@ class PandasGroupedOpsMixin:
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a 
user-defined
+        per-group state. The result Dataset will represent the flattened 
record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked first for all 
input groups and then
+        for all timed out states where the input data is set to be empty. 
Updates to each group's
+        state will be saved across invocations.
+
+        The function should take parameters (key, 
Iterator[`pandas.DataFrame`], state) and
+        return another Iterator[`pandas.DataFrame`]. The grouping key(s) will 
be passed as a tuple
+        of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The 
state will be passed as
+        :class:`pyspark.sql.streaming.state.GroupState`.
+
+        For each group, all columns are passed together as `pandas.DataFrame` 
to the user-function,
+        and the returned `pandas.DataFrame` across all invocations are 
combined as a
+        :class:`DataFrame`. Note that the user function should not make a 
guess of the number of
+        elements in the iterator. To process all data, the user function needs 
to iterate all
+        elements and process them. On the other hand, the user function is not 
strictly required to
+        iterate through all elements in the iterator if it intends to read a 
part of data.
+
+        The `outputStructType` should be a :class:`StructType` describing the 
schema of all
+        elements in the returned value, `pandas.DataFrame`. The column labels 
of all elements in
+        returned `pandas.DataFrame` must either match the field names in the 
defined schema if
+        specified as strings, or match the field data types by position if not 
strings,
+        e.g. integer indices.
+
+        The `stateStructType` should be :class:`StructType` describing the 
schema of the
+        user-defined state. The value of the state will be presented as a 
tuple, as well as the
+        update should be performed with the tuple. The corresponding Python 
types for
+        :class:DataType are supported. Please refer to the page
+        https://spark.apache.org/docs/latest/sql-ref-datatypes.html (python 
tab).
+
+        The size of each DataFrame in both the input and output can be 
arbitrary. The number of
+        DataFrames in both the input and output can also be arbitrary.
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        func : function
+            a Python native function to be called on every group. It should 
take parameters
+            (key, Iterator[`pandas.DataFrame`], state) and return 
Iterator[`pandas.DataFrame`].
+            Note that the type of the key is tuple and the type of the state is
+            :class:`pyspark.sql.streaming.state.GroupState`.
+        outputStructType : :class:`pyspark.sql.types.DataType` or str
+            the type of the output records. The value can be either a
+            :class:`pyspark.sql.types.DataType` object or a DDL-formatted type 
string.
+        stateStructType : :class:`pyspark.sql.types.DataType` or str
+            the type of the user-defined state. The value can be either a
+            :class:`pyspark.sql.types.DataType` object or a DDL-formatted type 
string.
+        outputMode : str
+            the output mode of the function.
+        timeoutConf : str
+            timeout configuration for groups that do not receive data for a 
while. valid values
+            are defined in 
:class:`pyspark.sql.streaming.state.GroupStateTimeout`.
+
+        Examples
+        --------
+        >>> import pandas as pd  # doctest: +SKIP
+        >>> from pyspark.sql.streaming.state import GroupStateTimeout
+        >>> def count_fn(key, pdf_iter, state):
+        ...     assert isinstance(state, GroupStateImpl)
+        ...     total_len = 0
+        ...     for pdf in pdf_iter:
+        ...         total_len += len(pdf)
+        ...     state.update((total_len,))
+        ...     yield pd.DataFrame({"id": [key[0]], "countAsString": 
[str(total_len)]})
+        >>> df.groupby("id").applyInPandasWithState(
+        ...     count_fn, outputStructType="id long, countAsString string",
+        ...     stateStructType="len long", outputMode="Update",
+        ...     timeoutConf=GroupStateTimeout.NoTimeout) # doctest: +SKIP
+
+        Notes
+        -----
+        This function requires a full shuffle.
+
+        This API is experimental.
+        """
+
+        from pyspark.sql import GroupedData
+        from pyspark.sql.functions import pandas_udf
+
+        assert isinstance(self, GroupedData)
+        assert timeoutConf in [
+            GroupStateTimeout.NoTimeout,
+            GroupStateTimeout.ProcessingTimeTimeout,
+            GroupStateTimeout.EventTimeTimeout,
+        ]
+
+        if isinstance(outputStructType, str):
+            outputStructType = cast(StructType, 
_parse_datatype_string(outputStructType))
+        if isinstance(stateStructType, str):
+            stateStructType = cast(StructType, 
_parse_datatype_string(stateStructType))
+
+        udf = pandas_udf(
+            func,  # type: ignore[call-overload]
+            returnType=outputStructType,
+            functionType=PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
+        )
+        df = self._df
+        udf_column = udf(*[df[col] for col in df.columns])
+        jdf = self._jgd.applyInPandasWithState(
+            udf_column._jc.expr(),
+            self.session._jsparkSession.parseDataType(outputStructType.json()),
+            self.session._jsparkSession.parseDataType(stateStructType.json()),
+            outputMode,
+            timeoutConf,
+        )
+        return DataFrame(jdf, self.session)
+
     def cogroup(self, other: "GroupedData") -> "PandasCogroupedOps":
         """
         Cogroups this group with another group so that we can run cogrouped 
operations.
diff --git a/python/pyspark/sql/pandas/serializers.py 
b/python/pyspark/sql/pandas/serializers.py
index 992e82b403a..ca249c75ea5 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -19,7 +19,9 @@
 Serializers for PyArrow and pandas conversions. See `pyspark.serializers` for 
more details.
 """
 
-from pyspark.serializers import Serializer, read_int, write_int, 
UTF8Deserializer
+from pyspark.serializers import Serializer, read_int, write_int, 
UTF8Deserializer, CPickleSerializer
+from pyspark.sql.pandas.types import to_arrow_type
+from pyspark.sql.types import StringType, StructType, BinaryType, StructField, 
LongType
 
 
 class SpecialLengths:
@@ -371,3 +373,354 @@ class 
CogroupUDFSerializer(ArrowStreamPandasUDFSerializer):
                 raise ValueError(
                     "Invalid number of pandas.DataFrames in group 
{0}".format(dataframes_in_group)
                 )
+
+
+class ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer):
+    """
+    Serializer used by Python worker to evaluate UDF for 
applyInPandasWithState.
+
+    Parameters
+    ----------
+    timezone : str
+        A timezone to respect when handling timestamp values
+    safecheck : bool
+        If True, conversion from Arrow to Pandas checks for overflow/truncation
+    assign_cols_by_name : bool
+        If True, then Pandas DataFrames will get columns by name
+    state_object_schema : StructType
+        The type of state object represented as Spark SQL type
+    arrow_max_records_per_batch : int
+        Limit of the number of records that can be written to a single 
ArrowRecordBatch in memory.
+    """
+
+    def __init__(
+        self,
+        timezone,
+        safecheck,
+        assign_cols_by_name,
+        state_object_schema,
+        arrow_max_records_per_batch,
+    ):
+        super(ApplyInPandasWithStateSerializer, self).__init__(
+            timezone, safecheck, assign_cols_by_name
+        )
+        self.pickleSer = CPickleSerializer()
+        self.utf8_deserializer = UTF8Deserializer()
+        self.state_object_schema = state_object_schema
+
+        self.result_state_df_type = StructType(
+            [
+                StructField("properties", StringType()),
+                StructField("keyRowAsUnsafe", BinaryType()),
+                StructField("object", BinaryType()),
+                StructField("oldTimeoutTimestamp", LongType()),
+            ]
+        )
+
+        self.result_state_pdf_arrow_type = 
to_arrow_type(self.result_state_df_type)
+        self.arrow_max_records_per_batch = arrow_max_records_per_batch
+
+    def load_stream(self, stream):
+        """
+        Read ArrowRecordBatches from stream, deserialize them to populate a 
list of pair
+        (data chunk, state), and convert the data into a list of pandas.Series.
+
+        Please refer the doc of inner function `gen_data_and_state` for more 
details how
+        this function works in overall.
+
+        In addition, this function further groups the return of 
`gen_data_and_state` by the state
+        instance (same semantic as grouping by grouping key) and produces an 
iterator of data
+        chunks for each group, so that the caller can lazily materialize the 
data chunk.
+        """
+
+        import pyarrow as pa
+        import json
+        from itertools import groupby
+        from pyspark.sql.streaming.state import GroupState
+
+        def construct_state(state_info_col):
+            """
+            Construct state instance from the value of state information 
column.
+            """
+
+            state_info_col_properties = state_info_col["properties"]
+            state_info_col_key_row = state_info_col["keyRowAsUnsafe"]
+            state_info_col_object = state_info_col["object"]
+
+            state_properties = json.loads(state_info_col_properties)
+            if state_info_col_object:
+                state_object = self.pickleSer.loads(state_info_col_object)
+            else:
+                state_object = None
+            state_properties["optionalValue"] = state_object
+
+            return GroupState(
+                keyAsUnsafe=state_info_col_key_row,
+                valueSchema=self.state_object_schema,
+                **state_properties,
+            )
+
+        def gen_data_and_state(batches):
+            """
+            Deserialize ArrowRecordBatches and return a generator of
+            `(a list of pandas.Series, state)`.
+
+            The logic on deserialization is following:
+
+            1. Read the entire data part from Arrow RecordBatch.
+            2. Read the entire state information part from Arrow RecordBatch.
+            3. Loop through each state information:
+               3.A. Extract the data out from entire data via the information 
of data range.
+               3.B. Construct a new state instance if the state information is 
the first occurrence
+                    for the current grouping key.
+               3.C. Leverage the existing state instance if it is already 
available for the current
+                    grouping key. (Meaning it's not the first occurrence.)
+               3.D. Remove the cache of state instance if the state 
information denotes the data is
+                    the last chunk for current grouping key.
+
+            This deserialization logic assumes that Arrow RecordBatches 
contain the data with the
+            ordering that data chunks for same grouping key will appear 
sequentially.
+
+            This function must avoid materializing multiple Arrow 
RecordBatches into memory at the
+            same time. And data chunks from the same grouping key should 
appear sequentially, to
+            further group them based on state instance (same state instance 
will be produced for
+            same grouping key).
+            """
+
+            state_for_current_group = None
+
+            for batch in batches:
+                batch_schema = batch.schema
+                data_schema = pa.schema([batch_schema[i] for i in range(0, 
len(batch_schema) - 1)])
+                state_schema = pa.schema(
+                    [
+                        batch_schema[-1],
+                    ]
+                )
+
+                batch_columns = batch.columns
+                data_columns = batch_columns[0:-1]
+                state_column = batch_columns[-1]
+
+                data_batch = pa.RecordBatch.from_arrays(data_columns, 
schema=data_schema)
+                state_batch = pa.RecordBatch.from_arrays(
+                    [
+                        state_column,
+                    ],
+                    schema=state_schema,
+                )
+
+                state_arrow = 
pa.Table.from_batches([state_batch]).itercolumns()
+                state_pandas = [self.arrow_to_pandas(c) for c in 
state_arrow][0]
+
+                for state_idx in range(0, len(state_pandas)):
+                    state_info_col = state_pandas.iloc[state_idx]
+
+                    if not state_info_col:
+                        # no more data with grouping key + state
+                        break
+
+                    data_start_offset = state_info_col["startOffset"]
+                    num_data_rows = state_info_col["numRows"]
+                    is_last_chunk = state_info_col["isLastChunk"]
+
+                    if state_for_current_group:
+                        # use the state, we already have state for same group 
and there should be
+                        # some data in same group being processed earlier
+                        state = state_for_current_group
+                    else:
+                        # there is no state being stored for same group, 
construct one
+                        state = construct_state(state_info_col)
+
+                    if is_last_chunk:
+                        # discard the state being cached for same group
+                        state_for_current_group = None
+                    elif not state_for_current_group:
+                        # there's no cached state but expected to have 
additional data in same group
+                        # cache the current state
+                        state_for_current_group = state
+
+                    data_batch_for_group = data_batch.slice(data_start_offset, 
num_data_rows)
+                    data_arrow = 
pa.Table.from_batches([data_batch_for_group]).itercolumns()
+
+                    data_pandas = [self.arrow_to_pandas(c) for c in data_arrow]
+
+                    # state info
+                    yield (
+                        data_pandas,
+                        state,
+                    )
+
+        _batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
+
+        data_state_generator = gen_data_and_state(_batches)
+
+        # state will be same object for same grouping key
+        for _state, _data in groupby(data_state_generator, key=lambda x: x[1]):
+            yield (
+                _data,
+                _state,
+            )
+
+    def dump_stream(self, iterator, stream):
+        """
+        Read through an iterator of (iterator of pandas DataFrame, state), 
serialize them to Arrow
+        RecordBatches, and write batches to stream.
+        """
+
+        import pandas as pd
+        import pyarrow as pa
+
+        def construct_state_pdf(state):
+            """
+            Construct a pandas DataFrame from the state instance.
+            """
+
+            state_properties = state.json().encode("utf-8")
+            state_key_row_as_binary = state._keyAsUnsafe
+            if state.exists:
+                state_object = 
self.pickleSer.dumps(state._value_schema.toInternal(state._value))
+            else:
+                state_object = None
+            state_old_timeout_timestamp = state.oldTimeoutTimestamp
+
+            state_dict = {
+                "properties": [
+                    state_properties,
+                ],
+                "keyRowAsUnsafe": [
+                    state_key_row_as_binary,
+                ],
+                "object": [
+                    state_object,
+                ],
+                "oldTimeoutTimestamp": [
+                    state_old_timeout_timestamp,
+                ],
+            }
+
+            return pd.DataFrame.from_dict(state_dict)
+
+        def construct_record_batch(pdfs, pdf_data_cnt, pdf_schema, state_pdfs, 
state_data_cnt):
+            """
+            Construct a new Arrow RecordBatch based on output pandas 
DataFrames and states. Each
+            one matches to the single struct field for Arrow schema, hence the 
return value of
+            Arrow RecordBatch will have schema with two fields, in `data`, 
`state` order.
+            (Readers are expected to access the field via position rather than 
the name. We do
+            not guarantee the name of the field.)
+
+            Note that Arrow RecordBatch requires all columns to have all same 
number of rows,
+            hence this function inserts empty data for state/data with less 
elements to compensate.
+            """
+
+            max_data_cnt = max(pdf_data_cnt, state_data_cnt)
+
+            empty_row_cnt_in_data = max_data_cnt - pdf_data_cnt
+            empty_row_cnt_in_state = max_data_cnt - state_data_cnt
+
+            empty_rows_pdf = pd.DataFrame(
+                dict.fromkeys(pa.schema(pdf_schema).names),
+                index=[x for x in range(0, empty_row_cnt_in_data)],
+            )
+            empty_rows_state = pd.DataFrame(
+                columns=["properties", "keyRowAsUnsafe", "object", 
"oldTimeoutTimestamp"],
+                index=[x for x in range(0, empty_row_cnt_in_state)],
+            )
+
+            pdfs.append(empty_rows_pdf)
+            state_pdfs.append(empty_rows_state)
+
+            merged_pdf = pd.concat(pdfs, ignore_index=True)
+            merged_state_pdf = pd.concat(state_pdfs, ignore_index=True)
+
+            return self._create_batch(
+                [(merged_pdf, pdf_schema), (merged_state_pdf, 
self.result_state_pdf_arrow_type)]
+            )
+
+        def serialize_batches():
+            """
+            Read through an iterator of (iterator of pandas DataFrame, state), 
and serialize them
+            to Arrow RecordBatches.
+
+            This function does batching on constructing the Arrow RecordBatch; 
a batch will be
+            serialized to the Arrow RecordBatch when the total number of 
records exceeds the
+            configured threshold.
+            """
+            # a set of variables for the state of current batch which will be 
converted to Arrow
+            # RecordBatch.
+            pdfs = []
+            state_pdfs = []
+            pdf_data_cnt = 0
+            state_data_cnt = 0
+
+            return_schema = None
+
+            for data in iterator:
+                # data represents the result of each call of user function
+                packaged_result = data[0]
+
+                # There are two results from the call of user function:
+                # 1) iterator of pandas DataFrame (output)
+                # 2) updated state instance
+                pdf_iter = packaged_result[0][0]
+                state = packaged_result[0][1]
+
+                # This is static and won't change across batches.
+                return_schema = packaged_result[1]
+
+                for pdf in pdf_iter:
+                    # We ignore empty pandas DataFrame.
+                    if len(pdf) > 0:
+                        pdf_data_cnt += len(pdf)
+                        pdfs.append(pdf)
+
+                        # If the total number of records in current batch 
exceeds the configured
+                        # threshold, time to construct the Arrow RecordBatch 
from the batch.
+                        if pdf_data_cnt > self.arrow_max_records_per_batch:
+                            batch = construct_record_batch(
+                                pdfs, pdf_data_cnt, return_schema, state_pdfs, 
state_data_cnt
+                            )
+
+                            # Reset the variables to start with new batch for 
further data.
+                            pdfs = []
+                            state_pdfs = []
+                            pdf_data_cnt = 0
+                            state_data_cnt = 0
+
+                            yield batch
+
+                # This has to be performed 'after' evaluating all elements in 
iterator, so that
+                # the user function has been completed and the state is 
guaranteed to be updated.
+                state_pdf = construct_state_pdf(state)
+
+                state_pdfs.append(state_pdf)
+                state_data_cnt += 1
+
+            # processed all output, but current batch may not be flushed yet.
+            if pdf_data_cnt > 0 or state_data_cnt > 0:
+                batch = construct_record_batch(
+                    pdfs, pdf_data_cnt, return_schema, state_pdfs, 
state_data_cnt
+                )
+
+                yield batch
+
+        def init_stream_yield_batches(batches):
+            """
+            This function helps to ensure the requirement for Pandas UDFs - 
Pandas UDFs require a
+            START_ARROW_STREAM before the Arrow stream is sent.
+
+            START_ARROW_STREAM should be sent after creating the first record 
batch so in case of
+            an error, it can be sent back to the JVM before the Arrow stream 
starts.
+            """
+            should_write_start_length = True
+
+            for batch in batches:
+                if should_write_start_length:
+                    write_int(SpecialLengths.START_ARROW_STREAM, stream)
+                    should_write_start_length = False
+
+                yield batch
+
+        batches_to_write = init_stream_yield_batches(serialize_batches())
+
+        return ArrowStreamSerializer.dump_stream(self, batches_to_write, 
stream)
diff --git a/python/pyspark/sql/streaming/state.py 
b/python/pyspark/sql/streaming/state.py
index 842eff32233..66b225e1b10 100644
--- a/python/pyspark/sql/streaming/state.py
+++ b/python/pyspark/sql/streaming/state.py
@@ -20,16 +20,24 @@ from typing import Tuple, Optional
 
 from pyspark.sql.types import DateType, Row, StructType
 
-__all__ = ["GroupStateImpl", "GroupStateTimeout"]
+__all__ = ["GroupState", "GroupStateTimeout"]
 
 
 class GroupStateTimeout:
+    """
+    Represents the type of timeouts possible for the Dataset operations 
applyInPandasWithState.
+    """
+
     NoTimeout: str = "NoTimeout"
     ProcessingTimeTimeout: str = "ProcessingTimeTimeout"
     EventTimeTimeout: str = "EventTimeTimeout"
 
 
-class GroupStateImpl:
+class GroupState:
+    """
+    Wrapper class for interacting with per-group state data in 
`applyInPandasWithState`.
+    """
+
     NO_TIMESTAMP: int = -1
 
     def __init__(
@@ -76,10 +84,16 @@ class GroupStateImpl:
 
     @property
     def exists(self) -> bool:
+        """
+        Whether state exists or not.
+        """
         return self._defined
 
     @property
     def get(self) -> Tuple:
+        """
+        Get the state value if it exists, or throw ValueError.
+        """
         if self.exists:
             return tuple(self._value)
         else:
@@ -87,6 +101,9 @@ class GroupStateImpl:
 
     @property
     def getOption(self) -> Optional[Tuple]:
+        """
+        Get the state value if it exists, or return None.
+        """
         if self.exists:
             return tuple(self._value)
         else:
@@ -94,6 +111,10 @@ class GroupStateImpl:
 
     @property
     def hasTimedOut(self) -> bool:
+        """
+        Whether the function has been called because the key has timed out.
+        This can return true only when timeouts are enabled.
+        """
         return self._has_timed_out
 
     # NOTE: this function is only available to PySpark implementation due to 
underlying
@@ -103,6 +124,9 @@ class GroupStateImpl:
         return self._old_timeout_timestamp
 
     def update(self, newValue: Tuple) -> None:
+        """
+        Update the value of the state. The value of the state cannot be null.
+        """
         if newValue is None:
             raise ValueError("'None' is not a valid state value")
 
@@ -112,11 +136,18 @@ class GroupStateImpl:
         self._removed = False
 
     def remove(self) -> None:
+        """
+        Remove this state.
+        """
         self._defined = False
         self._updated = False
         self._removed = True
 
     def setTimeoutDuration(self, durationMs: int) -> None:
+        """
+        Set the timeout duration in ms for this key.
+        Processing time timeout must be enabled.
+        """
         if isinstance(durationMs, str):
             # TODO(SPARK-40437): Support string representation of durationMs.
             raise ValueError("durationMs should be int but get :%s" % 
type(durationMs))
@@ -133,6 +164,11 @@ class GroupStateImpl:
 
     # TODO(SPARK-40438): Implement additionalDuration parameter.
     def setTimeoutTimestamp(self, timestampMs: int) -> None:
+        """
+        Set the timeout timestamp for this key as milliseconds in epoch time.
+        This timestamp cannot be older than the current watermark.
+        Event time timeout must be enabled.
+        """
         if self._timeout_conf != GroupStateTimeout.EventTimeTimeout:
             raise RuntimeError(
                 "Cannot set timeout duration without enabling processing time 
timeout in "
@@ -146,7 +182,7 @@ class GroupStateImpl:
             raise ValueError("Timeout timestamp must be positive")
 
         if (
-            self._event_time_watermark_ms != GroupStateImpl.NO_TIMESTAMP
+            self._event_time_watermark_ms != GroupState.NO_TIMESTAMP
             and timestampMs < self._event_time_watermark_ms
         ):
             raise ValueError(
@@ -157,6 +193,10 @@ class GroupStateImpl:
         self._timeout_timestamp = timestampMs
 
     def getCurrentWatermarkMs(self) -> int:
+        """
+        Get the current event time watermark as milliseconds in epoch time.
+        In a streaming query, this can be called only when watermark is set.
+        """
         if not self._watermark_present:
             raise RuntimeError(
                 "Cannot get event time watermark timestamp without setting 
watermark before "
@@ -165,6 +205,11 @@ class GroupStateImpl:
         return self._event_time_watermark_ms
 
     def getCurrentProcessingTimeMs(self) -> int:
+        """
+        Get the current processing time as milliseconds in epoch time.
+        In a streaming query, this will return a constant value throughout the 
duration of a
+        trigger, even if the trigger is re-executed.
+        """
         return self._batch_processing_time_ms
 
     def __str__(self) -> str:
@@ -174,6 +219,10 @@ class GroupStateImpl:
             return "GroupState(<undefined>)"
 
     def json(self) -> str:
+        """
+        Convert the internal values of instance into JSON. This is used to 
send out the update
+        from Python worker to executor.
+        """
         return json.dumps(
             {
                 # Constructor
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index 6a01e399d04..da9a245bb71 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -144,20 +144,23 @@ class UserDefinedFunction:
                     "Invalid return type with scalar Pandas UDFs: %s is "
                     "not supported" % str(self._returnType_placeholder)
                 )
-        elif self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
+        elif (
+            self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
+            or self.evalType == 
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE
+        ):
             if isinstance(self._returnType_placeholder, StructType):
                 try:
                     to_arrow_type(self._returnType_placeholder)
                 except TypeError:
                     raise NotImplementedError(
                         "Invalid return type with grouped map Pandas UDFs or "
-                        "at groupby.applyInPandas: %s is not supported"
+                        "at groupby.applyInPandas(WithState): %s is not 
supported"
                         % str(self._returnType_placeholder)
                     )
             else:
                 raise TypeError(
                     "Invalid return type for grouped map Pandas "
-                    "UDFs or at groupby.applyInPandas: return type must be a "
+                    "UDFs or at groupby.applyInPandas(WithState): return type 
must be a "
                     "StructType."
                 )
         elif (
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index c486b7bed1d..c1c3669701f 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -23,6 +23,7 @@ import sys
 import time
 from inspect import currentframe, getframeinfo, getfullargspec
 import importlib
+import json
 
 # 'resource' is a Unix specific module.
 has_resource_module = True
@@ -57,6 +58,7 @@ from pyspark.sql.pandas.serializers import (
     ArrowStreamPandasUDFSerializer,
     CogroupUDFSerializer,
     ArrowStreamUDFSerializer,
+    ApplyInPandasWithStateSerializer,
 )
 from pyspark.sql.pandas.types import to_arrow_type
 from pyspark.sql.types import StructType
@@ -207,6 +209,90 @@ def wrap_grouped_map_pandas_udf(f, return_type, argspec):
     return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))]
 
 
+def wrap_grouped_map_pandas_udf_with_state(f, return_type):
+    """
+    Provides a new lambda instance wrapping user function of 
applyInPandasWithState.
+
+    The lambda instance receives (key series, iterator of value series, state) 
and performs
+    some conversion to be adapted with the signature of user function.
+
+    See the function doc of inner function `wrapped` for more details on what 
adapter does.
+    See the function doc of `mapper` function for
+    `eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE` for 
more details on
+    the input parameters of lambda function.
+
+    Along with the returned iterator, the lambda instance will also produce 
the return_type as
+    converted to the arrow schema.
+    """
+
+    def wrapped(key_series, value_series_gen, state):
+        """
+        Provide an adapter of the user function performing below:
+
+        - Extract the first value of all columns in key series and produce as 
a tuple.
+        - If the state has timed out, call the user function with empty pandas 
DataFrame.
+        - If not, construct a new generator which converts each element of 
value series to
+          pandas DataFrame (lazy evaluation), and call the user function with 
the generator
+        - Verify each element of returned iterator to check the schema of 
pandas DataFrame.
+        """
+        import pandas as pd
+
+        key = tuple(s[0] for s in key_series)
+
+        if state.hasTimedOut:
+            # Timeout processing pass empty iterator. Here we return an empty 
DataFrame instead.
+            values = [
+                pd.DataFrame(columns=pd.concat(next(value_series_gen), 
axis=1).columns),
+            ]
+        else:
+            values = (pd.concat(x, axis=1) for x in value_series_gen)
+
+        result_iter = f(key, values, state)
+
+        def verify_element(result):
+            if not isinstance(result, pd.DataFrame):
+                raise TypeError(
+                    "The type of element in return iterator of the 
user-defined function "
+                    "should be pandas.DataFrame, but is 
{}".format(type(result))
+                )
+            # the number of columns of result have to match the return type
+            # but it is fine for result to have no columns at all if it is 
empty
+            if not (
+                len(result.columns) == len(return_type)
+                or (len(result.columns) == 0 and result.empty)
+            ):
+                raise RuntimeError(
+                    "Number of columns of the element (pandas.DataFrame) in 
return iterator "
+                    "doesn't match specified schema. "
+                    "Expected: {} Actual: {}".format(len(return_type), 
len(result.columns))
+                )
+
+            return result
+
+        if isinstance(result_iter, pd.DataFrame):
+            raise TypeError(
+                "Return type of the user-defined function should be "
+                "iterable of pandas.DataFrame, but is 
{}".format(type(result_iter))
+            )
+
+        try:
+            iter(result_iter)
+        except TypeError:
+            raise TypeError(
+                "Return type of the user-defined function should be "
+                "iterable, but is {}".format(type(result_iter))
+            )
+
+        result_iter_with_validation = (verify_element(x) for x in result_iter)
+
+        return (
+            result_iter_with_validation,
+            state,
+        )
+
+    return lambda k, v, s: [(wrapped(k, v, s), to_arrow_type(return_type))]
+
+
 def wrap_grouped_agg_pandas_udf(f, return_type):
     arrow_return_type = to_arrow_type(return_type)
 
@@ -311,6 +397,8 @@ def read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index):
     elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
         argspec = getfullargspec(chained_func)  # signature was lost when 
wrapping it
         return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, 
argspec)
+    elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
+        return arg_offsets, wrap_grouped_map_pandas_udf_with_state(func, 
return_type)
     elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
         argspec = getfullargspec(chained_func)  # signature was lost when 
wrapping it
         return arg_offsets, wrap_cogrouped_map_pandas_udf(func, return_type, 
argspec)
@@ -336,6 +424,7 @@ def read_udfs(pickleSer, infile, eval_type):
         PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
         PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
         PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF,
+        PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
     ):
 
         # Load conf used for pandas_udf evaluation
@@ -345,6 +434,10 @@ def read_udfs(pickleSer, infile, eval_type):
             v = utf8_deserializer.loads(infile)
             runner_conf[k] = v
 
+        state_object_schema = None
+        if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
+            state_object_schema = 
StructType.fromJson(json.loads(utf8_deserializer.loads(infile)))
+
         # NOTE: if timezone is set here, that implies respectSessionTimeZone 
is True
         timezone = runner_conf.get("spark.sql.session.timeZone", None)
         safecheck = (
@@ -361,6 +454,19 @@ def read_udfs(pickleSer, infile, eval_type):
 
         if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
             ser = CogroupUDFSerializer(timezone, safecheck, 
assign_cols_by_name)
+        elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
+            arrow_max_records_per_batch = runner_conf.get(
+                "spark.sql.execution.arrow.maxRecordsPerBatch", 10000
+            )
+            arrow_max_records_per_batch = int(arrow_max_records_per_batch)
+
+            ser = ApplyInPandasWithStateSerializer(
+                timezone,
+                safecheck,
+                assign_cols_by_name,
+                state_object_schema,
+                arrow_max_records_per_batch,
+            )
         elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF:
             ser = ArrowStreamUDFSerializer()
         else:
@@ -486,6 +592,43 @@ def read_udfs(pickleSer, infile, eval_type):
             vals = [a[o] for o in parsed_offsets[0][1]]
             return f(keys, vals)
 
+    elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
+        # We assume there is only one UDF here because grouped map doesn't
+        # support combining multiple UDFs.
+        assert num_udfs == 1
+
+        # See FlatMapGroupsInPandas(WithState)Exec for how arg_offsets are 
used to
+        # distinguish between grouping attributes and data attributes
+        arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index=0)
+        parsed_offsets = extract_key_value_indexes(arg_offsets)
+
+        def mapper(a):
+            """
+            The function receives (iterator of data, state) and performs 
extraction of key and
+            value from the data, with retaining lazy evaluation.
+
+            See `load_stream` in `ApplyInPandasWithStateSerializer` for more 
details on the input
+            and see `wrap_grouped_map_pandas_udf_with_state` for more details 
on how output will
+            be used.
+            """
+            from itertools import tee
+
+            state = a[1]
+            data_gen = (x[0] for x in a[0])
+
+            # We know there should be at least one item in the 
iterator/generator.
+            # We want to peek the first element to construct the key, hence 
applying
+            # tee to construct the key while we retain another 
iterator/generator
+            # for values.
+            keys_gen, values_gen = tee(data_gen)
+            keys_elem = next(keys_gen)
+            keys = [keys_elem[o] for o in parsed_offsets[0][0]]
+
+            # This must be generator comprehension - do not materialize.
+            vals = ([x[o] for o in parsed_offsets[0][1]] for x in values_gen)
+
+            return f(keys, vals, state)
+
     elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
         # We assume there is only one UDF here because cogrouped map doesn't
         # support combining multiple UDFs.
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
index c11ce7d3b90..84795203fd1 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
@@ -64,6 +64,7 @@ object UnsupportedOperationChecker extends Logging {
       case s: Aggregate if s.isStreaming => true
       case _ @ Join(left, right, _, _, _) if left.isStreaming && 
right.isStreaming => true
       case f: FlatMapGroupsWithState if f.isStreaming => true
+      case f: FlatMapGroupsInPandasWithState if f.isStreaming => true
       case d: Deduplicate if d.isStreaming => true
       case _ => false
     }
@@ -142,6 +143,17 @@ object UnsupportedOperationChecker extends Logging {
           " or the output mode is not append on a streaming 
DataFrames/Datasets")(plan)
     }
 
+    val applyInPandasWithStates = plan.collect {
+      case f: FlatMapGroupsInPandasWithState if f.isStreaming => f
+    }
+
+    // Disallow multiple `applyInPandasWithState`s.
+    if (applyInPandasWithStates.size > 1) {
+      throwError(
+        "Multiple applyInPandasWithStates are not supported on a streaming " +
+          "DataFrames/Datasets")(plan)
+    }
+
     // Disallow multiple streaming aggregations
     val aggregates = collectStreamingAggregates(plan)
 
@@ -311,6 +323,56 @@ object UnsupportedOperationChecker extends Logging {
             }
           }
 
+        // applyInPandasWithState
+        case m: FlatMapGroupsInPandasWithState if m.isStreaming =>
+          // Check compatibility with output modes and aggregations in query
+          val aggsInQuery = collectStreamingAggregates(plan)
+
+          if (aggsInQuery.isEmpty) {
+            // applyInPandasWithState without aggregation: operation's output 
mode must
+            // match query output mode
+            m.outputMode match {
+              case InternalOutputModes.Update if outputMode != 
InternalOutputModes.Update =>
+                throwError(
+                  "applyInPandasWithState in update mode is not supported with 
" +
+                    s"$outputMode output mode on a streaming 
DataFrame/Dataset")
+
+              case InternalOutputModes.Append if outputMode != 
InternalOutputModes.Append =>
+                throwError(
+                  "applyInPandasWithState in append mode is not supported with 
" +
+                    s"$outputMode output mode on a streaming 
DataFrame/Dataset")
+
+              case _ =>
+            }
+          } else {
+            // applyInPandasWithState with aggregation: update operation mode 
not allowed, and
+            // *groupsWithState after aggregation not allowed
+            if (m.outputMode == InternalOutputModes.Update) {
+              throwError(
+                "applyInPandasWithState in update mode is not supported with " 
+
+                  "aggregation on a streaming DataFrame/Dataset")
+            } else if (collectStreamingAggregates(m).nonEmpty) {
+              throwError(
+                "applyInPandasWithState in append mode is not supported after 
" +
+                  "aggregation on a streaming DataFrame/Dataset")
+            }
+          }
+
+          // Check compatibility with timeout configs
+          if (m.timeout == EventTimeTimeout) {
+            // With event time timeout, watermark must be defined.
+            val watermarkAttributes = m.child.output.collect {
+              case a: Attribute if 
a.metadata.contains(EventTimeWatermark.delayKey) => a
+            }
+            if (watermarkAttributes.isEmpty) {
+              throwError(
+                "Watermark must be specified in the query using " +
+                  "'[Dataset/DataFrame].withWatermark()' for using event-time 
timeout in a " +
+                  "applyInPandasWithState. Event-time timeout not supported 
without " +
+                  "watermark.")(plan)
+            }
+          }
+
         case d: Deduplicate if collectStreamingAggregates(d).nonEmpty =>
           throwError("dropDuplicates is not supported after aggregation on a " 
+
             "streaming DataFrame/Dataset")
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
index c2f74b35083..e97ff7808f1 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.plans.logical
 
 import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, 
Expression, PythonUDF}
 import org.apache.spark.sql.catalyst.util.truncatedString
+import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
+import org.apache.spark.sql.types.StructType
 
 /**
  * FlatMap groups using a udf: pandas.Dataframe -> pandas.DataFrame.
@@ -98,6 +100,38 @@ case class FlatMapCoGroupsInPandas(
     copy(left = newLeft, right = newRight)
 }
 
+/**
+ * Similar with [[FlatMapGroupsWithState]]. Applies func to each unique group
+ * in `child`, based on the evaluation of `groupingAttributes`,
+ * while using state data.
+ * `functionExpr` is invoked with an pandas DataFrame representation and the
+ * grouping key (tuple).
+ *
+ * @param functionExpr function called on each group
+ * @param groupingAttributes used to group the data
+ * @param outputAttrs used to define the output rows
+ * @param stateType used to serialize/deserialize state before calling 
`functionExpr`
+ * @param outputMode the output mode of `func`
+ * @param timeout used to timeout groups that have not received data in a while
+ * @param child logical plan of the underlying data
+ */
+case class FlatMapGroupsInPandasWithState(
+    functionExpr: Expression,
+    groupingAttributes: Seq[Attribute],
+    outputAttrs: Seq[Attribute],
+    stateType: StructType,
+    outputMode: OutputMode,
+    timeout: GroupStateTimeout,
+    child: LogicalPlan) extends UnaryNode {
+
+  override def output: Seq[Attribute] = outputAttrs
+
+  override def producedAttributes: AttributeSet = AttributeSet(outputAttrs)
+
+  override protected def withNewChildInternal(
+    newChild: LogicalPlan): FlatMapGroupsInPandasWithState = copy(child = 
newChild)
+}
+
 trait BaseEvalPython extends UnaryNode {
 
   def udfs: Seq[PythonUDF]
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 989ee325218..0429fd27a41 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -30,9 +30,11 @@ import org.apache.spark.sql.catalyst.encoders.encoderFor
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
 import org.apache.spark.sql.catalyst.util.toPrettySQL
 import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
 import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
+import org.apache.spark.sql.streaming.OutputMode
 import org.apache.spark.sql.types.{NumericType, StructType}
 
 /**
@@ -620,6 +622,49 @@ class RelationalGroupedDataset protected[sql](
     Dataset.ofRows(df.sparkSession, plan)
   }
 
+  /**
+   * Applies a grouped vectorized python user-defined function to each group 
of data.
+   * The user-defined function defines a transformation: iterator of 
`pandas.DataFrame` ->
+   * iterator of `pandas.DataFrame`.
+   * For each group, all elements in the group are passed as an iterator of 
`pandas.DataFrame`
+   * along with corresponding state, and the results for all groups are 
combined into a new
+   * [[DataFrame]].
+   *
+   * This function does not support partial aggregation, and requires 
shuffling all the data in
+   * the [[DataFrame]].
+   *
+   * This function uses Apache Arrow as serialization format between Java 
executors and Python
+   * workers.
+   */
+  private[sql] def applyInPandasWithState(
+      func: PythonUDF,
+      outputStructType: StructType,
+      stateStructType: StructType,
+      outputModeStr: String,
+      timeoutConfStr: String): DataFrame = {
+    val timeoutConf = org.apache.spark.sql.execution.streaming
+      .GroupStateImpl.groupStateTimeoutFromString(timeoutConfStr)
+    val outputMode = InternalOutputModes(outputModeStr)
+    if (outputMode != OutputMode.Append && outputMode != OutputMode.Update) {
+      throw new IllegalArgumentException("The output mode of function should 
be append or update")
+    }
+    val groupingNamedExpressions = groupingExprs.map {
+      case ne: NamedExpression => ne
+      case other => Alias(other, other.toString)()
+    }
+    val groupingAttrs = groupingNamedExpressions.map(_.toAttribute)
+    val outputAttrs = outputStructType.toAttributes
+    val plan = FlatMapGroupsInPandasWithState(
+      func,
+      groupingAttrs,
+      outputAttrs,
+      stateStructType,
+      outputMode,
+      timeoutConf,
+      child = df.logicalPlan)
+    Dataset.ofRows(df.sparkSession, plan)
+  }
+
   override def toString: String = {
     val builder = new StringBuilder
     builder.append("RelationalGroupedDataset: [grouping expressions: [")
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 6104104c7be..c64a123e3a7 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -684,6 +684,25 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
     }
   }
 
+  /**
+   * Strategy to convert [[FlatMapGroupsInPandasWithState]] logical operator 
to physical operator
+   * in streaming plans. Conversion for batch plans is handled by 
[[BasicOperators]].
+   */
+  object FlatMapGroupsInPandasWithStateStrategy extends Strategy {
+    override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+      case FlatMapGroupsInPandasWithState(
+        func, groupAttr, outputAttr, stateType, outputMode, timeout, child) =>
+        val stateVersion = 
conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION)
+        val execPlan = python.FlatMapGroupsInPandasWithStateExec(
+          func, groupAttr, outputAttr, stateType, None, stateVersion, 
outputMode, timeout,
+          batchTimestampMs = None, eventTimeWatermark = None, planLater(child)
+        )
+        execPlan :: Nil
+      case _ =>
+        Nil
+    }
+  }
+
   /**
    * Strategy to convert EvalPython logical operator to physical operator.
    */
@@ -793,6 +812,10 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
           initialStateGroupAttrs, data, initialStateDataAttrs, output, timeout,
           hasInitialState, planLater(initialState), planLater(child)
         ) :: Nil
+      case _: FlatMapGroupsInPandasWithState =>
+        // TODO(SPARK-40443): support applyInPandasWithState in batch query
+        throw new UnsupportedOperationException(
+          "applyInPandasWithState is unsupported in batch query. Use 
applyInPandas instead.")
       case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, 
oAttr, left, right) =>
         execution.CoGroupExec(
           f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr,
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
index 7abca5f0e33..2988c0fb518 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
@@ -44,7 +44,7 @@ object ArrowWriter {
     new ArrowWriter(root, children.toArray)
   }
 
-  private def createFieldWriter(vector: ValueVector): ArrowFieldWriter = {
+  private[sql] def createFieldWriter(vector: ValueVector): ArrowFieldWriter = {
     val field = vector.getField()
     (ArrowUtils.fromArrowField(field), vector) match {
       case (BooleanType, vector: BitVector) => new BooleanWriter(vector)
@@ -98,6 +98,16 @@ class ArrowWriter(val root: VectorSchemaRoot, fields: 
Array[ArrowFieldWriter]) {
     count += 1
   }
 
+  def sizeInBytes(): Int = {
+    var i = 0
+    var bytes = 0
+    while (i < fields.size) {
+      bytes += fields(i).getSizeInBytes()
+      i += 1
+    }
+    bytes
+  }
+
   def finish(): Unit = {
     root.setRowCount(count)
     fields.foreach(_.finish())
@@ -132,6 +142,10 @@ private[arrow] abstract class ArrowFieldWriter {
     count += 1
   }
 
+  def getSizeInBytes(): Int = {
+    valueVector.getBufferSizeFor(count)
+  }
+
   def finish(): Unit = {
     valueVector.setValueCount(count)
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
new file mode 100644
index 00000000000..bd8c72029dc
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
@@ -0,0 +1,223 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import java.io._
+
+import scala.collection.JavaConverters._
+
+import org.apache.arrow.vector.VectorSchemaRoot
+import org.apache.arrow.vector.ipc.ArrowStreamWriter
+import org.json4s._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.api.python._
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.api.python.PythonSQLUtils
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import 
org.apache.spark.sql.execution.python.ApplyInPandasWithStatePythonRunner.{InType,
 OutType, OutTypeForState, STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER}
+import 
org.apache.spark.sql.execution.python.ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA
+import org.apache.spark.sql.execution.streaming.GroupStateImpl
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
+
+
+/**
+ * A variant implementation of [[ArrowPythonRunner]] to serve the operation
+ * applyInPandasWithState.
+ *
+ * Unlike normal ArrowPythonRunner which both input and output (executor <-> 
python worker)
+ * are InternalRow, applyInPandasWithState has side data (state information) 
in both input
+ * and output along with data, which requires different struct on Arrow 
RecordBatch.
+ */
+class ApplyInPandasWithStatePythonRunner(
+    funcs: Seq[ChainedPythonFunctions],
+    evalType: Int,
+    argOffsets: Array[Array[Int]],
+    inputSchema: StructType,
+    override protected val timeZoneId: String,
+    initialWorkerConf: Map[String, String],
+    stateEncoder: ExpressionEncoder[Row],
+    keySchema: StructType,
+    outputSchema: StructType,
+    stateValueSchema: StructType)
+  extends BasePythonRunner[InType, OutType](funcs, evalType, argOffsets)
+  with PythonArrowInput[InType]
+  with PythonArrowOutput[OutType] {
+
+  private val sqlConf = SQLConf.get
+
+  override protected val schema: StructType = inputSchema.add("__state", 
STATE_METADATA_SCHEMA)
+
+  override val simplifiedTraceback: Boolean = 
sqlConf.pysparkSimplifiedTraceback
+
+  override val bufferSize: Int = {
+    val configuredSize = sqlConf.pandasUDFBufferSize
+    if (configuredSize < 4) {
+      logWarning("Pandas execution requires more than 4 bytes. Please 
configure bigger value " +
+        s"for the configuration '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'. " +
+        "Force using the value '4'.")
+      4
+    } else {
+      configuredSize
+    }
+  }
+
+  private val arrowMaxRecordsPerBatch = sqlConf.arrowMaxRecordsPerBatch
+
+  // applyInPandasWithState has its own mechanism to construct the Arrow 
RecordBatch instance.
+  // Configurations are both applied to executor and Python worker, set them 
to the worker conf
+  // to let Python worker read the config properly.
+  override protected val workerConf: Map[String, String] = initialWorkerConf +
+    (SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> 
arrowMaxRecordsPerBatch.toString)
+
+  private val stateRowDeserializer = stateEncoder.createDeserializer()
+
+  /**
+   * This method sends out the additional metadata before sending out actual 
data.
+   *
+   * Specifically, this class overrides this method to also write the schema 
for state value.
+   */
+  override protected def handleMetadataBeforeExec(stream: DataOutputStream): 
Unit = {
+    super.handleMetadataBeforeExec(stream)
+    // Also write the schema for state value
+    PythonRDD.writeUTF(stateValueSchema.json, stream)
+  }
+
+  /**
+   * Read the (key, state, values) from input iterator and construct Arrow 
RecordBatches, and
+   * write constructed RecordBatches to the writer.
+   *
+   * See [[ApplyInPandasWithStateWriter]] for more details.
+   */
+  protected def writeIteratorToArrowStream(
+      root: VectorSchemaRoot,
+      writer: ArrowStreamWriter,
+      dataOut: DataOutputStream,
+      inputIterator: Iterator[InType]): Unit = {
+    val w = new ApplyInPandasWithStateWriter(root, writer, 
arrowMaxRecordsPerBatch)
+
+    while (inputIterator.hasNext) {
+      val (keyRow, groupState, dataIter) = inputIterator.next()
+      assert(dataIter.hasNext, "should have at least one data row!")
+      w.startNewGroup(keyRow, groupState)
+
+      while (dataIter.hasNext) {
+        val dataRow = dataIter.next()
+        w.writeRow(dataRow)
+      }
+
+      w.finalizeGroup()
+    }
+
+    w.finalizeData()
+  }
+
+  /**
+   * Deserialize ColumnarBatch received from the Python worker to produce the 
output. Schema info
+   * for given ColumnarBatch is also provided as well.
+   */
+  protected def deserializeColumnarBatch(batch: ColumnarBatch, schema: 
StructType): OutType = {
+    // This should at least have one row for state. Also, we ensure that all 
columns across
+    // data and state metadata have same number of rows, which is required by 
Arrow record
+    // batch.
+    assert(batch.numRows() > 0)
+    assert(schema.length == 2)
+
+    def getColumnarBatchForStructTypeColumn(
+        batch: ColumnarBatch,
+        ordinal: Int,
+        expectedType: StructType): ColumnarBatch = {
+      //  UDF returns a StructType column in ColumnarBatch, select the 
children here
+      val structVector = batch.column(ordinal).asInstanceOf[ArrowColumnVector]
+      val dataType = schema(ordinal).dataType.asInstanceOf[StructType]
+      assert(dataType.sameType(expectedType),
+        s"Schema equality check failure! type from Arrow: $dataType, expected 
type: $expectedType")
+
+      val outputVectors = dataType.indices.map(structVector.getChild)
+      val flattenedBatch = new ColumnarBatch(outputVectors.toArray)
+      flattenedBatch.setNumRows(batch.numRows())
+
+      flattenedBatch
+    }
+
+    def constructIterForData(batch: ColumnarBatch): Iterator[InternalRow] = {
+      val dataBatch = getColumnarBatchForStructTypeColumn(batch, 0, 
outputSchema)
+      dataBatch.rowIterator.asScala.flatMap { row =>
+        if (row.isNullAt(0)) {
+          // The entire row in record batch seems to be for state metadata.
+          None
+        } else {
+          Some(row)
+        }
+      }
+    }
+
+    def constructIterForState(batch: ColumnarBatch): Iterator[OutTypeForState] 
= {
+      val stateMetadataBatch = getColumnarBatchForStructTypeColumn(batch, 1,
+        STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER)
+
+      stateMetadataBatch.rowIterator().asScala.flatMap { row =>
+        implicit val formats = org.json4s.DefaultFormats
+
+        if (row.isNullAt(0)) {
+          // The entire row in record batch seems to be for data.
+          None
+        } else {
+          // NOTE: See 
ApplyInPandasWithStatePythonRunner.STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER
+          // for the schema.
+          val propertiesAsJson = parse(row.getUTF8String(0).toString)
+          val keyRowAsUnsafeAsBinary = row.getBinary(1)
+          val keyRowAsUnsafe = new UnsafeRow(keySchema.fields.length)
+          keyRowAsUnsafe.pointTo(keyRowAsUnsafeAsBinary, 
keyRowAsUnsafeAsBinary.length)
+          val maybeObjectRow = if (row.isNullAt(2)) {
+            None
+          } else {
+            val pickledStateValue = row.getBinary(2)
+            Some(PythonSQLUtils.toJVMRow(pickledStateValue, stateValueSchema,
+              stateRowDeserializer))
+          }
+          val oldTimeoutTimestamp = row.getLong(3)
+
+          Some((keyRowAsUnsafe, GroupStateImpl.fromJson(maybeObjectRow, 
propertiesAsJson),
+            oldTimeoutTimestamp))
+        }
+      }
+    }
+
+    (constructIterForState(batch), constructIterForData(batch))
+  }
+}
+
+object ApplyInPandasWithStatePythonRunner {
+  type InType = (UnsafeRow, GroupStateImpl[Row], Iterator[InternalRow])
+  type OutTypeForState = (UnsafeRow, GroupStateImpl[Row], Long)
+  type OutType = (Iterator[OutTypeForState], Iterator[InternalRow])
+
+  val STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER: StructType = StructType(
+    Array(
+      StructField("properties", StringType),
+      StructField("keyRowAsUnsafe", BinaryType),
+      StructField("object", BinaryType),
+      StructField("oldTimeoutTimestamp", LongType)
+    )
+  )
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala
new file mode 100644
index 00000000000..60a228ddd73
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala
@@ -0,0 +1,276 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import scala.collection.JavaConverters._
+
+import org.apache.arrow.vector.{FieldVector, VectorSchemaRoot}
+import org.apache.arrow.vector.ipc.ArrowStreamWriter
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.api.python.PythonSQLUtils
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, 
UnsafeRow}
+import org.apache.spark.sql.execution.arrow.ArrowWriter
+import org.apache.spark.sql.execution.arrow.ArrowWriter.createFieldWriter
+import org.apache.spark.sql.execution.streaming.GroupStateImpl
+import org.apache.spark.sql.types.{BinaryType, BooleanType, IntegerType, 
StringType, StructField, StructType}
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * This class abstracts the complexity on constructing Arrow RecordBatches for 
data and state with
+ * bin-packing and chunking. The caller only need to call the proper public 
methods of this class
+ * `startNewGroup`, `writeRow`, `finalizeGroup`, `finalizeData` and this class 
will write the data
+ * and state into Arrow RecordBatches with performing bin-pack and chunk 
internally.
+ *
+ * This class requires that the parameter `root` has been initialized with the 
Arrow schema like
+ * below:
+ * - data fields
+ * - state field
+ *   - nested schema (Refer ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA)
+ *
+ * Please refer the code comment in the implementation to see how the writes 
of data and state
+ * against Arrow RecordBatch work with consideration of bin-packing and 
chunking.
+ */
+class ApplyInPandasWithStateWriter(
+    root: VectorSchemaRoot,
+    writer: ArrowStreamWriter,
+    arrowMaxRecordsPerBatch: Int) {
+
+  import ApplyInPandasWithStateWriter._
+
+  // Unlike applyInPandas (and other PySpark operators), 
applyInPandasWithState requires to produce
+  // the additional data `state`, along with the input data.
+  //
+  // ArrowStreamWriter supports only single VectorSchemaRoot, which means all 
Arrow RecordBatches
+  // being sent out from ArrowStreamWriter should have same schema. That said, 
we have to construct
+  // "an" Arrow schema to contain both data and state, and also construct 
ArrowBatches to contain
+  // both data and state.
+  //
+  // To achieve this, we extend the schema for input data to have a column for 
state at the end.
+  // But also, we logically group the columns by family (data vs state) and 
initialize writer
+  // separately, since it's lot more easier and probably performant to write 
the row directly
+  // rather than projecting the row to match up with the overall schema.
+  //
+  // Although Arrow RecordBatch enables to write the data as columnar, we 
figure out it gives
+  // strange outputs if we don't ensure that all columns have the same number 
of values. Since
+  // there are at least one data for a grouping key (we ensure this for the 
case of handling timed
+  // out state as well) whereas there is only one state for a grouping key, we 
have to fill up the
+  // empty rows in state side to ensure both have the same number of rows.
+  private val arrowWriterForData = createArrowWriter(
+    root.getFieldVectors.asScala.toSeq.dropRight(1))
+  private val arrowWriterForState = createArrowWriter(
+    root.getFieldVectors.asScala.toSeq.takeRight(1))
+
+  // - Bin-packing
+  //
+  // We apply bin-packing the data from multiple groups into one Arrow 
RecordBatch to
+  // gain the performance. In many cases, the amount of data per grouping key 
is quite
+  // small, which does not seem to maximize the benefits of using Arrow.
+  //
+  // We have to split the record batch down to each group in Python worker to 
convert the
+  // data for group to Pandas, but hopefully, Arrow RecordBatch provides the 
way to split
+  // the range of data and give a view, say, "zero-copy". To help splitting 
the range for
+  // data, we provide the "start offset" and the "number of data" in the state 
metadata.
+  //
+  // We don't bin-pack all groups into a single record batch - we have a limit 
on the number
+  // of rows in the current Arrow RecordBatch to stop adding next group.
+  //
+  // - Chunking
+  //
+  // We also chunk the data from single group into multiple Arrow RecordBatch 
to ensure
+  // scalability. Note that we don't know the volume (number of rows, overall 
size) of data for
+  // specific group key before we read the entire data. The easiest approach 
to address both
+  // bin-pack and chunk is to check the number of rows in the current Arrow 
RecordBatch for each
+  // write of row.
+  //
+  // - Data and State
+  //
+  // Since we apply bin-packing and chunking, there should be the way to 
distinguish each chunk
+  // from the entire data part of Arrow RecordBatch. We leverage the state 
metadata to also
+  // contain the "metadata" of data part to distinguish the chunk from the 
entire data.
+  // As a result, state metadata has a 1-1 relationship with "chunk", instead 
of "grouping key".
+  //
+  // - Consideration
+  //
+  // Since the number of rows in Arrow RecordBatch does not represent the 
actual size (bytes),
+  // the limit should be set very conservatively. Using a small number of 
limit does not introduce
+  // correctness issues.
+
+  // variables for tracking current grouping key and state
+  private var currentGroupKeyRow: UnsafeRow = _
+  private var currentGroupState: GroupStateImpl[Row] = _
+
+  // variables for tracking the status of current batch
+  private var totalNumRowsForBatch = 0
+  private var totalNumStatesForBatch = 0
+
+  // variables for tracking the status of current chunk
+  private var startOffsetForCurrentChunk = 0
+  private var numRowsForCurrentChunk = 0
+
+
+  /**
+   * Indicates writer to start with new grouping key.
+   *
+   * @param keyRow The grouping key row for current group.
+   * @param groupState The instance of GroupStateImpl for current group.
+   */
+  def startNewGroup(keyRow: UnsafeRow, groupState: GroupStateImpl[Row]): Unit 
= {
+    currentGroupKeyRow = keyRow
+    currentGroupState = groupState
+  }
+
+  /**
+   * Indicates writer to write a row in the current group.
+   *
+   * @param dataRow The row to write in the current group.
+   */
+  def writeRow(dataRow: InternalRow): Unit = {
+    // If it exceeds the condition of batch (number of records) and there is 
more data for the
+    // same group, finalize and construct a new batch.
+
+    if (totalNumRowsForBatch >= arrowMaxRecordsPerBatch) {
+      finalizeCurrentChunk(isLastChunkForGroup = false)
+      finalizeCurrentArrowBatch()
+    }
+
+    arrowWriterForData.write(dataRow)
+
+    numRowsForCurrentChunk += 1
+    totalNumRowsForBatch += 1
+  }
+
+  /**
+   * Indicates writer that current group has finalized and there will be no 
further row bound to
+   * the current group.
+   */
+  def finalizeGroup(): Unit = {
+    finalizeCurrentChunk(isLastChunkForGroup = true)
+
+    // If it exceeds the condition of batch (number of records) once the all 
data is received for
+    // same group, finalize and construct a new batch.
+    if (totalNumRowsForBatch >= arrowMaxRecordsPerBatch) {
+      finalizeCurrentArrowBatch()
+    }
+  }
+
+  /**
+   * Indicates writer that all groups have been processed.
+   */
+  def finalizeData(): Unit = {
+    if (totalNumRowsForBatch > 0) {
+      // We still have some rows in the current record batch. Need to finalize 
them as well.
+      finalizeCurrentArrowBatch()
+    }
+  }
+
+  private def createArrowWriter(fieldVectors: Seq[FieldVector]): ArrowWriter = 
{
+    val children = fieldVectors.map { vector =>
+      vector.allocateNew()
+      createFieldWriter(vector)
+    }
+
+    new ArrowWriter(root, children.toArray)
+  }
+
+  private def buildStateInfoRow(
+      keyRow: UnsafeRow,
+      groupState: GroupStateImpl[Row],
+      startOffset: Int,
+      numRows: Int,
+      isLastChunk: Boolean): InternalRow = {
+    // NOTE: see ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA
+    val stateUnderlyingRow = new GenericInternalRow(
+      Array[Any](
+        UTF8String.fromString(groupState.json()),
+        keyRow.getBytes,
+        groupState.getOption.map(PythonSQLUtils.toPyRow).orNull,
+        startOffset,
+        numRows,
+        isLastChunk
+      )
+    )
+    new GenericInternalRow(Array[Any](stateUnderlyingRow))
+  }
+
+  private def finalizeCurrentChunk(isLastChunkForGroup: Boolean): Unit = {
+    val stateInfoRow = buildStateInfoRow(currentGroupKeyRow, currentGroupState,
+      startOffsetForCurrentChunk, numRowsForCurrentChunk, isLastChunkForGroup)
+    arrowWriterForState.write(stateInfoRow)
+    totalNumStatesForBatch += 1
+
+    // The start offset for next chunk would be same as the total number of 
rows for batch,
+    // unless the next chunk starts with new batch.
+    startOffsetForCurrentChunk = totalNumRowsForBatch
+    numRowsForCurrentChunk = 0
+  }
+
+  private def finalizeCurrentArrowBatch(): Unit = {
+    val remainingEmptyStateRows = totalNumRowsForBatch - totalNumStatesForBatch
+    (0 until remainingEmptyStateRows).foreach { _ =>
+      arrowWriterForState.write(EMPTY_STATE_METADATA_ROW)
+    }
+
+    arrowWriterForState.finish()
+    arrowWriterForData.finish()
+    writer.writeBatch()
+    arrowWriterForState.reset()
+    arrowWriterForData.reset()
+
+    startOffsetForCurrentChunk = 0
+    numRowsForCurrentChunk = 0
+    totalNumRowsForBatch = 0
+    totalNumStatesForBatch = 0
+  }
+}
+
+object ApplyInPandasWithStateWriter {
+  // This schema contains both state metadata and the metadata of the chunk. 
Refer the code comment
+  // of "Data and State" for more details.
+  val STATE_METADATA_SCHEMA: StructType = StructType(
+    Array(
+      /*
+       Metadata of the state
+       */
+
+      // properties of state instance (excluding state value) in json format
+      StructField("properties", StringType),
+      // key row as UnsafeRow, Python worker won't touch this value but send 
the value back to
+      // executor when sending an update of state
+      StructField("keyRowAsUnsafe", BinaryType),
+      // state value
+      StructField("object", BinaryType),
+
+      /*
+       Metadata of the chunk
+       */
+
+      // start offset of the data chunk from entire data
+      StructField("startOffset", IntegerType),
+      // the number of rows for the data chunk
+      StructField("numRows", IntegerType),
+      // whether the current data chunk is the last one for current grouping 
key or not
+      StructField("isLastChunk", BooleanType)
+    )
+  )
+
+  // To avoid initializing a new row for empty state metadata row.
+  val EMPTY_STATE_METADATA_ROW = new GenericInternalRow(
+    Array[Any](null, null, null, null, null, null))
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala
index e830ea6b546..b39787b12a4 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala
@@ -78,8 +78,8 @@ case class FlatMapCoGroupsInPandasExec(
 
   override protected def doExecute(): RDD[InternalRow] = {
 
-    val (leftDedup, leftArgOffsets) = resolveArgOffsets(left, leftGroup)
-    val (rightDedup, rightArgOffsets) = resolveArgOffsets(right, rightGroup)
+    val (leftDedup, leftArgOffsets) = resolveArgOffsets(left.output, leftGroup)
+    val (rightDedup, rightArgOffsets) = resolveArgOffsets(right.output, 
rightGroup)
 
     // Map cogrouped rows to ArrowPythonRunner results, Only execute if 
partition is not empty
     left.execute().zipPartitions(right.execute())  { (leftData, rightData) =>
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
index 3a3a6022f99..f0e815e966e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
@@ -75,7 +75,7 @@ case class FlatMapGroupsInPandasExec(
   override protected def doExecute(): RDD[InternalRow] = {
     val inputRDD = child.execute()
 
-    val (dedupAttributes, argOffsets) = resolveArgOffsets(child, 
groupingAttributes)
+    val (dedupAttributes, argOffsets) = resolveArgOffsets(child.output, 
groupingAttributes)
 
     // Map grouped rows to ArrowPythonRunner results, Only execute if 
partition is not empty
     inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala
new file mode 100644
index 00000000000..159f805f734
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.python
+
+import org.apache.spark.TaskContext
+import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, 
ProcessingTimeTimeout}
+import org.apache.spark.sql.catalyst.plans.physical.Distribution
+import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, 
UnaryExecNode}
+import org.apache.spark.sql.execution.python.PandasGroupUtils.resolveArgOffsets
+import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP
+import 
org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper.StateData
+import org.apache.spark.sql.execution.streaming.state.StateStore
+import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.ArrowUtils
+import org.apache.spark.util.CompletionIterator
+
+/**
+ * Physical operator for executing
+ * 
[[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandasWithState]]
+ *
+ * @param functionExpr function called on each group
+ * @param groupingAttributes used to group the data
+ * @param outAttributes used to define the output rows
+ * @param stateType used to serialize/deserialize state before calling 
`functionExpr`
+ * @param stateInfo `StatefulOperatorStateInfo` to identify the state store 
for a given operator.
+ * @param stateFormatVersion the version of state format.
+ * @param outputMode the output mode of `functionExpr`
+ * @param timeoutConf used to timeout groups that have not received data in a 
while
+ * @param batchTimestampMs processing timestamp of the current batch.
+ * @param eventTimeWatermark event time watermark for the current batch
+ * @param child logical plan of the underlying data
+ */
+case class FlatMapGroupsInPandasWithStateExec(
+    functionExpr: Expression,
+    groupingAttributes: Seq[Attribute],
+    outAttributes: Seq[Attribute],
+    stateType: StructType,
+    stateInfo: Option[StatefulOperatorStateInfo],
+    stateFormatVersion: Int,
+    outputMode: OutputMode,
+    timeoutConf: GroupStateTimeout,
+    batchTimestampMs: Option[Long],
+    eventTimeWatermark: Option[Long],
+    child: SparkPlan) extends UnaryExecNode with 
FlatMapGroupsWithStateExecBase {
+
+  // TODO(SPARK-40444): Add the support of initial state.
+  override protected val initialStateDeserializer: Expression = null
+  override protected val initialStateGroupAttrs: Seq[Attribute] = null
+  override protected val initialStateDataAttrs: Seq[Attribute] = null
+  override protected val initialState: SparkPlan = null
+  override protected val hasInitialState: Boolean = false
+
+  override protected val stateEncoder: ExpressionEncoder[Any] =
+    RowEncoder(stateType).resolveAndBind().asInstanceOf[ExpressionEncoder[Any]]
+
+  override def output: Seq[Attribute] = outAttributes
+
+  private val sessionLocalTimeZone = conf.sessionLocalTimeZone
+  private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
+
+  private val pythonFunction = functionExpr.asInstanceOf[PythonUDF].func
+  private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction)))
+  private lazy val (dedupAttributes, argOffsets) = resolveArgOffsets(
+    groupingAttributes ++ child.output, groupingAttributes)
+  private lazy val unsafeProj = UnsafeProjection.create(dedupAttributes, 
child.output)
+
+  override def requiredChildDistribution: Seq[Distribution] =
+    StatefulOperatorPartitioning.getCompatibleDistribution(
+      groupingAttributes, getStateInfo, conf) :: Nil
+
+  override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq(
+    groupingAttributes.map(SortOrder(_, Ascending)))
+
+  override def shortName: String = "applyInPandasWithState"
+
+  override protected def withNewChildInternal(
+      newChild: SparkPlan): FlatMapGroupsInPandasWithStateExec = copy(child = 
newChild)
+
+  override def createInputProcessor(
+      store: StateStore): InputProcessor = new InputProcessor(store: 
StateStore) {
+
+    override def processNewData(dataIter: Iterator[InternalRow]): 
Iterator[InternalRow] = {
+      val groupedIter = GroupedIterator(dataIter, groupingAttributes, 
child.output)
+      val processIter = groupedIter.map { case (keyRow, valueRowIter) =>
+        val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow]
+        val stateData = stateManager.getState(store, keyUnsafeRow)
+        (keyUnsafeRow, stateData, valueRowIter.map(unsafeProj))
+      }
+
+      process(processIter, hasTimedOut = false)
+    }
+
+    override def processNewDataWithInitialState(
+        childDataIter: Iterator[InternalRow],
+        initStateIter: Iterator[InternalRow]): Iterator[InternalRow] = {
+      throw new UnsupportedOperationException("Should not reach here!")
+    }
+
+    override def processTimedOutState(): Iterator[InternalRow] = {
+      if (isTimeoutEnabled) {
+        val timeoutThreshold = timeoutConf match {
+          case ProcessingTimeTimeout => batchTimestampMs.get
+          case EventTimeTimeout => eventTimeWatermark.get
+          case _ =>
+            throw new IllegalStateException(
+              s"Cannot filter timed out keys for $timeoutConf")
+        }
+        val timingOutPairs = stateManager.getAllState(store).filter { state =>
+          state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < 
timeoutThreshold
+        }
+
+        val processIter = timingOutPairs.map { stateData =>
+          val joinedKeyRow = unsafeProj(
+            new JoinedRow(
+              stateData.keyRow,
+              new GenericInternalRow(Array.fill(dedupAttributes.length)(null: 
Any))))
+
+          (stateData.keyRow, stateData, Iterator.single(joinedKeyRow))
+        }
+
+        process(processIter, hasTimedOut = true)
+      } else Iterator.empty
+    }
+
+    private def process(
+        iter: Iterator[(UnsafeRow, StateData, Iterator[InternalRow])],
+        hasTimedOut: Boolean): Iterator[InternalRow] = {
+      val runner = new ApplyInPandasWithStatePythonRunner(
+        chainedFunc,
+        PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
+        Array(argOffsets),
+        StructType.fromAttributes(dedupAttributes),
+        sessionLocalTimeZone,
+        pythonRunnerConf,
+        stateEncoder.asInstanceOf[ExpressionEncoder[Row]],
+        groupingAttributes.toStructType,
+        outAttributes.toStructType,
+        stateType)
+
+      val context = TaskContext.get()
+
+      val processIter = iter.map { case (keyRow, stateData, valueIter) =>
+        val groupedState = GroupStateImpl.createForStreaming(
+          Option(stateData.stateObj).map { r => assert(r.isInstanceOf[Row]); r 
},
+          batchTimestampMs.getOrElse(NO_TIMESTAMP),
+          eventTimeWatermark.getOrElse(NO_TIMESTAMP),
+          timeoutConf,
+          hasTimedOut = hasTimedOut,
+          watermarkPresent).asInstanceOf[GroupStateImpl[Row]]
+        (keyRow, groupedState, valueIter)
+      }
+      runner.compute(processIter, context.partitionId(), context).flatMap {
+        case (stateIter, outputIter) =>
+          // When the iterator is consumed, then write changes to state.
+          // state does not affect each others, hence when to update does not 
affect to the result.
+          def onIteratorCompletion: Unit = {
+            stateIter.foreach { case (keyRow, newGroupState, 
oldTimeoutTimestamp) =>
+              if (newGroupState.isRemoved && 
!newGroupState.getTimeoutTimestampMs.isPresent()) {
+                stateManager.removeState(store, keyRow)
+                numRemovedStateRows += 1
+              } else {
+                val currentTimeoutTimestamp = 
newGroupState.getTimeoutTimestampMs
+                  .orElse(NO_TIMESTAMP)
+                val hasTimeoutChanged = currentTimeoutTimestamp != 
oldTimeoutTimestamp
+                val shouldWriteState = newGroupState.isUpdated || 
newGroupState.isRemoved ||
+                  hasTimeoutChanged
+
+                if (shouldWriteState) {
+                  val updatedStateObj = if (newGroupState.exists) 
newGroupState.get else null
+                  stateManager.putState(store, keyRow, updatedStateObj,
+                    currentTimeoutTimestamp)
+                  numUpdatedStateRows += 1
+                }
+              }
+            }
+          }
+
+          CompletionIterator[InternalRow, Iterator[InternalRow]](
+            outputIter, onIteratorCompletion).map { row =>
+            numOutputRows += 1
+            row
+          }
+      }
+    }
+
+    override protected def callFunctionAndUpdateState(
+        stateData: StateData,
+        valueRowIter: Iterator[InternalRow],
+        hasTimedOut: Boolean): Iterator[InternalRow] = {
+      throw new UnsupportedOperationException("Should not reach here!")
+    }
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala
index 2da0000dad4..07887666406 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala
@@ -24,7 +24,7 @@ import org.apache.spark.TaskContext
 import org.apache.spark.api.python.BasePythonRunner
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection}
-import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan}
+import org.apache.spark.sql.execution.GroupedIterator
 import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
 
 /**
@@ -88,9 +88,10 @@ private[python] object PandasGroupUtils {
    * argOffsets[argOffsets[0]+2 .. ] is the arg offsets for data attributes
    */
   def resolveArgOffsets(
-    child: SparkPlan, groupingAttributes: Seq[Attribute]): (Seq[Attribute], 
Array[Int]) = {
+      attributes: Seq[Attribute],
+      groupingAttributes: Seq[Attribute]): (Seq[Attribute], Array[Int]) = {
 
-    val dataAttributes = child.output.drop(groupingAttributes.length)
+    val dataAttributes = attributes.drop(groupingAttributes.length)
     val groupingIndicesInData = groupingAttributes.map { attribute =>
       dataAttributes.indexWhere(attribute.semanticEquals)
     }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
index 6168d0f867a..bf66791183e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
@@ -76,7 +76,6 @@ private[python] trait PythonArrowInput[IN] { self: 
BasePythonRunner[IN, _] =>
         val root = VectorSchemaRoot.create(arrowSchema, allocator)
 
         Utils.tryWithSafeFinally {
-          val arrowWriter = ArrowWriter.create(root)
           val writer = new ArrowStreamWriter(root, null, dataOut)
           writer.start()
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index 3f369ac5e97..f386282a0b3 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern._
 import org.apache.spark.sql.execution.{LocalLimitExec, QueryExecution, 
SparkPlan, SparkPlanner, UnaryExecNode}
 import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, 
MergingSessionsExec, ObjectHashAggregateExec, SortAggregateExec, 
UpdatingSessionsExec}
 import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
+import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasWithStateExec
 import 
org.apache.spark.sql.execution.streaming.sources.WriteToMicroBatchDataSourceV1
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.OutputMode
@@ -62,6 +63,7 @@ class IncrementalExecution(
       StreamingJoinStrategy ::
       StatefulAggregationStrategy ::
       FlatMapGroupsWithStateStrategy ::
+      FlatMapGroupsInPandasWithStateStrategy ::
       StreamingRelationStrategy ::
       StreamingDeduplicationStrategy ::
       StreamingGlobalLimitStrategy(outputMode) :: Nil
@@ -210,6 +212,13 @@ class IncrementalExecution(
           hasInitialState = hasInitialState
         )
 
+      case m: FlatMapGroupsInPandasWithStateExec =>
+        m.copy(
+          stateInfo = Some(nextStatefulOperationStateInfo),
+          batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs),
+          eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs)
+        )
+
       case j: StreamingSymmetricHashJoinExec =>
         j.copy(
           stateInfo = Some(nextStatefulOperationStateInfo),
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
index 01ff72bac7b..022fd1239ce 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
@@ -49,7 +49,7 @@ package object state {
     }
 
     /** Map each partition of an RDD along with data in a [[StateStore]]. */
-    private[streaming] def mapPartitionsWithStateStore[U: ClassTag](
+    def mapPartitionsWithStateStore[U: ClassTag](
         stateInfo: StatefulOperatorStateInfo,
         keySchema: StructType,
         valueSchema: StructType,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to