Yicong-Huang commented on code in PR #52716:
URL: https://github.com/apache/spark/pull/52716#discussion_r2471042308


##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -214,22 +220,86 @@ def applyInPandas(
         |  2|          2| 3.0|
         +---+-----------+----+
 
+        The function can also take and return an iterator of 
`pandas.DataFrame` using type
+        hints.
+
+        >>> from typing import Iterator  # doctest: +SKIP
+        >>> df = spark.createDataFrame(
+        ...     [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
+        ...     ("id", "v"))  # doctest: +SKIP
+        >>> def filter_func(
+        ...     batches: Iterator[pd.DataFrame]
+        ... ) -> Iterator[pd.DataFrame]:  # doctest: +SKIP
+        ...     for batch in batches:
+        ...         # Process and yield each batch independently
+        ...         filtered = batch[batch['v'] > 2.0]
+        ...         if not filtered.empty:
+        ...             yield filtered[['v']]
+        >>> df.groupby("id").applyInPandas(
+        ...     filter_func, schema="v double").show()  # doctest: +SKIP
+        +----+
+        |   v|
+        +----+
+        | 3.0|
+        | 5.0|
+        |10.0|
+        +----+
+
+        Alternatively, the user can pass a function that takes two arguments.
+        In this case, the grouping key(s) will be passed as the first argument 
and the data will
+        be passed as the second argument. The grouping key(s) will be passed 
as a tuple of numpy
+        data types. The data will still be passed in as an iterator of 
`pandas.DataFrame`.
+
+        >>> from typing import Iterator, Tuple, Any  # doctest: +SKIP
+        >>> def transform_func(
+        ...     key: Tuple[Any, ...], batches: Iterator[pd.DataFrame]
+        ... ) -> Iterator[pd.DataFrame]:  # doctest: +SKIP
+        ...     for batch in batches:
+        ...         # Yield transformed results for each batch
+        ...         result = batch.assign(id=key[0], v_doubled=batch['v'] * 2)
+        ...         yield result[['id', 'v_doubled']]
+        >>> df.groupby("id").applyInPandas(
+        ...     transform_func, schema="id long, v_doubled double").show()  # 
doctest: +SKIP
+        +---+----------+
+        | id|v_doubled |
+        +---+----------+
+        |  1|       2.0|
+        |  1|       4.0|
+        |  2|       6.0|
+        |  2|      10.0|
+        |  2|      20.0|
+        +---+----------+
+
         Notes
         -----
-        This function requires a full shuffle. All the data of a group will be 
loaded
-        into memory, so the user should be aware of the potential OOM risk if 
data is skewed
-        and certain groups are too large to fit in memory.
+        This function requires a full shuffle. If using the `pandas.DataFrame` 
API, all data of a
+        group will be loaded into memory, so the user should be aware of the 
potential OOM risk if
+        data is skewed and certain groups are too large to fit in memory, and 
can use the
+        iterator of `pandas.DataFrame` API to mitigate this.
 
         See Also
         --------
         pyspark.sql.functions.pandas_udf
         """
         from pyspark.sql import GroupedData
         from pyspark.sql.functions import pandas_udf, PandasUDFType
+        from pyspark.sql.pandas.typehints import 
infer_group_pandas_eval_type_from_func
+        from pyspark.sql.pandas.functions import PythonEvalType
+        import warnings

Review Comment:
   removed



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to