zhengruifeng commented on code in PR #52716:
URL: https://github.com/apache/spark/pull/52716#discussion_r2467769396


##########
python/pyspark/worker.py:
##########
@@ -2625,6 +2664,12 @@ def read_udfs(pickleSer, infile, eval_type):
             ser = GroupPandasUDFSerializer(
                 timezone, safecheck, _assign_cols_by_name, 
int_to_decimal_coercion_enabled
             )
+        elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF:
+            from pyspark.sql.pandas.serializers import 
GroupPandasIterUDFSerializer

Review Comment:
   let's put the import here
   
   
https://github.com/apache/spark/blob/57b4cd2ccc646990e22643388d512e997cd4299e/python/pyspark/worker.py#L54-L73



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -214,22 +220,86 @@ def applyInPandas(
         |  2|          2| 3.0|
         +---+-----------+----+
 
+        The function can also take and return an iterator of 
`pandas.DataFrame` using type
+        hints.
+
+        >>> from typing import Iterator  # doctest: +SKIP
+        >>> df = spark.createDataFrame(
+        ...     [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
+        ...     ("id", "v"))  # doctest: +SKIP
+        >>> def filter_func(
+        ...     batches: Iterator[pd.DataFrame]
+        ... ) -> Iterator[pd.DataFrame]:  # doctest: +SKIP
+        ...     for batch in batches:
+        ...         # Process and yield each batch independently
+        ...         filtered = batch[batch['v'] > 2.0]
+        ...         if not filtered.empty:
+        ...             yield filtered[['v']]
+        >>> df.groupby("id").applyInPandas(
+        ...     filter_func, schema="v double").show()  # doctest: +SKIP
+        +----+
+        |   v|
+        +----+
+        | 3.0|
+        | 5.0|
+        |10.0|
+        +----+
+
+        Alternatively, the user can pass a function that takes two arguments.
+        In this case, the grouping key(s) will be passed as the first argument 
and the data will
+        be passed as the second argument. The grouping key(s) will be passed 
as a tuple of numpy
+        data types. The data will still be passed in as an iterator of 
`pandas.DataFrame`.
+
+        >>> from typing import Iterator, Tuple, Any  # doctest: +SKIP
+        >>> def transform_func(
+        ...     key: Tuple[Any, ...], batches: Iterator[pd.DataFrame]
+        ... ) -> Iterator[pd.DataFrame]:  # doctest: +SKIP
+        ...     for batch in batches:
+        ...         # Yield transformed results for each batch
+        ...         result = batch.assign(id=key[0], v_doubled=batch['v'] * 2)
+        ...         yield result[['id', 'v_doubled']]
+        >>> df.groupby("id").applyInPandas(
+        ...     transform_func, schema="id long, v_doubled double").show()  # 
doctest: +SKIP
+        +---+----------+
+        | id|v_doubled |
+        +---+----------+
+        |  1|       2.0|
+        |  1|       4.0|
+        |  2|       6.0|
+        |  2|      10.0|
+        |  2|      20.0|
+        +---+----------+
+
         Notes
         -----
-        This function requires a full shuffle. All the data of a group will be 
loaded
-        into memory, so the user should be aware of the potential OOM risk if 
data is skewed
-        and certain groups are too large to fit in memory.
+        This function requires a full shuffle. If using the `pandas.DataFrame` 
API, all data of a
+        group will be loaded into memory, so the user should be aware of the 
potential OOM risk if
+        data is skewed and certain groups are too large to fit in memory, and 
can use the
+        iterator of `pandas.DataFrame` API to mitigate this.
 
         See Also
         --------
         pyspark.sql.functions.pandas_udf
         """
         from pyspark.sql import GroupedData
         from pyspark.sql.functions import pandas_udf, PandasUDFType
+        from pyspark.sql.pandas.typehints import 
infer_group_pandas_eval_type_from_func
+        from pyspark.sql.pandas.functions import PythonEvalType
+        import warnings

Review Comment:
   not need to re-import `PythonEvalType` and `warnings`



##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -1245,6 +1245,89 @@ def __repr__(self):
         return "GroupPandasUDFSerializer"
 
 
+class GroupPandasIterUDFSerializer(ArrowStreamPandasUDFSerializer):
+    """
+    Serializer for grouped map Pandas iterator UDFs.
+    
+    Loads grouped data as pandas.Series and serializes results from iterator 
UDFs.
+    Flattens the (dataframes_generator, arrow_type) tuple by iterating over 
the generator.
+    """
+    
+    def __init__(
+        self,
+        timezone,
+        safecheck,
+        assign_cols_by_name,
+        int_to_decimal_coercion_enabled,
+    ):
+        super(GroupPandasIterUDFSerializer, self).__init__(
+            timezone=timezone,
+            safecheck=safecheck,
+            assign_cols_by_name=assign_cols_by_name,
+            df_for_struct=False,
+            struct_in_pandas="dict",
+            ndarray_as_list=False,
+            arrow_cast=True,
+            input_types=None,
+            int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
+        )
+
+    def load_stream(self, stream):
+        """
+        Deserialize Grouped ArrowRecordBatches and yield a generator of 
pandas.Series lists
+        (one list per batch), allowing the iterator UDF to process data 
batch-by-batch.
+        """
+        import pyarrow as pa
+
+        dataframes_in_group = None
+
+        while dataframes_in_group is None or dataframes_in_group > 0:
+            dataframes_in_group = read_int(stream)
+
+            if dataframes_in_group == 1:
+                # Read all Arrow batches for this group first (must read from 
stream synchronously)
+                batches = list(ArrowStreamSerializer.load_stream(self, stream))

Review Comment:
   I think we cannot load all batches here, the iterator API is designed to 
avoid loading all batches within a group so that it can migrate OOM
   
   you can refer to
   
   
https://github.com/apache/spark/blob/7bd18e3852f1a2160fcd0838f7d9937ea34926b4/python/pyspark/sql/pandas/serializers.py#L1136-L1140



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to