This is an automated email from the ASF dual-hosted git repository. allisonwang pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new dab3464701e2 [SPARK-52982][PYTHON] Disallow lateral join with Arrow Python UDTFs dab3464701e2 is described below commit dab3464701e2ca4c661ba565d3b25b442e4834c4 Author: Allison Wang <allison.w...@databricks.com> AuthorDate: Fri Aug 22 14:25:44 2025 -0700 [SPARK-52982][PYTHON] Disallow lateral join with Arrow Python UDTFs ### What changes were proposed in this pull request? Arrow-optimized User-Defined Table Functions (UDTFs) have arbitrary output cardinality, meaning they can produce different numbers of rows for the same batch input rows. This fundamental characteristic is incompatible with lateral join semantics which combines each row from LHS with all rows from RHS for that particular input row. This PR implements a restriction that blocks Arrow Python UDTFs on the right hand side of lateral join: `SELECT * FROM table, LATERAL arrow_udtf(...)`. It is always recommended to use table argument with arrow UDTFs: `SELECT * FROM arrow_udtf(table(...))` Note Regular UDTFs with lateral joins are allowed (unchanged behavior). ### Why are the changes needed? Prevents runtime errors and inconsistent results while maintaining full Arrow UDTF functionality for supported patterns. ### Does this PR introduce _any_ user-facing change? Yes ### How was this patch tested? New unit tests. ### Was this patch authored or co-authored using generative AI tooling? No Closes #52048 from allisonwang-db/spark-52982-disallow-lateral-join. Authored-by: Allison Wang <allison.w...@databricks.com> Signed-off-by: Allison Wang <allison.w...@databricks.com> --- .../src/main/resources/error/error-conditions.json | 7 ++ python/pyspark/sql/tests/arrow/test_arrow_udtf.py | 130 +++++++++++++++++++++ .../spark/sql/catalyst/analysis/Analyzer.scala | 15 ++- .../sql/catalyst/analysis/CheckAnalysis.scala | 12 ++ .../plans/logical/basicLogicalOperators.scala | 9 ++ 5 files changed, 167 insertions(+), 6 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index f7db0b6761a5..234b0c3ed02d 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -4086,6 +4086,13 @@ ], "sqlState" : "42K0L" }, + "LATERAL_JOIN_WITH_ARROW_UDTF_UNSUPPORTED" : { + "message" : [ + "LATERAL JOIN with Arrow-optimized user-defined table functions (UDTFs) is not supported. Arrow UDTFs cannot be used on the right-hand side of a lateral join.", + "Please use a regular UDTF instead, or restructure your query to avoid the lateral join." + ], + "sqlState" : "0A000" + }, "LOAD_DATA_PATH_NOT_EXISTS" : { "message" : [ "LOAD DATA input path does not exist: <path>." diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udtf.py b/python/pyspark/sql/tests/arrow/test_arrow_udtf.py index 50fe6588eb92..75909fc88f82 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow_udtf.py +++ b/python/pyspark/sql/tests/arrow/test_arrow_udtf.py @@ -26,6 +26,7 @@ from pyspark.testing import assertDataFrameEqual if have_pyarrow: import pyarrow as pa + import pyarrow.compute as pc @unittest.skipIf(not have_pyarrow, pyarrow_requirement_message) @@ -460,6 +461,135 @@ class ArrowUDTFTestsMixin: ) assertDataFrameEqual(sql_result_df, expected_df) + def test_arrow_udtf_lateral_join_disallowed(self): + @arrow_udtf(returnType="x int, result int") + class SimpleArrowUDTF: + def eval(self, input_val: "pa.Array") -> Iterator["pa.Table"]: + val = input_val[0].as_py() + result_table = pa.table( + { + "x": pa.array([val], type=pa.int32()), + "result": pa.array([val * 2], type=pa.int32()), + } + ) + yield result_table + + self.spark.udtf.register("simple_arrow_udtf", SimpleArrowUDTF) + + test_df = self.spark.createDataFrame([(1,), (2,), (3,)], "id int") + test_df.createOrReplaceTempView("test_table") + + with self.assertRaisesRegex(Exception, "LATERAL_JOIN_WITH_ARROW_UDTF_UNSUPPORTED"): + self.spark.sql( + """ + SELECT t.id, f.x, f.result + FROM test_table t, LATERAL simple_arrow_udtf(t.id) f + """ + ) + + def test_arrow_udtf_lateral_join_with_table_argument_disallowed(self): + @arrow_udtf(returnType="filtered_id bigint") + class MixedArgsUDTF: + def eval(self, input_table: "pa.Table") -> Iterator["pa.Table"]: + filtered_data = input_table.filter(pc.greater(input_table["id"], 5)) + result_table = pa.table({"filtered_id": filtered_data["id"]}) + yield result_table + + self.spark.udtf.register("mixed_args_udtf", MixedArgsUDTF) + + test_df1 = self.spark.createDataFrame([(1,), (2,), (3,)], "id int") + test_df1.createOrReplaceTempView("test_table1") + + test_df2 = self.spark.createDataFrame([(6,), (7,), (8,)], "id bigint") + test_df2.createOrReplaceTempView("test_table2") + + # Table arguments create nested lateral joins where our CheckAnalysis rule doesn't trigger + # because the Arrow UDTF is in the inner lateral join, not the outer one our rule checks. + # So Spark's general lateral join validation catches this first with + # NON_DETERMINISTIC_LATERAL_SUBQUERIES. + with self.assertRaisesRegex( + Exception, + "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.NON_DETERMINISTIC_LATERAL_SUBQUERIES", + ): + self.spark.sql( + """ + SELECT t1.id, f.filtered_id + FROM test_table1 t1, LATERAL mixed_args_udtf(table(SELECT * FROM test_table2)) f + """ + ) + + def test_arrow_udtf_with_table_argument_then_lateral_join_allowed(self): + @arrow_udtf(returnType="processed_id bigint") + class TableArgUDTF: + def eval(self, input_table: "pa.Table") -> Iterator["pa.Table"]: + processed_data = pc.add(input_table["id"], 100) + result_table = pa.table({"processed_id": processed_data}) + yield result_table + + self.spark.udtf.register("table_arg_udtf", TableArgUDTF) + + source_df = self.spark.createDataFrame([(1,), (2,), (3,)], "id bigint") + source_df.createOrReplaceTempView("source_table") + + join_df = self.spark.createDataFrame([("A",), ("B",), ("C",)], "label string") + join_df.createOrReplaceTempView("join_table") + + result_df = self.spark.sql( + """ + SELECT f.processed_id, j.label + FROM table_arg_udtf(table(SELECT * FROM source_table)) f, + join_table j + ORDER BY f.processed_id, j.label + """ + ) + + expected_data = [ + (101, "A"), + (101, "B"), + (101, "C"), + (102, "A"), + (102, "B"), + (102, "C"), + (103, "A"), + (103, "B"), + (103, "C"), + ] + expected_df = self.spark.createDataFrame(expected_data, "processed_id bigint, label string") + assertDataFrameEqual(result_df, expected_df) + + def test_arrow_udtf_table_argument_with_regular_udtf_lateral_join_allowed(self): + @arrow_udtf(returnType="computed_value int") + class ComputeUDTF: + def eval(self, input_table: "pa.Table") -> Iterator["pa.Table"]: + total = pc.sum(input_table["value"]).as_py() + result_table = pa.table({"computed_value": pa.array([total], type=pa.int32())}) + yield result_table + + from pyspark.sql.functions import udtf + from pyspark.sql.types import StructType, StructField, IntegerType + + @udtf(returnType=StructType([StructField("multiplied", IntegerType())])) + class MultiplyUDTF: + def eval(self, input_val: int): + yield (input_val * 3,) + + self.spark.udtf.register("compute_udtf", ComputeUDTF) + self.spark.udtf.register("multiply_udtf", MultiplyUDTF) + + values_df = self.spark.createDataFrame([(10,), (20,), (30,)], "value int") + values_df.createOrReplaceTempView("values_table") + + result_df = self.spark.sql( + """ + SELECT c.computed_value, m.multiplied + FROM compute_udtf(table(SELECT * FROM values_table) WITH SINGLE PARTITION) c, + LATERAL multiply_udtf(c.computed_value) m + """ + ) + + expected_df = self.spark.createDataFrame([(60, 180)], "computed_value int, multiplied int") + assertDataFrameEqual(result_df, expected_df) + class ArrowUDTFTests(ArrowUDTFTestsMixin, ReusedSQLTestCase): pass diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b25e4d5d538f..1896a1c7ac27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2258,12 +2258,15 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor case _ => tvf } - Project( - Seq(UnresolvedStar(Some(Seq(alias)))), - LateralJoin( - tableArgs.map(_._2).reduceLeft(Join(_, _, Inner, None, JoinHint.NONE)), - LateralSubquery(SubqueryAlias(alias, tvfWithTableColumnIndexes)), Inner, None) - ) + val lateralJoin = LateralJoin( + tableArgs.map(_._2).reduceLeft(Join(_, _, Inner, None, JoinHint.NONE)), + LateralSubquery(SubqueryAlias(alias, tvfWithTableColumnIndexes)), Inner, None) + + // Set the tag so that it can be used to differentiate lateral join added by + // TABLE argument vs added by user. + lateralJoin.setTagValue(LateralJoin.BY_TABLE_ARGUMENT, ()) + + Project(Seq(UnresolvedStar(Some(Seq(alias)))), lateralJoin) } case q: LogicalPlan => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 0e7975a128bd..2ff842553bee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import scala.collection.mutable import org.apache.spark.{SparkException, SparkThrowable} +import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.ExtendedAnalysisException import org.apache.spark.sql.catalyst.analysis.ResolveWithCTE.checkIfSelfReferenceIsPlacedCorrectly @@ -889,6 +890,17 @@ trait CheckAnalysis extends LookupCatalog with QueryErrorsBase with PlanToString messageParameters = Map( "invalidExprSqls" -> invalidExprSqls.mkString(", "))) + case j @ LateralJoin(_, right, _, _) + if j.getTagValue(LateralJoin.BY_TABLE_ARGUMENT).isEmpty => + right.plan.foreach { + case Generate(pyudtf: PythonUDTF, _, _, _, _, _) + if pyudtf.evalType == PythonEvalType.SQL_ARROW_UDTF => + j.failAnalysis( + errorClass = "LATERAL_JOIN_WITH_ARROW_UDTF_UNSUPPORTED", + messageParameters = Map.empty) + case _ => + } + case _ => // Analysis successful! } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 90ff9146548a..add31448bef7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -2123,6 +2123,15 @@ case class LateralJoin( } } + +object LateralJoin { + /** + * A tag to identify if a Lateral Join is added by resolving table argument. + */ + val BY_TABLE_ARGUMENT = TreeNodeTag[Unit]("by_table_argument") +} + + /** * A logical plan for as-of join. */ --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org