zhengruifeng commented on code in PR #53043:
URL: https://github.com/apache/spark/pull/53043#discussion_r2558390249


##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -1315,22 +1335,25 @@ def process_group(batches: "Iterator[pa.RecordBatch]"):
                 )
 
     def dump_stream(self, iterator, stream):
-        """
-        Flatten the (dataframes_generator, arrow_type) tuples by iterating 
over each generator.
-        This allows the iterator UDF to stream results without materializing 
all DataFrames.
-        """
-        # Flatten: (dataframes_generator, arrow_type) -> (df, arrow_type), 
(df, arrow_type), ...
-        flattened_iter = (
-            (df, arrow_type) for dataframes_gen, arrow_type in iterator for df 
in dataframes_gen
-        )
-
-        # Convert each (df, arrow_type) to the format expected by parent's 
dump_stream
-        series_iter = ([(df, arrow_type)] for df, arrow_type in flattened_iter)
+        # Flatten iterator of (generator, arrow_type) into (df, arrow_type) 
for parent class
+        def flatten_iterator():
+            for (
+                batches_gen,
+                arrow_type,
+            ) in iterator:  # tuple constructed in wrap_grouped_*_pandas_udf
+                # yields df for single UDF or [(df1, type1), (df2, type2), 
...] for multiple UDFs

Review Comment:
   is this ser dedicated for SQL_GROUPED_MAP_PANDAS_UDF and 
SQL_GROUPED_MAP_PANDAS_ITER_UDF?
   I think they don't support multiple UDFs?



##########
python/pyspark/worker.py:
##########
@@ -3003,24 +3009,50 @@ def mapper(a):
         )
         parsed_offsets = extract_key_value_indexes(arg_offsets)
 
-        # Create mapper similar to Arrow iterator:
-        # `a` is an iterator of Series lists (one list per batch, containing 
all columns)
-        # Materialize first batch to get keys, then create generator for value 
batches
-        def mapper(a):
-            import itertools
+        def series_from_offset(series_list, offsets):
+            return [series_list[o] for o in offsets]
 
-            series_iter = iter(a)
-            # Need to materialize the first series list to get the keys
+        def mapper(series_lists_iter):
+            # `series_lists_iter` is an iterator of Series lists (one list per 
batch)
+            # Materialize first batch to extract keys (guaranteed to exist for 
grouped operations)
+            series_iter = iter(series_lists_iter)
             first_series_list = next(series_iter)
 
-            keys = [first_series_list[o] for o in parsed_offsets[0][0]]
+            keys = series_from_offset(first_series_list, parsed_offsets[0][0])
+            # Create generator for value series from all batches
             value_series_gen = (
-                [series_list[o] for o in parsed_offsets[0][1]]
+                series_from_offset(series_list, parsed_offsets[0][1])
                 for series_list in itertools.chain((first_series_list,), 
series_iter)
             )
-
+            # Call wrapped function which returns (generator, arrow_type)
             return f(keys, value_series_gen)
 
+    elif eval_type in (
+        PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,

Review Comment:
   can we exclude changes for SQL_GROUPED_AGG_PANDAS_UDF and 
SQL_WINDOW_AGG_PANDAS_UDF to make the PR more cleaner?
   we can do it in a separate PR



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to