Yicong-Huang commented on code in PR #53317:
URL: https://github.com/apache/spark/pull/53317#discussion_r2612481886
##########
python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py:
##########
@@ -897,6 +897,146 @@ def sum_udf(v):
)
assert_frame_equal(expected, result)
+ def test_iterator_grouped_agg_basic(self):
+ """
+ Test basic functionality of iterator grouped agg pandas UDF with
Iterator[pd.Series].
+ """
+ df = self.spark.createDataFrame(
+ [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
+ )
+
+ @pandas_udf("double")
+ def pandas_mean_iter(it: Iterator[pd.Series]) -> float:
+ sum_val = 0.0
+ cnt = 0
+ for series in it:
+ assert isinstance(series, pd.Series)
+ sum_val += series.sum()
+ cnt += len(series)
+ return sum_val / cnt if cnt > 0 else 0.0
+
+ result =
df.groupby("id").agg(pandas_mean_iter(df["v"]).alias("mean")).sort("id").collect()
+
+ # Expected means:
+ # Group 1: (1.0 + 2.0) / 2 = 1.5
+ # Group 2: (3.0 + 5.0 + 10.0) / 3 = 6.0
+ expected = [Row(id=1, mean=1.5), Row(id=2, mean=6.0)]
+ self.assertEqual(result, expected)
+
+ def test_iterator_grouped_agg_multiple_columns(self):
+ """
+ Test iterator grouped agg pandas UDF with multiple columns
+ using Iterator[Tuple[pd.Series, ...]].
+ """
+ df = self.spark.createDataFrame(
+ [(1, 1.0, 1.0), (1, 2.0, 2.0), (2, 3.0, 1.0), (2, 5.0, 2.0), (2,
10.0, 3.0)],
+ ("id", "v", "w"),
+ )
+
+ @pandas_udf("double")
+ def pandas_weighted_mean_iter(it: Iterator[Tuple[pd.Series,
pd.Series]]) -> float:
+ import numpy as np
+
+ weighted_sum = 0.0
+ weight = 0.0
+ for v_series, w_series in it:
+ assert isinstance(v_series, pd.Series)
+ assert isinstance(w_series, pd.Series)
+ weighted_sum += np.dot(v_series, w_series)
+ weight += w_series.sum()
+ return weighted_sum / weight if weight > 0 else 0.0
+
+ result = (
+ df.groupby("id")
+ .agg(pandas_weighted_mean_iter(df["v"], df["w"]).alias("wm"))
+ .sort("id")
+ .collect()
+ )
+
+ # Expected weighted means:
+ # Group 1: (1.0*1.0 + 2.0*2.0) / (1.0 + 2.0) = 5.0 / 3.0
+ # Group 2: (3.0*1.0 + 5.0*2.0 + 10.0*3.0) / (1.0 + 2.0 + 3.0) = 43.0 /
6.0
+ expected = [Row(id=1, wm=5.0 / 3.0), Row(id=2, wm=43.0 / 6.0)]
+ self.assertEqual(result, expected)
+
+ def test_iterator_grouped_agg_eval_type(self):
+ """
+ Test that the eval type is correctly inferred for iterator grouped agg
UDFs.
+ """
+
+ @pandas_udf("double")
+ def pandas_sum_iter(it: Iterator[pd.Series]) -> float:
+ total = 0.0
+ for series in it:
+ total += series.sum()
+ return total
+
+ self.assertEqual(pandas_sum_iter.evalType,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_ITER_UDF)
+
+ @pandas_udf("double")
+ def pandas_sum_iter_tuple(it: Iterator[Tuple[pd.Series, pd.Series]])
-> float:
+ total = 0.0
+ for v, w in it:
+ total += v.sum()
+ return total
+
+ self.assertEqual(
+ pandas_sum_iter_tuple.evalType,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_ITER_UDF
+ )
+
+ def test_iterator_grouped_agg_partial_consumption(self):
+ """
+ Test that iterator grouped agg UDF can partially consume batches.
+ This ensures that batches are processed one by one without loading all
data into memory.
+ """
+ # Create a dataset with multiple batches per group
+ # Use small batch size to ensure multiple batches per group
+ # Use same value (1.0) for all records to avoid batch ordering issues
+ with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch":
2}):
+ # Group 1: 6 values (3 batches) - will process only first 2
batches (partial)
+ # Group 2: 2 values (1 batch) - will process 1 batch (all
available)
+ df = self.spark.createDataFrame(
+ [(1, 1.0), (1, 1.0), (1, 1.0), (1, 1.0), (1, 1.0), (1, 1.0),
(2, 1.0), (2, 1.0)],
+ ("id", "v"),
+ )
+
+ @pandas_udf("long")
+ def pandas_partial_count(it: Iterator[pd.Series]) -> int:
+ # Process first 2 batches, then stop (partial consumption)
+ total_count = 0
+ for i, series in enumerate(it):
+ assert isinstance(series, pd.Series)
+ if i < 2: # Process first 2 batches
+ total_count += len(series)
+ else:
+ # Stop early - partial consumption
+ break
+ return total_count
+
+ result =
df.groupby("id").agg(pandas_partial_count(df["v"]).alias("count")).sort("id")
+
+ # Verify results are correct for partial consumption
+ # With batch size = 2:
+ # Group 1 (id=1): 6 values in 3 batches -> processes only first 2
batches (partial)
+ # Result: count=4 (only 4 out of 6 values processed)
+ # Group 2 (id=2): 2 values in 1 batch -> processes 1 batch (all
available)
+ # Result: count=2
+ actual = result.collect()
+ self.assertEqual(len(actual), 2, "Should have results for both
groups")
+
+ # Verify partial consumption works
+ # Group 1: processes only 2 batches (4 values out of 6 total) -
partial consumption
+ group1_result = next(row for row in actual if row["id"] == 1)
+ self.assertEqual(
+ group1_result["count"], 4, msg="Group 1 should process only 2
batches (4 values)"
+ )
+
+ # Group 2: processes 1 batch (all 2 values, 1 batch available)
+ group2_result = next(row for row in actual if row["id"] == 2)
+ self.assertEqual(
+ group2_result["count"], 2, msg="Group 2 should process 1 batch
(2 values)"
+ )
+
Review Comment:
I checked the other agg UDFs support struct type but expect dict of series
as input, not data frame. I am aligning it with the same expectation. if we
need to support `df_for_struct=True` case, we can do it later.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]