zhengruifeng commented on code in PR #53317:
URL: https://github.com/apache/spark/pull/53317#discussion_r2600639778
##########
python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py:
##########
@@ -897,6 +897,146 @@ def sum_udf(v):
)
assert_frame_equal(expected, result)
+ def test_iterator_grouped_agg_basic(self):
+ """
+ Test basic functionality of iterator grouped agg pandas UDF with
Iterator[pd.Series].
+ """
+ df = self.spark.createDataFrame(
+ [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
+ )
+
+ @pandas_udf("double")
+ def pandas_mean_iter(it: Iterator[pd.Series]) -> float:
+ sum_val = 0.0
+ cnt = 0
+ for series in it:
+ assert isinstance(series, pd.Series)
+ sum_val += series.sum()
+ cnt += len(series)
+ return sum_val / cnt if cnt > 0 else 0.0
+
+ result =
df.groupby("id").agg(pandas_mean_iter(df["v"]).alias("mean")).sort("id").collect()
+
+ # Expected means:
+ # Group 1: (1.0 + 2.0) / 2 = 1.5
+ # Group 2: (3.0 + 5.0 + 10.0) / 3 = 6.0
+ expected = [Row(id=1, mean=1.5), Row(id=2, mean=6.0)]
+ self.assertEqual(result, expected)
+
+ def test_iterator_grouped_agg_multiple_columns(self):
+ """
+ Test iterator grouped agg pandas UDF with multiple columns
+ using Iterator[Tuple[pd.Series, ...]].
+ """
+ df = self.spark.createDataFrame(
+ [(1, 1.0, 1.0), (1, 2.0, 2.0), (2, 3.0, 1.0), (2, 5.0, 2.0), (2,
10.0, 3.0)],
+ ("id", "v", "w"),
+ )
+
+ @pandas_udf("double")
+ def pandas_weighted_mean_iter(it: Iterator[Tuple[pd.Series,
pd.Series]]) -> float:
+ import numpy as np
+
+ weighted_sum = 0.0
+ weight = 0.0
+ for v_series, w_series in it:
+ assert isinstance(v_series, pd.Series)
+ assert isinstance(w_series, pd.Series)
+ weighted_sum += np.dot(v_series, w_series)
+ weight += w_series.sum()
+ return weighted_sum / weight if weight > 0 else 0.0
+
+ result = (
+ df.groupby("id")
+ .agg(pandas_weighted_mean_iter(df["v"], df["w"]).alias("wm"))
+ .sort("id")
+ .collect()
+ )
+
+ # Expected weighted means:
+ # Group 1: (1.0*1.0 + 2.0*2.0) / (1.0 + 2.0) = 5.0 / 3.0
+ # Group 2: (3.0*1.0 + 5.0*2.0 + 10.0*3.0) / (1.0 + 2.0 + 3.0) = 43.0 /
6.0
Review Comment:
are you still assuming the ordering with a group?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]