zhengruifeng commented on code in PR #53035:
URL: https://github.com/apache/spark/pull/53035#discussion_r2558225810
##########
python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py:
##########
@@ -1059,6 +1059,90 @@ def my_grouped_agg_arrow_udf(x):
],
)
+ def test_iterator_grouped_agg_single_column(self):
+ """
+ Test iterator API for grouped aggregation with single column.
+ """
+ import pyarrow as pa
+ from typing import Iterator
+
+ @arrow_udf("double")
+ def arrow_mean_iter(it: Iterator[pa.Array]) -> float:
+ sum_val = 0.0
+ cnt = 0
+ for v in it:
+ assert isinstance(v, pa.Array)
+ sum_val += pa.compute.sum(v).as_py()
+ cnt += len(v)
+ return sum_val / cnt if cnt > 0 else 0.0
+
+ df = self.spark.createDataFrame(
+ [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
+ )
+
+ result =
df.groupby("id").agg(arrow_mean_iter(df["v"]).alias("mean")).sort("id")
+ expected =
df.groupby("id").agg(sf.mean(df["v"]).alias("mean")).sort("id").collect()
+
+ self.assertEqual(expected, result.collect())
+
+ @unittest.skipIf(not have_numpy, numpy_requirement_message)
+ def test_iterator_grouped_agg_multiple_columns(self):
+ """
+ Test iterator API for grouped aggregation with multiple columns.
+ """
+ import pyarrow as pa
+ import numpy as np
+ from typing import Iterator, Tuple
+
+ @arrow_udf("double")
+ def arrow_weighted_mean_iter(it: Iterator[Tuple[pa.Array, pa.Array]])
-> float:
+ weighted_sum = 0.0
+ weight = 0.0
+ for v, w in it:
+ assert isinstance(v, pa.Array)
+ assert isinstance(w, pa.Array)
+ weighted_sum += np.dot(v, w)
+ weight += pa.compute.sum(w).as_py()
+ return weighted_sum / weight if weight > 0 else 0.0
+
+ df = self.spark.createDataFrame(
+ [(1, 1.0, 1.0), (1, 2.0, 2.0), (2, 3.0, 1.0), (2, 5.0, 2.0), (2,
10.0, 3.0)],
+ ("id", "v", "w"),
+ )
+
+ result = (
+ df.groupby("id")
+ .agg(arrow_weighted_mean_iter(df["v"], df["w"]).alias("wm"))
+ .sort("id")
+ .collect()
+ )
+
+ # Expected weighted means:
+ # Group 1: (1.0*1.0 + 2.0*2.0) / (1.0 + 2.0) = 5.0 / 3.0
+ # Group 2: (3.0*1.0 + 5.0*2.0 + 10.0*3.0) / (1.0 + 2.0 + 3.0) = 43.0 /
6.0
+ expected = [(1, 5.0 / 3.0), (2, 43.0 / 6.0)]
+
+ self.assertEqual(len(result), len(expected))
+ for r, (exp_id, exp_wm) in zip(result, expected):
+ self.assertEqual(r["id"], exp_id)
+ self.assertAlmostEqual(r["wm"], exp_wm, places=5)
+
+ def test_iterator_grouped_agg_eval_type(self):
+ """
+ Test that the eval type is correctly inferred for iterator grouped agg
UDFs.
+ """
+ import pyarrow as pa
+ from typing import Iterator
+
+ @arrow_udf("double")
+ def arrow_sum_iter(it: Iterator[pa.Array]) -> float:
+ total = 0.0
+ for v in it:
+ total += pa.compute.sum(v).as_py()
+ return total
+
+ self.assertEqual(arrow_sum_iter.evalType,
PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF)
+
Review Comment:
we need a test to make sure partial consumption works
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]