Yicong-Huang commented on code in PR #52716:
URL: https://github.com/apache/spark/pull/52716#discussion_r2471039518


##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -1245,6 +1245,89 @@ def __repr__(self):
         return "GroupPandasUDFSerializer"
 
 
+class GroupPandasIterUDFSerializer(ArrowStreamPandasUDFSerializer):
+    """
+    Serializer for grouped map Pandas iterator UDFs.
+    
+    Loads grouped data as pandas.Series and serializes results from iterator 
UDFs.
+    Flattens the (dataframes_generator, arrow_type) tuple by iterating over 
the generator.
+    """
+    
+    def __init__(
+        self,
+        timezone,
+        safecheck,
+        assign_cols_by_name,
+        int_to_decimal_coercion_enabled,
+    ):
+        super(GroupPandasIterUDFSerializer, self).__init__(
+            timezone=timezone,
+            safecheck=safecheck,
+            assign_cols_by_name=assign_cols_by_name,
+            df_for_struct=False,
+            struct_in_pandas="dict",
+            ndarray_as_list=False,
+            arrow_cast=True,
+            input_types=None,
+            int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
+        )
+
+    def load_stream(self, stream):
+        """
+        Deserialize Grouped ArrowRecordBatches and yield a generator of 
pandas.Series lists
+        (one list per batch), allowing the iterator UDF to process data 
batch-by-batch.
+        """
+        import pyarrow as pa
+
+        dataframes_in_group = None
+
+        while dataframes_in_group is None or dataframes_in_group > 0:
+            dataframes_in_group = read_int(stream)
+
+            if dataframes_in_group == 1:
+                # Read all Arrow batches for this group first (must read from 
stream synchronously)
+                batches = list(ArrowStreamSerializer.load_stream(self, stream))

Review Comment:
   I think this was an old commit... but I've updated it to be iterator patten. 
Thanks for catching it!
   
   I think we can do another PR to deduplicate the code.



##########
python/pyspark/worker.py:
##########
@@ -2625,6 +2664,12 @@ def read_udfs(pickleSer, infile, eval_type):
             ser = GroupPandasUDFSerializer(
                 timezone, safecheck, _assign_cols_by_name, 
int_to_decimal_coercion_enabled
             )
+        elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF:
+            from pyspark.sql.pandas.serializers import 
GroupPandasIterUDFSerializer

Review Comment:
   moved



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to