This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 39f8331cda44 [SPARK-53013][PYTHON] Fix Arrow-optimized Python UDTF 
returning no rows on lateral join
39f8331cda44 is described below

commit 39f8331cda44caf7594c0803e6aa47132f55c0e7
Author: Takuya Ueshin <ues...@databricks.com>
AuthorDate: Thu Jul 31 07:43:09 2025 +0900

    [SPARK-53013][PYTHON] Fix Arrow-optimized Python UDTF returning no rows on 
lateral join
    
    ### What changes were proposed in this pull request?
    
    Fixes Arrow-optimized Python UDTF returning no rows on lateral join.
    
    ### Why are the changes needed?
    
    The new code path for Arrow-optimized Python UDTF won't work well with 
lateral join when it returns no nows.
    
    ```py
    >>> udtf(returnType="a: int", useArrow=True)
    ... class TestUDTF:
    ...   def eval(self, i: int):
    ...     for n in range(i):
    ...       yield n,
    ...
    >>> spark.range(3, 
numPartitions=1).lateralJoin(TestUDTF(col("id").outer())).show()
    +---+---+
    | id|  a|
    +---+---+
    |  0|  0|
    |  1|  0|
    |  1|  1|
    +---+---+
    ```
    
    It returns no rows when it takes `0` so there will be no rows for `id == 
0`. Instead, this should be:
    
    ```py
    +---+---+
    | id|  a|
    +---+---+
    |  1|  0|
    |  2|  0|
    |  2|  1|
    +---+---+
    ```
    
    The bug was introduced at https://github.com/apache/spark/pull/51659.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, Arrow-optimized Python UDTF returning no rows will work with lateral 
join.
    
    ### How was this patch tested?
    
    Added the related tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #51717 from ueshin/issues/SPARK-53013/lateral_empty.
    
    Authored-by: Takuya Ueshin <ues...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/tests/test_udtf.py | 17 +++++++++++++++--
 python/pyspark/worker.py              | 25 ++++++++++++++-----------
 2 files changed, 29 insertions(+), 13 deletions(-)

diff --git a/python/pyspark/sql/tests/test_udtf.py 
b/python/pyspark/sql/tests/test_udtf.py
index 7f812ad20f59..1c473daff74e 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -183,8 +183,6 @@ class BaseUDTFTestsMixin:
                 yield 1,
                 yield 2,
 
-        self.spark.udtf.register("testUDTF", TestUDTF)
-
         assertDataFrameEqual(
             self.spark.range(3, numPartitions=1).lateralJoin(TestUDTF()),
             [
@@ -197,6 +195,21 @@ class BaseUDTFTestsMixin:
             ],
         )
 
+        @udtf(returnType="a: int")
+        class TestUDTF:
+            def eval(self, i: int):
+                for n in range(i):
+                    yield n,
+
+        assertDataFrameEqual(
+            self.spark.range(3, 
numPartitions=1).lateralJoin(TestUDTF(col("id").outer())),
+            [
+                Row(id=1, a=0),
+                Row(id=2, a=0),
+                Row(id=2, a=1),
+            ],
+        )
+
     def test_udtf_eval_with_return_stmt(self):
         class TestUDTF:
             def eval(self, a: int, b: int):
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index d330de85c4ee..b278273a161b 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -1289,7 +1289,6 @@ def use_legacy_pandas_udf_conversion(runner_conf):
 def read_udtf(pickleSer, infile, eval_type):
     prefers_large_var_types = False
     legacy_pandas_conversion = False
-    input_schema = None
 
     if eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF:
         runner_conf = {}
@@ -1306,7 +1305,9 @@ def read_udtf(pickleSer, infile, eval_type):
             ).lower()
             == "true"
         )
-        input_schema = 
_parse_datatype_json_string(utf8_deserializer.loads(infile))
+        input_types = [
+            field.dataType for field in 
_parse_datatype_json_string(utf8_deserializer.loads(infile))
+        ]
         if legacy_pandas_conversion:
             # NOTE: if timezone is set here, that implies 
respectSessionTimeZone is True
             safecheck = (
@@ -1322,7 +1323,6 @@ def read_udtf(pickleSer, infile, eval_type):
                 == "true"
             )
             timezone = runner_conf.get("spark.sql.session.timeZone", None)
-            input_types = [field.dataType for field in input_schema]
             ser = ArrowStreamPandasUDTFSerializer(
                 timezone,
                 safecheck,
@@ -1843,7 +1843,10 @@ def read_udtf(pickleSer, infile, eval_type):
             def convert_to_arrow(data: Iterable):
                 data = list(check_return_value(data))
                 if len(data) == 0:
-                    return pa.Table.from_pylist(data, 
schema=pa.schema(list(arrow_return_type)))
+                    # Return one empty RecordBatch to match the left side of 
the lateral join
+                    return [
+                        pa.RecordBatch.from_pylist(data, 
schema=pa.schema(list(arrow_return_type)))
+                    ]
 
                 def raise_conversion_error(original_exception):
                     raise PySparkRuntimeError(
@@ -1856,7 +1859,7 @@ def read_udtf(pickleSer, infile, eval_type):
                     ) from original_exception
 
                 try:
-                    return LocalDataToArrowConversion.convert(
+                    table = LocalDataToArrowConversion.convert(
                         data, return_type, prefers_large_var_types
                     )
                 except PySparkValueError as e:
@@ -1878,15 +1881,17 @@ def read_udtf(pickleSer, infile, eval_type):
                 except Exception as e:
                     raise_conversion_error(e)
 
+                return verify_result(table).to_batches()
+
             def evaluate(*args: list, num_rows=1):
                 if len(args) == 0:
                     for _ in range(num_rows):
-                        for batch in 
verify_result(convert_to_arrow(func())).to_batches():
+                        for batch in convert_to_arrow(func()):
                             yield batch, arrow_return_type
 
                 else:
                     for row in zip(*args):
-                        for batch in 
verify_result(convert_to_arrow(func(*row))).to_batches():
+                        for batch in convert_to_arrow(func(*row)):
                             yield batch, arrow_return_type
 
             return evaluate
@@ -1906,10 +1911,8 @@ def read_udtf(pickleSer, infile, eval_type):
         def mapper(_, it):
             try:
                 converters = [
-                    ArrowTableToRowsConversion._create_converter(
-                        field.dataType, none_on_identity=True
-                    )
-                    for field in input_schema
+                    ArrowTableToRowsConversion._create_converter(dt, 
none_on_identity=True)
+                    for dt in input_types
                 ]
                 for a in it:
                     pylist = [


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to