zhengruifeng commented on code in PR #53317:
URL: https://github.com/apache/spark/pull/53317#discussion_r2591294559
##########
python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py:
##########
@@ -897,6 +897,161 @@ def sum_udf(v):
)
assert_frame_equal(expected, result)
+ def test_iterator_grouped_agg_basic(self):
+ """
+ Test basic functionality of iterator grouped agg pandas UDF with
Iterator[pd.Series].
+ """
+ from typing import Iterator
+
+ df = self.spark.createDataFrame(
+ [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
+ )
+
+ @pandas_udf("double")
+ def pandas_mean_iter(it: Iterator[pd.Series]) -> float:
+ sum_val = 0.0
+ cnt = 0
+ for series in it:
+ assert isinstance(series, pd.Series)
+ sum_val += series.sum()
+ cnt += len(series)
+ return sum_val / cnt if cnt > 0 else 0.0
+
+ result =
df.groupby("id").agg(pandas_mean_iter(df["v"]).alias("mean")).sort("id").collect()
+
+ # Expected means:
+ # Group 1: (1.0 + 2.0) / 2 = 1.5
+ # Group 2: (3.0 + 5.0 + 10.0) / 3 = 6.0
+ expected = [(1, 1.5), (2, 6.0)]
+
+ self.assertEqual(len(result), len(expected))
+ for r, (exp_id, exp_mean) in zip(result, expected):
+ self.assertEqual(r["id"], exp_id)
+ self.assertAlmostEqual(r["mean"], exp_mean, places=5)
+
+ def test_iterator_grouped_agg_multiple_columns(self):
+ """
+ Test iterator grouped agg pandas UDF with multiple columns
+ using Iterator[Tuple[pd.Series, ...]].
+ """
+ from typing import Iterator, Tuple
+
+ df = self.spark.createDataFrame(
+ [(1, 1.0, 1.0), (1, 2.0, 2.0), (2, 3.0, 1.0), (2, 5.0, 2.0), (2,
10.0, 3.0)],
+ ("id", "v", "w"),
+ )
+
+ @pandas_udf("double")
+ def pandas_weighted_mean_iter(it: Iterator[Tuple[pd.Series,
pd.Series]]) -> float:
+ import numpy as np
+
+ weighted_sum = 0.0
+ weight = 0.0
+ for v_series, w_series in it:
+ assert isinstance(v_series, pd.Series)
+ assert isinstance(w_series, pd.Series)
+ weighted_sum += np.dot(v_series, w_series)
+ weight += w_series.sum()
+ return weighted_sum / weight if weight > 0 else 0.0
+
+ result = (
+ df.groupby("id")
+ .agg(pandas_weighted_mean_iter(df["v"], df["w"]).alias("wm"))
+ .sort("id")
+ .collect()
+ )
+
+ # Expected weighted means:
+ # Group 1: (1.0*1.0 + 2.0*2.0) / (1.0 + 2.0) = 5.0 / 3.0
+ # Group 2: (3.0*1.0 + 5.0*2.0 + 10.0*3.0) / (1.0 + 2.0 + 3.0) = 43.0 /
6.0
+ expected = [(1, 5.0 / 3.0), (2, 43.0 / 6.0)]
Review Comment:
ditto
##########
python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py:
##########
@@ -897,6 +897,161 @@ def sum_udf(v):
)
assert_frame_equal(expected, result)
+ def test_iterator_grouped_agg_basic(self):
+ """
+ Test basic functionality of iterator grouped agg pandas UDF with
Iterator[pd.Series].
+ """
+ from typing import Iterator
+
+ df = self.spark.createDataFrame(
+ [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
+ )
+
+ @pandas_udf("double")
+ def pandas_mean_iter(it: Iterator[pd.Series]) -> float:
+ sum_val = 0.0
+ cnt = 0
+ for series in it:
+ assert isinstance(series, pd.Series)
+ sum_val += series.sum()
+ cnt += len(series)
+ return sum_val / cnt if cnt > 0 else 0.0
+
+ result =
df.groupby("id").agg(pandas_mean_iter(df["v"]).alias("mean")).sort("id").collect()
+
+ # Expected means:
+ # Group 1: (1.0 + 2.0) / 2 = 1.5
+ # Group 2: (3.0 + 5.0 + 10.0) / 3 = 6.0
+ expected = [(1, 1.5), (2, 6.0)]
+
+ self.assertEqual(len(result), len(expected))
+ for r, (exp_id, exp_mean) in zip(result, expected):
+ self.assertEqual(r["id"], exp_id)
+ self.assertAlmostEqual(r["mean"], exp_mean, places=5)
+
+ def test_iterator_grouped_agg_multiple_columns(self):
+ """
+ Test iterator grouped agg pandas UDF with multiple columns
+ using Iterator[Tuple[pd.Series, ...]].
+ """
+ from typing import Iterator, Tuple
Review Comment:
ditto
##########
python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py:
##########
@@ -897,6 +897,161 @@ def sum_udf(v):
)
assert_frame_equal(expected, result)
+ def test_iterator_grouped_agg_basic(self):
+ """
+ Test basic functionality of iterator grouped agg pandas UDF with
Iterator[pd.Series].
+ """
+ from typing import Iterator
Review Comment:
can you put such common imports at the head of this file?
##########
python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py:
##########
@@ -344,6 +385,58 @@ def weighted_mean(v: pd.Series, w: pd.Series) ->
np.float64:
expected = df.groupby("id").agg(mean(df.v).alias("weighted_mean(v,
1.0)")).sort("id")
assert_frame_equal(expected.toPandas(), actual.toPandas())
+ def test_group_agg_iter_udf_type_hint(self):
+ df = self.spark.createDataFrame(
+ [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
+ )
+
+ def pandas_mean_iter(it: Iterator[pd.Series]) -> float:
+ sum_val = 0.0
+ cnt = 0
+ for series in it:
+ sum_val += series.sum()
+ cnt += len(series)
+ return sum_val / cnt if cnt > 0 else 0.0
+
+ pandas_mean_iter = pandas_udf("double")(pandas_mean_iter)
+
+ actual =
df.groupby("id").agg(pandas_mean_iter(df["v"]).alias("mean")).sort("id")
+ expected = df.groupby("id").agg(mean(df["v"]).alias("mean")).sort("id")
+ assert_frame_equal(expected.toPandas(), actual.toPandas())
+
+ # Test with Tuple for multiple columns
+ df2 = self.spark.createDataFrame(
+ [(1, 1.0, 1.0), (1, 2.0, 2.0), (2, 3.0, 1.0), (2, 5.0, 2.0), (2,
10.0, 3.0)],
+ ("id", "v", "w"),
+ )
+
+ def pandas_weighted_mean_iter(it: Iterator[Tuple[pd.Series,
pd.Series]]) -> float:
+ import numpy as np
+
+ weighted_sum = 0.0
+ weight = 0.0
+ for v_series, w_series in it:
+ weighted_sum += np.dot(v_series, w_series)
+ weight += w_series.sum()
+ return weighted_sum / weight if weight > 0 else 0.0
+
+ pandas_weighted_mean_iter =
pandas_udf("double")(pandas_weighted_mean_iter)
+
+ actual2 = (
+ df2.groupby("id")
+ .agg(pandas_weighted_mean_iter(df2["v"], df2["w"]).alias("wm"))
+ .sort("id")
+ )
+ # Expected weighted means:
+ # Group 1: (1.0*1.0 + 2.0*2.0) / (1.0 + 2.0) = 5.0 / 3.0
+ # Group 2: (3.0*1.0 + 5.0*2.0 + 10.0*3.0) / (1.0 + 2.0 + 3.0) = 43.0 /
6.0
+ expected_results = [(1, 5.0 / 3.0), (2, 43.0 / 6.0)]
+ actual_results = actual2.collect()
+ self.assertEqual(len(actual_results), len(expected_results))
+ for (id_val, exp_wm), actual_row in zip(expected_results,
actual_results):
Review Comment:
ditto, why comparing the results in such complex way?
##########
python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py:
##########
@@ -897,6 +897,161 @@ def sum_udf(v):
)
assert_frame_equal(expected, result)
+ def test_iterator_grouped_agg_basic(self):
+ """
+ Test basic functionality of iterator grouped agg pandas UDF with
Iterator[pd.Series].
+ """
+ from typing import Iterator
+
+ df = self.spark.createDataFrame(
+ [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
+ )
+
+ @pandas_udf("double")
+ def pandas_mean_iter(it: Iterator[pd.Series]) -> float:
+ sum_val = 0.0
+ cnt = 0
+ for series in it:
+ assert isinstance(series, pd.Series)
+ sum_val += series.sum()
+ cnt += len(series)
+ return sum_val / cnt if cnt > 0 else 0.0
+
+ result =
df.groupby("id").agg(pandas_mean_iter(df["v"]).alias("mean")).sort("id").collect()
+
+ # Expected means:
+ # Group 1: (1.0 + 2.0) / 2 = 1.5
+ # Group 2: (3.0 + 5.0 + 10.0) / 3 = 6.0
+ expected = [(1, 1.5), (2, 6.0)]
+
+ self.assertEqual(len(result), len(expected))
+ for r, (exp_id, exp_mean) in zip(result, expected):
+ self.assertEqual(r["id"], exp_id)
+ self.assertAlmostEqual(r["mean"], exp_mean, places=5)
Review Comment:
you can directly compare the rows
```suggestion
expected = [Row... ]
self.assertEqual(result, expected)
```
##########
python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py:
##########
@@ -897,6 +897,161 @@ def sum_udf(v):
)
assert_frame_equal(expected, result)
+ def test_iterator_grouped_agg_basic(self):
+ """
+ Test basic functionality of iterator grouped agg pandas UDF with
Iterator[pd.Series].
+ """
+ from typing import Iterator
+
+ df = self.spark.createDataFrame(
+ [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
+ )
+
+ @pandas_udf("double")
+ def pandas_mean_iter(it: Iterator[pd.Series]) -> float:
+ sum_val = 0.0
+ cnt = 0
+ for series in it:
+ assert isinstance(series, pd.Series)
+ sum_val += series.sum()
+ cnt += len(series)
+ return sum_val / cnt if cnt > 0 else 0.0
+
+ result =
df.groupby("id").agg(pandas_mean_iter(df["v"]).alias("mean")).sort("id").collect()
+
+ # Expected means:
+ # Group 1: (1.0 + 2.0) / 2 = 1.5
+ # Group 2: (3.0 + 5.0 + 10.0) / 3 = 6.0
+ expected = [(1, 1.5), (2, 6.0)]
+
+ self.assertEqual(len(result), len(expected))
+ for r, (exp_id, exp_mean) in zip(result, expected):
+ self.assertEqual(r["id"], exp_id)
+ self.assertAlmostEqual(r["mean"], exp_mean, places=5)
+
+ def test_iterator_grouped_agg_multiple_columns(self):
+ """
+ Test iterator grouped agg pandas UDF with multiple columns
+ using Iterator[Tuple[pd.Series, ...]].
+ """
+ from typing import Iterator, Tuple
+
+ df = self.spark.createDataFrame(
+ [(1, 1.0, 1.0), (1, 2.0, 2.0), (2, 3.0, 1.0), (2, 5.0, 2.0), (2,
10.0, 3.0)],
+ ("id", "v", "w"),
+ )
+
+ @pandas_udf("double")
+ def pandas_weighted_mean_iter(it: Iterator[Tuple[pd.Series,
pd.Series]]) -> float:
+ import numpy as np
+
+ weighted_sum = 0.0
+ weight = 0.0
+ for v_series, w_series in it:
+ assert isinstance(v_series, pd.Series)
+ assert isinstance(w_series, pd.Series)
+ weighted_sum += np.dot(v_series, w_series)
+ weight += w_series.sum()
+ return weighted_sum / weight if weight > 0 else 0.0
+
+ result = (
+ df.groupby("id")
+ .agg(pandas_weighted_mean_iter(df["v"], df["w"]).alias("wm"))
+ .sort("id")
+ .collect()
+ )
+
+ # Expected weighted means:
+ # Group 1: (1.0*1.0 + 2.0*2.0) / (1.0 + 2.0) = 5.0 / 3.0
+ # Group 2: (3.0*1.0 + 5.0*2.0 + 10.0*3.0) / (1.0 + 2.0 + 3.0) = 43.0 /
6.0
+ expected = [(1, 5.0 / 3.0), (2, 43.0 / 6.0)]
+
+ self.assertEqual(len(result), len(expected))
+ for r, (exp_id, exp_wm) in zip(result, expected):
+ self.assertEqual(r["id"], exp_id)
+ self.assertAlmostEqual(r["wm"], exp_wm, places=5)
+
+ def test_iterator_grouped_agg_eval_type(self):
+ """
+ Test that the eval type is correctly inferred for iterator grouped agg
UDFs.
+ """
+ from typing import Iterator, Tuple
+
+ @pandas_udf("double")
+ def pandas_sum_iter(it: Iterator[pd.Series]) -> float:
+ total = 0.0
+ for series in it:
+ total += series.sum()
+ return total
+
+ self.assertEqual(pandas_sum_iter.evalType,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_ITER_UDF)
+
+ @pandas_udf("double")
+ def pandas_sum_iter_tuple(it: Iterator[Tuple[pd.Series, pd.Series]])
-> float:
+ total = 0.0
+ for v, w in it:
+ total += v.sum()
+ return total
+
+ self.assertEqual(
+ pandas_sum_iter_tuple.evalType,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_ITER_UDF
+ )
+
+ def test_iterator_grouped_agg_partial_consumption(self):
+ """
+ Test that iterator grouped agg UDF can partially consume batches.
+ This ensures that batches are processed one by one without loading all
data into memory.
+ """
+ from typing import Iterator
+
+ # Create a dataset with multiple batches per group
+ # Use small batch size to ensure multiple batches per group
+ with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch":
2}):
+ df = self.spark.createDataFrame(
+ [(1, 1.0), (1, 2.0), (1, 3.0), (1, 4.0), (2, 5.0), (2, 6.0)],
("id", "v")
+ )
+
+ @pandas_udf("double")
+ def pandas_sum_partial(it: Iterator[pd.Series]) -> float:
+ # Only consume first two batches, then return
+ # This tests that partial consumption works correctly
+ total = 0.0
+ count = 0
+ for i, series in enumerate(it):
+ if i < 2: # Only process first 2 batches
+ total += series.sum()
+ count += len(series)
+ else:
+ # Stop early - partial consumption
+ break
+ return total / count if count > 0 else 0.0
+
+ result =
df.groupby("id").agg(pandas_sum_partial(df["v"]).alias("mean")).sort("id")
+
+ # Verify results are correct for partial consumption
+ # With batch size = 2:
+ # Group 1 (id=1): 4 values in 2 batches -> processes both batches
+ # Batch 1: [1.0, 2.0], Batch 2: [3.0, 4.0]
Review Comment:
you are making the assumption of the ordering in a group, I suspect this
will make this test unstable
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]