Yicong-Huang commented on code in PR #52716:
URL: https://github.com/apache/spark/pull/52716#discussion_r2476523786
##########
python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py:
##########
@@ -988,6 +988,423 @@ def test_negative_and_zero_batch_size(self):
with
self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": batch_size}):
ApplyInPandasTestsMixin.test_complex_groupby(self)
+ def test_apply_in_pandas_iterator_basic(self):
+ from typing import Iterator
+
+ df = self.spark.createDataFrame(
+ [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
+ )
+
+ def sum_func(batches: Iterator[pd.DataFrame]) ->
Iterator[pd.DataFrame]:
+ total = 0
+ for batch in batches:
+ total += batch["v"].sum()
+ yield pd.DataFrame({"v": [total]})
+
+ result = df.groupby("id").applyInPandas(sum_func, schema="v
double").orderBy("v").collect()
+ self.assertEqual(len(result), 2)
+ self.assertEqual(result[0][0], 3.0)
+ self.assertEqual(result[1][0], 18.0)
+
+ def test_apply_in_pandas_iterator_with_keys(self):
+ from typing import Iterator, Tuple, Any
+
+ df = self.spark.createDataFrame(
+ [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
+ )
+
+ def sum_func(
+ key: Tuple[Any, ...], batches: Iterator[pd.DataFrame]
+ ) -> Iterator[pd.DataFrame]:
+ total = 0
+ for batch in batches:
+ total += batch["v"].sum()
+ yield pd.DataFrame({"id": [key[0]], "v": [total]})
+
+ result = (
+ df.groupby("id")
+ .applyInPandas(sum_func, schema="id long, v double")
+ .orderBy("id")
+ .collect()
+ )
+ self.assertEqual(len(result), 2)
+ self.assertEqual(result[0][0], 1)
+ self.assertEqual(result[0][1], 3.0)
+ self.assertEqual(result[1][0], 2)
+ self.assertEqual(result[1][1], 18.0)
+
+ def test_apply_in_pandas_iterator_batch_slicing(self):
+ from typing import Iterator
+
+ df = self.spark.range(10000000).select(
+ (sf.col("id") % 2).alias("key"), sf.col("id").alias("v")
+ )
+ cols = {f"col_{i}": sf.col("v") + i for i in range(20)}
+ df = df.withColumns(cols)
+
+ def min_max_v(batches: Iterator[pd.DataFrame]) ->
Iterator[pd.DataFrame]:
+ # Collect all batches to compute min/max across the entire group
+ all_data = []
+ key_val = None
+ for batch in batches:
+ all_data.append(batch)
+ if key_val is None:
+ key_val = batch.key.iloc[0]
+
+ combined = pd.concat(all_data, ignore_index=True)
+ assert len(combined) == 10000000 / 2, len(combined)
+
+ yield pd.DataFrame(
+ {
+ "key": [key_val],
+ "min": [combined.v.min()],
+ "max": [combined.v.max()],
+ }
+ )
+
+ expected = (
+ df.groupby("key").agg(sf.min("v").alias("min"),
sf.max("v").alias("max")).sort("key")
+ ).collect()
+
+ for maxRecords, maxBytes in [(1000, 2**31 - 1), (0, 1048576), (1000,
1048576)]:
+ with self.subTest(maxRecords=maxRecords, maxBytes=maxBytes):
+ with self.sql_conf(
+ {
+ "spark.sql.execution.arrow.maxRecordsPerBatch":
maxRecords,
+ "spark.sql.execution.arrow.maxBytesPerBatch": maxBytes,
+ }
+ ):
+ result = (
+ df.groupBy("key")
+ .applyInPandas(min_max_v, "key long, min long, max
long")
+ .sort("key")
+ ).collect()
+
+ self.assertEqual(expected, result)
+
+ def test_apply_in_pandas_iterator_with_keys_batch_slicing(self):
+ from typing import Iterator, Tuple, Any
Review Comment:
moved
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]