HyukjinKwon commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r977150660


##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,125 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a 
user-defined
+        per-group state. The result Dataset will represent the flattened 
record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked first for all 
input groups and then
+        for all timed out states where the input data is set to be empty. 
Updates to each group's
+        state will be saved across invocations.

Review Comment:
   ```suggestion
           For a streaming :class:`DataFrame`, the function will be invoked 
first for all input groups
           and then for all timed out states where the input data is set to be 
empty. Updates to
           each group's state will be saved across invocations.
   ```



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala:
##########
@@ -98,6 +98,16 @@ class ArrowWriter(val root: VectorSchemaRoot, fields: 
Array[ArrowFieldWriter]) {
     count += 1
   }
 
+  def sizeInBytes(): Int = {

Review Comment:
   I think we don't need `sizeInBytes` and  `getSizeInBytes ` anymore



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,125 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a 
user-defined
+        per-group state. The result Dataset will represent the flattened 
record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked first for all 
input groups and then
+        for all timed out states where the input data is set to be empty. 
Updates to each group's
+        state will be saved across invocations.
+
+        The function should take parameters (key, 
Iterator[`pandas.DataFrame`], state) and
+        return another Iterator[`pandas.DataFrame`]. The grouping key(s) will 
be passed as a tuple
+        of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The 
state will be passed as
+        :class:`pyspark.sql.streaming.state.GroupState`.
+
+        For each group, all columns are passed together as `pandas.DataFrame` 
to the user-function,
+        and the returned `pandas.DataFrame` across all invocations are 
combined as a
+        :class:`DataFrame`. Note that the user function should not make a 
guess of the number of
+        elements in the iterator. To process all data, the user function needs 
to iterate all
+        elements and process them. On the other hand, the user function is not 
strictly required to
+        iterate through all elements in the iterator if it intends to read a 
part of data.
+
+        The `outputStructType` should be a :class:`StructType` describing the 
schema of all
+        elements in the returned value, `pandas.DataFrame`. The column labels 
of all elements in
+        returned `pandas.DataFrame` must either match the field names in the 
defined schema if
+        specified as strings, or match the field data types by position if not 
strings,
+        e.g. integer indices.
+
+        The `stateStructType` should be :class:`StructType` describing the 
schema of the
+        user-defined state. The value of the state will be presented as a 
tuple, as well as the
+        update should be performed with the tuple. The corresponding Python 
types for
+        :class:DataType are supported. Please refer to the page
+        https://spark.apache.org/docs/latest/sql-ref-datatypes.html (python 
tab).
+
+        The size of each DataFrame in both the input and output can be 
arbitrary. The number of
+        DataFrames in both the input and output can also be arbitrary.
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        func : function
+            a Python native function to be called on every group. It should 
take parameters
+            (key, Iterator[`pandas.DataFrame`], state) and return 
Iterator[`pandas.DataFrame`].
+            Note that the type of the key is tuple and the type of the state is
+            :class:`pyspark.sql.streaming.state.GroupState`.
+        outputStructType : :class:`pyspark.sql.types.DataType` or str
+            the type of the output records. The value can be either a
+            :class:`pyspark.sql.types.DataType` object or a DDL-formatted type 
string.
+        stateStructType : :class:`pyspark.sql.types.DataType` or str
+            the type of the user-defined state. The value can be either a
+            :class:`pyspark.sql.types.DataType` object or a DDL-formatted type 
string.
+        outputMode : str
+            the output mode of the function.
+        timeoutConf : str
+            timeout configuration for groups that do not receive data for a 
while. valid values
+            are defined in 
:class:`pyspark.sql.streaming.state.GroupStateTimeout`.
+
+        Examples
+        --------
+        >>> import pandas as pd  # doctest: +SKIP
+        >>> from pyspark.sql.streaming.state import GroupStateTimeout
+        >>> def count_fn(key, pdf_iter, state):
+        ...     assert isinstance(state, GroupStateImpl)
+        ...     total_len = 0
+        ...     for pdf in pdf_iter:
+        ...         total_len += len(pdf)
+        ...     state.update((total_len,))
+        ...     yield pd.DataFrame({"id": [key[0]], "countAsString": 
[str(total_len)]})

Review Comment:
   ```suggestion
           ...     yield pd.DataFrame({"id": [key[0]], "countAsString": 
[str(total_len)]})
           ...
   ```



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,125 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a 
user-defined
+        per-group state. The result Dataset will represent the flattened 
record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked first for all 
input groups and then
+        for all timed out states where the input data is set to be empty. 
Updates to each group's
+        state will be saved across invocations.
+
+        The function should take parameters (key, 
Iterator[`pandas.DataFrame`], state) and
+        return another Iterator[`pandas.DataFrame`]. The grouping key(s) will 
be passed as a tuple
+        of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The 
state will be passed as
+        :class:`pyspark.sql.streaming.state.GroupState`.
+
+        For each group, all columns are passed together as `pandas.DataFrame` 
to the user-function,
+        and the returned `pandas.DataFrame` across all invocations are 
combined as a
+        :class:`DataFrame`. Note that the user function should not make a 
guess of the number of
+        elements in the iterator. To process all data, the user function needs 
to iterate all
+        elements and process them. On the other hand, the user function is not 
strictly required to
+        iterate through all elements in the iterator if it intends to read a 
part of data.
+
+        The `outputStructType` should be a :class:`StructType` describing the 
schema of all
+        elements in the returned value, `pandas.DataFrame`. The column labels 
of all elements in
+        returned `pandas.DataFrame` must either match the field names in the 
defined schema if
+        specified as strings, or match the field data types by position if not 
strings,
+        e.g. integer indices.
+
+        The `stateStructType` should be :class:`StructType` describing the 
schema of the
+        user-defined state. The value of the state will be presented as a 
tuple, as well as the
+        update should be performed with the tuple. The corresponding Python 
types for
+        :class:DataType are supported. Please refer to the page
+        https://spark.apache.org/docs/latest/sql-ref-datatypes.html (python 
tab).
+
+        The size of each DataFrame in both the input and output can be 
arbitrary. The number of
+        DataFrames in both the input and output can also be arbitrary.

Review Comment:
   ```suggestion
           The size of each `pandas.DataFrame` in both the input and output can 
be arbitrary.
           The number of DataFrames in both the input and output can also be 
arbitrary.
   ```



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,125 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a 
user-defined
+        per-group state. The result Dataset will represent the flattened 
record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked first for all 
input groups and then
+        for all timed out states where the input data is set to be empty. 
Updates to each group's
+        state will be saved across invocations.
+
+        The function should take parameters (key, 
Iterator[`pandas.DataFrame`], state) and
+        return another Iterator[`pandas.DataFrame`]. The grouping key(s) will 
be passed as a tuple
+        of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The 
state will be passed as
+        :class:`pyspark.sql.streaming.state.GroupState`.
+
+        For each group, all columns are passed together as `pandas.DataFrame` 
to the user-function,
+        and the returned `pandas.DataFrame` across all invocations are 
combined as a
+        :class:`DataFrame`. Note that the user function should not make a 
guess of the number of
+        elements in the iterator. To process all data, the user function needs 
to iterate all
+        elements and process them. On the other hand, the user function is not 
strictly required to
+        iterate through all elements in the iterator if it intends to read a 
part of data.
+
+        The `outputStructType` should be a :class:`StructType` describing the 
schema of all
+        elements in the returned value, `pandas.DataFrame`. The column labels 
of all elements in
+        returned `pandas.DataFrame` must either match the field names in the 
defined schema if
+        specified as strings, or match the field data types by position if not 
strings,
+        e.g. integer indices.
+
+        The `stateStructType` should be :class:`StructType` describing the 
schema of the
+        user-defined state. The value of the state will be presented as a 
tuple, as well as the
+        update should be performed with the tuple. The corresponding Python 
types for
+        :class:DataType are supported. Please refer to the page
+        https://spark.apache.org/docs/latest/sql-ref-datatypes.html (python 
tab).
+
+        The size of each DataFrame in both the input and output can be 
arbitrary. The number of
+        DataFrames in both the input and output can also be arbitrary.

Review Comment:
   I think we can extract some notes from the description to `Notes` section. 
But no biggie.



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,125 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a 
user-defined
+        per-group state. The result Dataset will represent the flattened 
record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked first for all 
input groups and then
+        for all timed out states where the input data is set to be empty. 
Updates to each group's
+        state will be saved across invocations.
+
+        The function should take parameters (key, 
Iterator[`pandas.DataFrame`], state) and
+        return another Iterator[`pandas.DataFrame`]. The grouping key(s) will 
be passed as a tuple
+        of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The 
state will be passed as
+        :class:`pyspark.sql.streaming.state.GroupState`.
+
+        For each group, all columns are passed together as `pandas.DataFrame` 
to the user-function,
+        and the returned `pandas.DataFrame` across all invocations are 
combined as a
+        :class:`DataFrame`. Note that the user function should not make a 
guess of the number of
+        elements in the iterator. To process all data, the user function needs 
to iterate all
+        elements and process them. On the other hand, the user function is not 
strictly required to
+        iterate through all elements in the iterator if it intends to read a 
part of data.
+
+        The `outputStructType` should be a :class:`StructType` describing the 
schema of all
+        elements in the returned value, `pandas.DataFrame`. The column labels 
of all elements in
+        returned `pandas.DataFrame` must either match the field names in the 
defined schema if
+        specified as strings, or match the field data types by position if not 
strings,
+        e.g. integer indices.
+
+        The `stateStructType` should be :class:`StructType` describing the 
schema of the
+        user-defined state. The value of the state will be presented as a 
tuple, as well as the
+        update should be performed with the tuple. The corresponding Python 
types for
+        :class:DataType are supported. Please refer to the page
+        https://spark.apache.org/docs/latest/sql-ref-datatypes.html (python 
tab).

Review Comment:
   ```suggestion
           https://spark.apache.org/docs/latest/sql-ref-datatypes.html (Python 
tab).
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to