This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3f64bcaee77b [SPARK-55736][PYTHON][TESTS] Add tests for mismatched 
batch sizes in scalar iter UDFs
3f64bcaee77b is described below

commit 3f64bcaee77b3ebabaadb7852ebb0fb17fb14d26
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Fri Feb 27 15:23:23 2026 +0800

    [SPARK-55736][PYTHON][TESTS] Add tests for mismatched batch sizes in scalar 
iter UDFs
    
    ### What changes were proposed in this pull request?
    Add tests for mismatched batch sizes in scalar iter UDFs
    
    ### Why are the changes needed?
    From the API reference 
https://apache.github.io/spark/api/python/reference/pyspark.sql/api/pyspark.sql.functions.pandas_udf.html?highlight=pandas_udf#pyspark.sql.functions.pandas_udf
    
    > The length of the entire output from the function should be the same 
length of the entire input; therefore, it can prefetch the data from the input 
iterator as long as the lengths are the same.
    
    The scalar iter UDFs don't require the output batches have the same sizes 
of input batches.
    While existing examples and tests are all for this case.
    
    This PR is to guard the case of mismatched batch sizes, with a test 200 
input batches -> single output batch
    
    ### Does this PR introduce _any_ user-facing change?
    No, test-only
    
    ### How was this patch tested?
    CI
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #54531 from zhengruifeng/test_iter_udf.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../sql/tests/arrow/test_arrow_udf_scalar.py        | 21 +++++++++++++++++++++
 .../sql/tests/pandas/test_pandas_udf_scalar.py      | 19 +++++++++++++++++++
 2 files changed, 40 insertions(+)

diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py 
b/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py
index 88ea7ad08019..061a80908f08 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py
@@ -1264,6 +1264,27 @@ class ScalarArrowUDFTestsMixin:
                 ],
             )
 
+    def test_scalar_iter_arrow_udf_with_single_output_batch(self):
+        import pyarrow as pa
+
+        @arrow_udf("long", ArrowUDFType.SCALAR_ITER)
+        def return_one(iterator):
+            rows = 0
+            batches = 0
+            for s in iterator:
+                rows += len(s)
+                batches += 1
+
+            assert rows == 1000, rows
+            assert batches == 200, batches
+            yield pa.array([1] * rows)
+
+        with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 
5}):
+            df = self.spark.range(0, 1000, 1, 1)
+            expected = [Row(one=1) for i in range(1000)]
+            result = df.select(return_one("id").alias("one")).collect()
+            self.assertEqual(expected, result)
+
 
 class ScalarArrowUDFTests(ScalarArrowUDFTestsMixin, ReusedSQLTestCase):
     @classmethod
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py 
b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
index ad3a3bea3e53..2aecb5efe7f8 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
@@ -2044,6 +2044,25 @@ class ScalarPandasUDFTestsMixin:
                     result = 
df.select(plus_two("id").alias("result")).collect()
                     self.assertEqual(expected, result)
 
+    def test_scalar_iter_pandas_udf_with_single_output_batch(self):
+        @pandas_udf("long", PandasUDFType.SCALAR_ITER)
+        def return_one(iterator):
+            rows = 0
+            batches = 0
+            for s in iterator:
+                rows += len(s)
+                batches += 1
+
+            assert rows == 1000, rows
+            assert batches == 200, batches
+            yield pd.Series([1] * rows)
+
+        with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 
5}):
+            df = self.spark.range(0, 1000, 1, 1)
+            expected = [Row(one=1) for i in range(1000)]
+            result = df.select(return_one("id").alias("one")).collect()
+            self.assertEqual(expected, result)
+
 
 class ScalarPandasUDFTests(ScalarPandasUDFTestsMixin, ReusedSQLTestCase):
     @classmethod


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to