This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 39f8331cda44 [SPARK-53013][PYTHON] Fix Arrow-optimized Python UDTF returning no rows on lateral join 39f8331cda44 is described below commit 39f8331cda44caf7594c0803e6aa47132f55c0e7 Author: Takuya Ueshin <ues...@databricks.com> AuthorDate: Thu Jul 31 07:43:09 2025 +0900 [SPARK-53013][PYTHON] Fix Arrow-optimized Python UDTF returning no rows on lateral join ### What changes were proposed in this pull request? Fixes Arrow-optimized Python UDTF returning no rows on lateral join. ### Why are the changes needed? The new code path for Arrow-optimized Python UDTF won't work well with lateral join when it returns no nows. ```py >>> udtf(returnType="a: int", useArrow=True) ... class TestUDTF: ... def eval(self, i: int): ... for n in range(i): ... yield n, ... >>> spark.range(3, numPartitions=1).lateralJoin(TestUDTF(col("id").outer())).show() +---+---+ | id| a| +---+---+ | 0| 0| | 1| 0| | 1| 1| +---+---+ ``` It returns no rows when it takes `0` so there will be no rows for `id == 0`. Instead, this should be: ```py +---+---+ | id| a| +---+---+ | 1| 0| | 2| 0| | 2| 1| +---+---+ ``` The bug was introduced at https://github.com/apache/spark/pull/51659. ### Does this PR introduce _any_ user-facing change? Yes, Arrow-optimized Python UDTF returning no rows will work with lateral join. ### How was this patch tested? Added the related tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #51717 from ueshin/issues/SPARK-53013/lateral_empty. Authored-by: Takuya Ueshin <ues...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/sql/tests/test_udtf.py | 17 +++++++++++++++-- python/pyspark/worker.py | 25 ++++++++++++++----------- 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index 7f812ad20f59..1c473daff74e 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -183,8 +183,6 @@ class BaseUDTFTestsMixin: yield 1, yield 2, - self.spark.udtf.register("testUDTF", TestUDTF) - assertDataFrameEqual( self.spark.range(3, numPartitions=1).lateralJoin(TestUDTF()), [ @@ -197,6 +195,21 @@ class BaseUDTFTestsMixin: ], ) + @udtf(returnType="a: int") + class TestUDTF: + def eval(self, i: int): + for n in range(i): + yield n, + + assertDataFrameEqual( + self.spark.range(3, numPartitions=1).lateralJoin(TestUDTF(col("id").outer())), + [ + Row(id=1, a=0), + Row(id=2, a=0), + Row(id=2, a=1), + ], + ) + def test_udtf_eval_with_return_stmt(self): class TestUDTF: def eval(self, a: int, b: int): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index d330de85c4ee..b278273a161b 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -1289,7 +1289,6 @@ def use_legacy_pandas_udf_conversion(runner_conf): def read_udtf(pickleSer, infile, eval_type): prefers_large_var_types = False legacy_pandas_conversion = False - input_schema = None if eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF: runner_conf = {} @@ -1306,7 +1305,9 @@ def read_udtf(pickleSer, infile, eval_type): ).lower() == "true" ) - input_schema = _parse_datatype_json_string(utf8_deserializer.loads(infile)) + input_types = [ + field.dataType for field in _parse_datatype_json_string(utf8_deserializer.loads(infile)) + ] if legacy_pandas_conversion: # NOTE: if timezone is set here, that implies respectSessionTimeZone is True safecheck = ( @@ -1322,7 +1323,6 @@ def read_udtf(pickleSer, infile, eval_type): == "true" ) timezone = runner_conf.get("spark.sql.session.timeZone", None) - input_types = [field.dataType for field in input_schema] ser = ArrowStreamPandasUDTFSerializer( timezone, safecheck, @@ -1843,7 +1843,10 @@ def read_udtf(pickleSer, infile, eval_type): def convert_to_arrow(data: Iterable): data = list(check_return_value(data)) if len(data) == 0: - return pa.Table.from_pylist(data, schema=pa.schema(list(arrow_return_type))) + # Return one empty RecordBatch to match the left side of the lateral join + return [ + pa.RecordBatch.from_pylist(data, schema=pa.schema(list(arrow_return_type))) + ] def raise_conversion_error(original_exception): raise PySparkRuntimeError( @@ -1856,7 +1859,7 @@ def read_udtf(pickleSer, infile, eval_type): ) from original_exception try: - return LocalDataToArrowConversion.convert( + table = LocalDataToArrowConversion.convert( data, return_type, prefers_large_var_types ) except PySparkValueError as e: @@ -1878,15 +1881,17 @@ def read_udtf(pickleSer, infile, eval_type): except Exception as e: raise_conversion_error(e) + return verify_result(table).to_batches() + def evaluate(*args: list, num_rows=1): if len(args) == 0: for _ in range(num_rows): - for batch in verify_result(convert_to_arrow(func())).to_batches(): + for batch in convert_to_arrow(func()): yield batch, arrow_return_type else: for row in zip(*args): - for batch in verify_result(convert_to_arrow(func(*row))).to_batches(): + for batch in convert_to_arrow(func(*row)): yield batch, arrow_return_type return evaluate @@ -1906,10 +1911,8 @@ def read_udtf(pickleSer, infile, eval_type): def mapper(_, it): try: converters = [ - ArrowTableToRowsConversion._create_converter( - field.dataType, none_on_identity=True - ) - for field in input_schema + ArrowTableToRowsConversion._create_converter(dt, none_on_identity=True) + for dt in input_types ] for a in it: pylist = [ --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org