Yicong-Huang commented on code in PR #52716:
URL: https://github.com/apache/spark/pull/52716#discussion_r2476516041


##########
python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py:
##########
@@ -988,6 +988,423 @@ def test_negative_and_zero_batch_size(self):
             with 
self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": batch_size}):
                 ApplyInPandasTestsMixin.test_complex_groupby(self)
 
+    def test_apply_in_pandas_iterator_basic(self):
+        from typing import Iterator
+
+        df = self.spark.createDataFrame(
+            [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
+        )
+
+        def sum_func(batches: Iterator[pd.DataFrame]) -> 
Iterator[pd.DataFrame]:
+            total = 0
+            for batch in batches:
+                total += batch["v"].sum()
+            yield pd.DataFrame({"v": [total]})
+
+        result = df.groupby("id").applyInPandas(sum_func, schema="v 
double").orderBy("v").collect()
+        self.assertEqual(len(result), 2)
+        self.assertEqual(result[0][0], 3.0)
+        self.assertEqual(result[1][0], 18.0)
+
+    def test_apply_in_pandas_iterator_with_keys(self):
+        from typing import Iterator, Tuple, Any
+
+        df = self.spark.createDataFrame(
+            [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
+        )
+
+        def sum_func(
+            key: Tuple[Any, ...], batches: Iterator[pd.DataFrame]
+        ) -> Iterator[pd.DataFrame]:
+            total = 0
+            for batch in batches:
+                total += batch["v"].sum()
+            yield pd.DataFrame({"id": [key[0]], "v": [total]})
+
+        result = (
+            df.groupby("id")
+            .applyInPandas(sum_func, schema="id long, v double")
+            .orderBy("id")
+            .collect()
+        )
+        self.assertEqual(len(result), 2)
+        self.assertEqual(result[0][0], 1)
+        self.assertEqual(result[0][1], 3.0)
+        self.assertEqual(result[1][0], 2)
+        self.assertEqual(result[1][1], 18.0)
+
+    def test_apply_in_pandas_iterator_batch_slicing(self):
+        from typing import Iterator
+
+        df = self.spark.range(10000000).select(
+            (sf.col("id") % 2).alias("key"), sf.col("id").alias("v")
+        )
+        cols = {f"col_{i}": sf.col("v") + i for i in range(20)}
+        df = df.withColumns(cols)
+
+        def min_max_v(batches: Iterator[pd.DataFrame]) -> 
Iterator[pd.DataFrame]:
+            # Collect all batches to compute min/max across the entire group
+            all_data = []
+            key_val = None
+            for batch in batches:
+                all_data.append(batch)
+                if key_val is None:
+                    key_val = batch.key.iloc[0]
+
+            combined = pd.concat(all_data, ignore_index=True)
+            assert len(combined) == 10000000 / 2, len(combined)
+
+            yield pd.DataFrame(
+                {
+                    "key": [key_val],
+                    "min": [combined.v.min()],
+                    "max": [combined.v.max()],
+                }
+            )
+
+        expected = (
+            df.groupby("key").agg(sf.min("v").alias("min"), 
sf.max("v").alias("max")).sort("key")
+        ).collect()
+
+        for maxRecords, maxBytes in [(1000, 2**31 - 1), (0, 1048576), (1000, 
1048576)]:
+            with self.subTest(maxRecords=maxRecords, maxBytes=maxBytes):
+                with self.sql_conf(
+                    {
+                        "spark.sql.execution.arrow.maxRecordsPerBatch": 
maxRecords,
+                        "spark.sql.execution.arrow.maxBytesPerBatch": maxBytes,
+                    }
+                ):
+                    result = (
+                        df.groupBy("key")
+                        .applyInPandas(min_max_v, "key long, min long, max 
long")
+                        .sort("key")
+                    ).collect()
+
+                    self.assertEqual(expected, result)
+
+    def test_apply_in_pandas_iterator_with_keys_batch_slicing(self):
+        from typing import Iterator, Tuple, Any
+
+        df = self.spark.range(10000000).select(
+            (sf.col("id") % 2).alias("key"), sf.col("id").alias("v")
+        )
+        cols = {f"col_{i}": sf.col("v") + i for i in range(20)}
+        df = df.withColumns(cols)
+
+        def min_max_v(
+            key: Tuple[Any, ...], batches: Iterator[pd.DataFrame]
+        ) -> Iterator[pd.DataFrame]:
+            # Collect all batches to compute min/max across the entire group
+            all_data = []
+            for batch in batches:
+                all_data.append(batch)
+
+            combined = pd.concat(all_data, ignore_index=True)
+            assert len(combined) == 10000000 / 2, len(combined)
+
+            yield pd.DataFrame(
+                {
+                    "key": [key[0]],
+                    "min": [combined.v.min()],
+                    "max": [combined.v.max()],
+                }
+            )
+
+        expected = (
+            df.groupby("key").agg(sf.min("v").alias("min"), 
sf.max("v").alias("max")).sort("key")
+        ).collect()
+
+        for maxRecords, maxBytes in [(1000, 2**31 - 1), (0, 1048576), (1000, 
1048576)]:
+            with self.subTest(maxRecords=maxRecords, maxBytes=maxBytes):
+                with self.sql_conf(
+                    {
+                        "spark.sql.execution.arrow.maxRecordsPerBatch": 
maxRecords,
+                        "spark.sql.execution.arrow.maxBytesPerBatch": maxBytes,
+                    }
+                ):
+                    result = (
+                        df.groupBy("key")
+                        .applyInPandas(min_max_v, "key long, min long, max 
long")
+                        .sort("key")
+                    ).collect()
+
+                    self.assertEqual(expected, result)
+
+    def test_apply_in_pandas_iterator_multiple_output_batches(self):
+        from typing import Iterator
+
+        df = self.spark.createDataFrame(
+            [(1, 1.0), (1, 2.0), (1, 3.0), (2, 4.0), (2, 5.0), (2, 6.0)], 
("id", "v")
+        )
+
+        def split_and_yield(batches: Iterator[pd.DataFrame]) -> 
Iterator[pd.DataFrame]:
+            # Yield multiple output batches for each input batch
+            for batch in batches:
+                for _, row in batch.iterrows():
+                    # Yield each row as a separate batch to test multiple 
yields
+                    yield pd.DataFrame(
+                        {"id": [row["id"]], "v": [row["v"]], "v_doubled": 
[row["v"] * 2]}
+                    )
+
+        result = (
+            df.groupby("id")
+            .applyInPandas(split_and_yield, schema="id long, v double, 
v_doubled double")
+            .orderBy("id", "v")
+            .collect()
+        )
+
+        # Verify that all rows are present after concatenation
+        self.assertEqual(len(result), 6)
+        self.assertEqual(result[0][0], 1)
+        self.assertEqual(result[0][1], 1.0)
+        self.assertEqual(result[0][2], 2.0)
+        self.assertEqual(result[1][0], 1)
+        self.assertEqual(result[1][1], 2.0)
+        self.assertEqual(result[1][2], 4.0)
+        self.assertEqual(result[2][0], 1)
+        self.assertEqual(result[2][1], 3.0)
+        self.assertEqual(result[2][2], 6.0)
+        self.assertEqual(result[3][0], 2)
+        self.assertEqual(result[3][1], 4.0)
+        self.assertEqual(result[3][2], 8.0)
+        self.assertEqual(result[4][0], 2)
+        self.assertEqual(result[4][1], 5.0)
+        self.assertEqual(result[4][2], 10.0)
+        self.assertEqual(result[5][0], 2)
+        self.assertEqual(result[5][1], 6.0)
+        self.assertEqual(result[5][2], 12.0)
+
+    def test_apply_in_pandas_iterator_filter_multiple_batches(self):
+        from typing import Iterator
+
+        df = self.spark.createDataFrame(
+            [(1, i * 1.0) for i in range(20)] + [(2, i * 1.0) for i in 
range(20)], ("id", "v")
+        )
+
+        def filter_and_yield(batches: Iterator[pd.DataFrame]) -> 
Iterator[pd.DataFrame]:
+            # Yield filtered results from each batch
+            for batch in batches:
+                # Filter even values and yield
+                even_batch = batch[batch["v"] % 2 == 0]
+                if not even_batch.empty:
+                    yield even_batch
+
+                # Filter odd values and yield separately
+                odd_batch = batch[batch["v"] % 2 == 1]
+                if not odd_batch.empty:
+                    yield odd_batch
+
+        result = (
+            df.groupby("id")
+            .applyInPandas(filter_and_yield, schema="id long, v double")
+            .orderBy("id", "v")
+            .collect()
+        )
+
+        # Verify all 40 rows are present (20 per group)
+        self.assertEqual(len(result), 40)
+
+        # Verify group 1 has all values 0-19
+        group1 = [row for row in result if row[0] == 1]
+        self.assertEqual(len(group1), 20)
+        self.assertEqual([row[1] for row in group1], [float(i) for i in 
range(20)])
+
+        # Verify group 2 has all values 0-19
+        group2 = [row for row in result if row[0] == 2]
+        self.assertEqual(len(group2), 20)
+        self.assertEqual([row[1] for row in group2], [float(i) for i in 
range(20)])
+
+    def test_apply_in_pandas_iterator_with_keys_multiple_batches(self):
+        from typing import Iterator, Tuple, Any
+
+        df = self.spark.createDataFrame(
+            [
+                (1, "a", 1.0),
+                (1, "b", 2.0),
+                (1, "c", 3.0),
+                (2, "d", 4.0),
+                (2, "e", 5.0),
+                (2, "f", 6.0),
+            ],
+            ("id", "name", "v"),
+        )
+
+        def process_with_key(
+            key: Tuple[Any, ...], batches: Iterator[pd.DataFrame]
+        ) -> Iterator[pd.DataFrame]:
+            # Yield multiple processed batches, including the key in each 
output
+            for batch in batches:
+                # Split batch and yield multiple output batches
+                for chunk_size in [1, 2]:
+                    for i in range(0, len(batch), chunk_size):
+                        chunk = batch.iloc[i : i + chunk_size]
+                        if not chunk.empty:
+                            result = chunk.assign(id=key[0], 
total=chunk["v"].sum())
+                            yield result[["id", "name", "total"]]
+
+        result = (
+            df.groupby("id")
+            .applyInPandas(process_with_key, schema="id long, name string, 
total double")
+            .orderBy("id", "name")
+            .collect()
+        )
+
+        # Verify we get results (may have duplicates due to splitting)
+        self.assertTrue(len(result) > 6)
+
+        # Verify all original names are present
+        names = [row[1] for row in result]
+        self.assertIn("a", names)
+        self.assertIn("b", names)
+        self.assertIn("c", names)
+        self.assertIn("d", names)
+        self.assertIn("e", names)
+        self.assertIn("f", names)
+
+        # Verify keys are correct
+        for row in result:
+            if row[1] in ["a", "b", "c"]:
+                self.assertEqual(row[0], 1)
+            else:
+                self.assertEqual(row[0], 2)
+
+    def test_apply_in_pandas_iterator_process_multiple_input_batches(self):
+        from typing import Iterator
+        import builtins

Review Comment:
   somehow when I use sum directly it would use column.sum. Do you know the 
reason?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to