Yicong-Huang commented on code in PR #53035:
URL: https://github.com/apache/spark/pull/53035#discussion_r2561496317


##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -1200,6 +1200,86 @@ def __repr__(self):
         return "ArrowStreamAggArrowUDFSerializer"
 
 
+# Serializer for SQL_GROUPED_AGG_ARROW_ITER_UDF
+class ArrowStreamAggArrowIterUDFSerializer(ArrowStreamArrowUDFSerializer):
+    def __init__(
+        self,
+        timezone,
+        safecheck,
+        assign_cols_by_name,
+        arrow_cast,
+    ):
+        super().__init__(
+            timezone=timezone,
+            safecheck=safecheck,
+            assign_cols_by_name=False,
+            arrow_cast=True,
+        )
+        self._timezone = timezone
+        self._safecheck = safecheck
+        self._assign_cols_by_name = assign_cols_by_name
+        self._arrow_cast = arrow_cast
+
+    def load_stream(self, stream):
+        """
+        Yield column iterators instead of concatenating batches.
+        Each group yields a structure where indexing by column offset gives an 
iterator of arrays.
+        """
+        dataframes_in_group = None
+
+        while dataframes_in_group is None or dataframes_in_group > 0:
+            dataframes_in_group = read_int(stream)
+
+            if dataframes_in_group == 1:
+                batches_stream = ArrowStreamSerializer.load_stream(self, 
stream)
+                # Peek at the first batch to get the number of columns
+                batches_iter = iter(batches_stream)
+                try:
+                    first_batch = next(batches_iter)
+                except StopIteration:
+                    # Empty group
+                    yield []
+                    continue
+
+                num_cols = first_batch.num_columns
+
+                # Create a custom class that can be indexed to get column 
iterators
+                # Uses itertools.tee to create independent iterators for each 
column
+                class ColumnIterators:
+                    def __init__(self, first_batch, batches_iter, num_cols):
+                        self._first_batch = first_batch
+                        self._batches_iter = batches_iter
+                        self._num_cols = num_cols
+                        self._teed_iters = None
+
+                    def __getitem__(self, col_idx):
+                        import itertools
+
+                        # Lazily create teed iterators when first column is 
accessed
+                        if self._teed_iters is None:
+                            # Recreate full batch stream including first batch
+                            full_stream = itertools.chain([self._first_batch], 
self._batches_iter)
+                            # Create independent iterators for each column
+                            self._teed_iters = itertools.tee(full_stream, 
self._num_cols)

Review Comment:
   It was introduced to support multiple UDFs, so that we might need to consume 
the batches multiple times. I checked again and seems we should only support a 
single UDF. so tee is not needed any more.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to