Yicong-Huang commented on code in PR #53035:
URL: https://github.com/apache/spark/pull/53035#discussion_r2561500721
##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -1200,6 +1200,72 @@ def __repr__(self):
return "ArrowStreamAggArrowUDFSerializer"
+# Serializer for SQL_GROUPED_AGG_ARROW_ITER_UDF
+class ArrowStreamAggArrowIterUDFSerializer(ArrowStreamArrowUDFSerializer):
+ def __init__(
+ self,
+ timezone,
+ safecheck,
+ assign_cols_by_name,
+ arrow_cast,
+ ):
+ super().__init__(
+ timezone=timezone,
+ safecheck=safecheck,
+ assign_cols_by_name=False,
+ arrow_cast=True,
+ )
+ self._timezone = timezone
+ self._safecheck = safecheck
+ self._assign_cols_by_name = assign_cols_by_name
+ self._arrow_cast = arrow_cast
+
+ def load_stream(self, stream):
+ """
+ Yield column iterators instead of concatenating batches.
+ Each group yields a structure where indexing by column offset gives an
iterator of arrays.
+ """
+ import pyarrow as pa
+
+ dataframes_in_group = None
+
+ while dataframes_in_group is None or dataframes_in_group > 0:
+ dataframes_in_group = read_int(stream)
+
+ if dataframes_in_group == 1:
+ batches = list(ArrowStreamSerializer.load_stream(self, stream))
Review Comment:
Thanks for spotting this! yes `tee` was introduced to support multiple UDFs,
so that we might need to consume the batches multiple times. I checked again
and seems we should only support a single UDF. so tee is not needed any more
and I have removed it.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]