zhengruifeng commented on code in PR #53035:
URL: https://github.com/apache/spark/pull/53035#discussion_r2562845325


##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -1200,6 +1200,86 @@ def __repr__(self):
         return "ArrowStreamAggArrowUDFSerializer"
 
 
+# Serializer for SQL_GROUPED_AGG_ARROW_ITER_UDF
+class ArrowStreamAggArrowIterUDFSerializer(ArrowStreamArrowUDFSerializer):
+    def __init__(
+        self,
+        timezone,
+        safecheck,
+        assign_cols_by_name,
+        arrow_cast,
+    ):
+        super().__init__(
+            timezone=timezone,
+            safecheck=safecheck,
+            assign_cols_by_name=False,
+            arrow_cast=True,
+        )
+        self._timezone = timezone
+        self._safecheck = safecheck
+        self._assign_cols_by_name = assign_cols_by_name
+        self._arrow_cast = arrow_cast
+
+    def load_stream(self, stream):
+        """
+        Yield column iterators instead of concatenating batches.
+        Each group yields a structure where indexing by column offset gives an 
iterator of arrays.
+        """
+        dataframes_in_group = None
+
+        while dataframes_in_group is None or dataframes_in_group > 0:
+            dataframes_in_group = read_int(stream)
+
+            if dataframes_in_group == 1:
+                batches_stream = ArrowStreamSerializer.load_stream(self, 
stream)
+                # Peek at the first batch to get the number of columns
+                batches_iter = iter(batches_stream)
+                try:
+                    first_batch = next(batches_iter)
+                except StopIteration:
+                    # Empty group
+                    yield []
+                    continue
+
+                num_cols = first_batch.num_columns
+
+                # Create a custom class that can be indexed to get column 
iterators
+                # Uses itertools.tee to create independent iterators for each 
column
+                class ColumnIterators:
+                    def __init__(self, first_batch, batches_iter, num_cols):
+                        self._first_batch = first_batch
+                        self._batches_iter = batches_iter
+                        self._num_cols = num_cols
+                        self._teed_iters = None
+
+                    def __getitem__(self, col_idx):
+                        import itertools
+
+                        # Lazily create teed iterators when first column is 
accessed
+                        if self._teed_iters is None:
+                            # Recreate full batch stream including first batch
+                            full_stream = itertools.chain([self._first_batch], 
self._batches_iter)
+                            # Create independent iterators for each column
+                            self._teed_iters = itertools.tee(full_stream, 
self._num_cols)

Review Comment:
   > I checked again and seems we should only support a single UDF.
   
   I don't think so, this new UDF should work in multiple UDF cases, e.g.
   
   `df.groupby()..agg(udf1(...), udf2(...))`, both `udf1` and `udf2` are of the 
new UDF type.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to