zhengruifeng commented on code in PR #53035:
URL: https://github.com/apache/spark/pull/53035#discussion_r2525599083


##########
python/pyspark/sql/pandas/functions.py:
##########
@@ -301,6 +303,66 @@ def calculate(iterator: Iterator[pa.Array]) -> 
Iterator[pa.Array]:
             Therefore, mutating the input arrays is not allowed and will cause 
incorrect results.
             For the same reason, users should also not rely on the index of 
the input arrays.
 
+    * Iterator of Arrays to Scalar
+        `Iterator[pyarrow.Array]` -> `Any`
+
+        The function takes an iterator of `pyarrow.Array` and returns a scalar 
value. This is
+        useful for grouped aggregations where the UDF can process all batches 
for a group
+        iteratively, which is more memory-efficient than loading all data at 
once. The returned
+        scalar can be a python primitive type, a numpy data type, or a 
`pyarrow.Scalar` instance.
+
+        >>> import pandas as pd
+        >>> from typing import Iterator
+        >>> @arrow_udf("double")
+        ... def arrow_mean(it: Iterator[pa.Array]) -> float:
+        ...     sum_val = 0.0
+        ...     cnt = 0
+        ...     for v in it:
+        ...         assert isinstance(v, pa.Array)
+        ...         sum_val += pa.compute.sum(v).as_py()
+        ...         cnt += len(v)
+        ...     return sum_val / cnt
+        ...
+        >>> df = spark.createDataFrame(
+        ...     [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", 
"v"))
+        >>> df.groupby("id").agg(arrow_mean(df['v'])).show()  # doctest: +SKIP

Review Comment:
   do not skip the doctests



##########
python/pyspark/sql/pandas/functions.py:
##########
@@ -301,6 +303,66 @@ def calculate(iterator: Iterator[pa.Array]) -> 
Iterator[pa.Array]:
             Therefore, mutating the input arrays is not allowed and will cause 
incorrect results.
             For the same reason, users should also not rely on the index of 
the input arrays.
 
+    * Iterator of Arrays to Scalar
+        `Iterator[pyarrow.Array]` -> `Any`
+
+        The function takes an iterator of `pyarrow.Array` and returns a scalar 
value. This is
+        useful for grouped aggregations where the UDF can process all batches 
for a group
+        iteratively, which is more memory-efficient than loading all data at 
once. The returned
+        scalar can be a python primitive type, a numpy data type, or a 
`pyarrow.Scalar` instance.
+
+        >>> import pandas as pd
+        >>> from typing import Iterator
+        >>> @arrow_udf("double")
+        ... def arrow_mean(it: Iterator[pa.Array]) -> float:
+        ...     sum_val = 0.0
+        ...     cnt = 0
+        ...     for v in it:
+        ...         assert isinstance(v, pa.Array)
+        ...         sum_val += pa.compute.sum(v).as_py()
+        ...         cnt += len(v)
+        ...     return sum_val / cnt
+        ...
+        >>> df = spark.createDataFrame(
+        ...     [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", 
"v"))
+        >>> df.groupby("id").agg(arrow_mean(df['v'])).show()  # doctest: +SKIP
+        +---+---------------+
+        | id|arrow_mean(v)  |
+        +---+---------------+
+        |  1|            1.5|
+        |  2|            6.0|
+        +---+---------------+
+
+    * Iterator of Multiple Arrays to Scalar
+        `Iterator[Tuple[pyarrow.Array, ...]]` -> `Any`
+
+        The function takes an iterator of a tuple of multiple `pyarrow.Array` 
and returns a
+        scalar value. This is useful for grouped aggregations with multiple 
input columns.
+
+        >>> from typing import Iterator, Tuple
+        >>> import numpy as np
+        >>> @arrow_udf("double")
+        ... def arrow_weighted_mean(it: Iterator[Tuple[pa.Array, pa.Array]]) 
-> float:
+        ...     weighted_sum = 0.0
+        ...     weight = 0.0
+        ...     for v, w in it:
+        ...         assert isinstance(v, pa.Array)
+        ...         assert isinstance(w, pa.Array)
+        ...         weighted_sum += np.dot(v, w)
+        ...         weight += pa.compute.sum(w).as_py()
+        ...     return weighted_sum / weight
+        ...
+        >>> df = spark.createDataFrame(
+        ...     [(1, 1.0, 1.0), (1, 2.0, 2.0), (2, 3.0, 1.0), (2, 5.0, 2.0), 
(2, 10.0, 3.0)],
+        ...     ("id", "v", "w"))
+        >>> df.groupby("id").agg(arrow_weighted_mean(df["v"], df["w"])).show() 
 # doctest: +SKIP
+        +---+---------------------------------+
+        | id|arrow_weighted_mean(v, w)        |
+        +---+---------------------------------+
+        |  1|               1.6666666666666667|
+        |  2|                7.166666666666667|

Review Comment:
   ```suggestion
           |  1|               1.6666666666666...|
           |  2|                7.166666666666...|
   ```



##########
python/pyspark/sql/pandas/functions.py:
##########
@@ -301,6 +303,66 @@ def calculate(iterator: Iterator[pa.Array]) -> 
Iterator[pa.Array]:
             Therefore, mutating the input arrays is not allowed and will cause 
incorrect results.
             For the same reason, users should also not rely on the index of 
the input arrays.
 
+    * Iterator of Arrays to Scalar
+        `Iterator[pyarrow.Array]` -> `Any`
+
+        The function takes an iterator of `pyarrow.Array` and returns a scalar 
value. This is
+        useful for grouped aggregations where the UDF can process all batches 
for a group
+        iteratively, which is more memory-efficient than loading all data at 
once. The returned
+        scalar can be a python primitive type, a numpy data type, or a 
`pyarrow.Scalar` instance.
+
+        >>> import pandas as pd
+        >>> from typing import Iterator
+        >>> @arrow_udf("double")
+        ... def arrow_mean(it: Iterator[pa.Array]) -> float:
+        ...     sum_val = 0.0
+        ...     cnt = 0
+        ...     for v in it:
+        ...         assert isinstance(v, pa.Array)
+        ...         sum_val += pa.compute.sum(v).as_py()
+        ...         cnt += len(v)
+        ...     return sum_val / cnt
+        ...
+        >>> df = spark.createDataFrame(
+        ...     [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", 
"v"))
+        >>> df.groupby("id").agg(arrow_mean(df['v'])).show()  # doctest: +SKIP
+        +---+---------------+
+        | id|arrow_mean(v)  |
+        +---+---------------+
+        |  1|            1.5|
+        |  2|            6.0|
+        +---+---------------+
+
+    * Iterator of Multiple Arrays to Scalar
+        `Iterator[Tuple[pyarrow.Array, ...]]` -> `Any`
+
+        The function takes an iterator of a tuple of multiple `pyarrow.Array` 
and returns a
+        scalar value. This is useful for grouped aggregations with multiple 
input columns.
+
+        >>> from typing import Iterator, Tuple
+        >>> import numpy as np
+        >>> @arrow_udf("double")
+        ... def arrow_weighted_mean(it: Iterator[Tuple[pa.Array, pa.Array]]) 
-> float:
+        ...     weighted_sum = 0.0
+        ...     weight = 0.0
+        ...     for v, w in it:
+        ...         assert isinstance(v, pa.Array)
+        ...         assert isinstance(w, pa.Array)
+        ...         weighted_sum += np.dot(v, w)
+        ...         weight += pa.compute.sum(w).as_py()
+        ...     return weighted_sum / weight
+        ...
+        >>> df = spark.createDataFrame(
+        ...     [(1, 1.0, 1.0), (1, 2.0, 2.0), (2, 3.0, 1.0), (2, 5.0, 2.0), 
(2, 10.0, 3.0)],
+        ...     ("id", "v", "w"))
+        >>> df.groupby("id").agg(arrow_weighted_mean(df["v"], df["w"])).show() 
 # doctest: +SKIP

Review Comment:
   ditto



##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -1200,6 +1200,72 @@ def __repr__(self):
         return "ArrowStreamAggArrowUDFSerializer"
 
 
+# Serializer for SQL_GROUPED_AGG_ARROW_ITER_UDF
+class ArrowStreamAggArrowIterUDFSerializer(ArrowStreamArrowUDFSerializer):

Review Comment:
   we should consolidate it with `ArrowStreamAggArrowUDFSerializer`: make 
`ArrowStreamAggArrowUDFSerializer` output the iterator and adjust the wrapper 
of `SQL_GROUPED_AGG_ARROW_UDF` and `SQL_WINDOW_AGG_ARROW_UDF`



##########
python/pyspark/sql/pandas/functions.py:
##########
@@ -301,6 +303,66 @@ def calculate(iterator: Iterator[pa.Array]) -> 
Iterator[pa.Array]:
             Therefore, mutating the input arrays is not allowed and will cause 
incorrect results.
             For the same reason, users should also not rely on the index of 
the input arrays.
 
+    * Iterator of Arrays to Scalar
+        `Iterator[pyarrow.Array]` -> `Any`
+
+        The function takes an iterator of `pyarrow.Array` and returns a scalar 
value. This is
+        useful for grouped aggregations where the UDF can process all batches 
for a group
+        iteratively, which is more memory-efficient than loading all data at 
once. The returned
+        scalar can be a python primitive type, a numpy data type, or a 
`pyarrow.Scalar` instance.
+
+        >>> import pandas as pd

Review Comment:
   pandas is not used?



##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -1200,6 +1200,72 @@ def __repr__(self):
         return "ArrowStreamAggArrowUDFSerializer"
 
 
+# Serializer for SQL_GROUPED_AGG_ARROW_ITER_UDF
+class ArrowStreamAggArrowIterUDFSerializer(ArrowStreamArrowUDFSerializer):
+    def __init__(
+        self,
+        timezone,
+        safecheck,
+        assign_cols_by_name,
+        arrow_cast,
+    ):
+        super().__init__(
+            timezone=timezone,
+            safecheck=safecheck,
+            assign_cols_by_name=False,
+            arrow_cast=True,
+        )
+        self._timezone = timezone
+        self._safecheck = safecheck
+        self._assign_cols_by_name = assign_cols_by_name
+        self._arrow_cast = arrow_cast
+
+    def load_stream(self, stream):
+        """
+        Yield column iterators instead of concatenating batches.
+        Each group yields a structure where indexing by column offset gives an 
iterator of arrays.
+        """
+        import pyarrow as pa
+
+        dataframes_in_group = None
+
+        while dataframes_in_group is None or dataframes_in_group > 0:
+            dataframes_in_group = read_int(stream)
+
+            if dataframes_in_group == 1:
+                batches = list(ArrowStreamSerializer.load_stream(self, stream))

Review Comment:
   we should not load all batches in a group, this new API is designed to 
process each group in an incremental approach



##########
python/pyspark/sql/pandas/typehints.py:
##########
@@ -226,6 +234,41 @@ def infer_arrow_eval_type(
     if is_iterator_array:
         return ArrowUDFType.SCALAR_ITER
 
+    # Iterator[Tuple[pa.Array, ...]] -> Any

Review Comment:
   let's move the new inference after `pa.Array, ... -> Any`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to