This is an automated email from the ASF dual-hosted git repository.

allisonwang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new dab3464701e2 [SPARK-52982][PYTHON] Disallow lateral join with Arrow 
Python UDTFs
dab3464701e2 is described below

commit dab3464701e2ca4c661ba565d3b25b442e4834c4
Author: Allison Wang <allison.w...@databricks.com>
AuthorDate: Fri Aug 22 14:25:44 2025 -0700

    [SPARK-52982][PYTHON] Disallow lateral join with Arrow Python UDTFs
    
    ### What changes were proposed in this pull request?
    
    Arrow-optimized User-Defined Table Functions (UDTFs) have arbitrary output 
cardinality, meaning they can produce different numbers of rows for the same 
batch input rows. This fundamental characteristic is incompatible with lateral 
join semantics which combines each row from LHS with all rows from RHS for that 
particular input row. This PR implements a restriction that blocks Arrow Python 
UDTFs on the right hand side of lateral join: `SELECT * FROM table, LATERAL 
arrow_udtf(...)`.
    
    It is always recommended to use table argument with arrow UDTFs:  `SELECT * 
FROM arrow_udtf(table(...))`
    
    Note Regular UDTFs with lateral joins are allowed (unchanged behavior).
    
    ### Why are the changes needed?
    
    Prevents runtime errors and inconsistent results while maintaining full 
Arrow UDTF functionality for supported patterns.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes
    
    ### How was this patch tested?
    
    New unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #52048 from allisonwang-db/spark-52982-disallow-lateral-join.
    
    Authored-by: Allison Wang <allison.w...@databricks.com>
    Signed-off-by: Allison Wang <allison.w...@databricks.com>
---
 .../src/main/resources/error/error-conditions.json |   7 ++
 python/pyspark/sql/tests/arrow/test_arrow_udtf.py  | 130 +++++++++++++++++++++
 .../spark/sql/catalyst/analysis/Analyzer.scala     |  15 ++-
 .../sql/catalyst/analysis/CheckAnalysis.scala      |  12 ++
 .../plans/logical/basicLogicalOperators.scala      |   9 ++
 5 files changed, 167 insertions(+), 6 deletions(-)

diff --git a/common/utils/src/main/resources/error/error-conditions.json 
b/common/utils/src/main/resources/error/error-conditions.json
index f7db0b6761a5..234b0c3ed02d 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -4086,6 +4086,13 @@
     ],
     "sqlState" : "42K0L"
   },
+  "LATERAL_JOIN_WITH_ARROW_UDTF_UNSUPPORTED" : {
+    "message" : [
+      "LATERAL JOIN with Arrow-optimized user-defined table functions (UDTFs) 
is not supported. Arrow UDTFs cannot be used on the right-hand side of a 
lateral join.",
+      "Please use a regular UDTF instead, or restructure your query to avoid 
the lateral join."
+    ],
+    "sqlState" : "0A000"
+  },
   "LOAD_DATA_PATH_NOT_EXISTS" : {
     "message" : [
       "LOAD DATA input path does not exist: <path>."
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udtf.py 
b/python/pyspark/sql/tests/arrow/test_arrow_udtf.py
index 50fe6588eb92..75909fc88f82 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_udtf.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udtf.py
@@ -26,6 +26,7 @@ from pyspark.testing import assertDataFrameEqual
 
 if have_pyarrow:
     import pyarrow as pa
+    import pyarrow.compute as pc
 
 
 @unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
@@ -460,6 +461,135 @@ class ArrowUDTFTestsMixin:
         )
         assertDataFrameEqual(sql_result_df, expected_df)
 
+    def test_arrow_udtf_lateral_join_disallowed(self):
+        @arrow_udtf(returnType="x int, result int")
+        class SimpleArrowUDTF:
+            def eval(self, input_val: "pa.Array") -> Iterator["pa.Table"]:
+                val = input_val[0].as_py()
+                result_table = pa.table(
+                    {
+                        "x": pa.array([val], type=pa.int32()),
+                        "result": pa.array([val * 2], type=pa.int32()),
+                    }
+                )
+                yield result_table
+
+        self.spark.udtf.register("simple_arrow_udtf", SimpleArrowUDTF)
+
+        test_df = self.spark.createDataFrame([(1,), (2,), (3,)], "id int")
+        test_df.createOrReplaceTempView("test_table")
+
+        with self.assertRaisesRegex(Exception, 
"LATERAL_JOIN_WITH_ARROW_UDTF_UNSUPPORTED"):
+            self.spark.sql(
+                """
+                SELECT t.id, f.x, f.result
+                FROM test_table t, LATERAL simple_arrow_udtf(t.id) f
+                """
+            )
+
+    def test_arrow_udtf_lateral_join_with_table_argument_disallowed(self):
+        @arrow_udtf(returnType="filtered_id bigint")
+        class MixedArgsUDTF:
+            def eval(self, input_table: "pa.Table") -> Iterator["pa.Table"]:
+                filtered_data = 
input_table.filter(pc.greater(input_table["id"], 5))
+                result_table = pa.table({"filtered_id": filtered_data["id"]})
+                yield result_table
+
+        self.spark.udtf.register("mixed_args_udtf", MixedArgsUDTF)
+
+        test_df1 = self.spark.createDataFrame([(1,), (2,), (3,)], "id int")
+        test_df1.createOrReplaceTempView("test_table1")
+
+        test_df2 = self.spark.createDataFrame([(6,), (7,), (8,)], "id bigint")
+        test_df2.createOrReplaceTempView("test_table2")
+
+        # Table arguments create nested lateral joins where our CheckAnalysis 
rule doesn't trigger
+        # because the Arrow UDTF is in the inner lateral join, not the outer 
one our rule checks.
+        # So Spark's general lateral join validation catches this first with
+        # NON_DETERMINISTIC_LATERAL_SUBQUERIES.
+        with self.assertRaisesRegex(
+            Exception,
+            
"UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.NON_DETERMINISTIC_LATERAL_SUBQUERIES",
+        ):
+            self.spark.sql(
+                """
+                SELECT t1.id, f.filtered_id
+                FROM test_table1 t1, LATERAL mixed_args_udtf(table(SELECT * 
FROM test_table2)) f
+                """
+            )
+
+    def test_arrow_udtf_with_table_argument_then_lateral_join_allowed(self):
+        @arrow_udtf(returnType="processed_id bigint")
+        class TableArgUDTF:
+            def eval(self, input_table: "pa.Table") -> Iterator["pa.Table"]:
+                processed_data = pc.add(input_table["id"], 100)
+                result_table = pa.table({"processed_id": processed_data})
+                yield result_table
+
+        self.spark.udtf.register("table_arg_udtf", TableArgUDTF)
+
+        source_df = self.spark.createDataFrame([(1,), (2,), (3,)], "id bigint")
+        source_df.createOrReplaceTempView("source_table")
+
+        join_df = self.spark.createDataFrame([("A",), ("B",), ("C",)], "label 
string")
+        join_df.createOrReplaceTempView("join_table")
+
+        result_df = self.spark.sql(
+            """
+            SELECT f.processed_id, j.label
+            FROM table_arg_udtf(table(SELECT * FROM source_table)) f,
+                join_table j
+            ORDER BY f.processed_id, j.label
+            """
+        )
+
+        expected_data = [
+            (101, "A"),
+            (101, "B"),
+            (101, "C"),
+            (102, "A"),
+            (102, "B"),
+            (102, "C"),
+            (103, "A"),
+            (103, "B"),
+            (103, "C"),
+        ]
+        expected_df = self.spark.createDataFrame(expected_data, "processed_id 
bigint, label string")
+        assertDataFrameEqual(result_df, expected_df)
+
+    def 
test_arrow_udtf_table_argument_with_regular_udtf_lateral_join_allowed(self):
+        @arrow_udtf(returnType="computed_value int")
+        class ComputeUDTF:
+            def eval(self, input_table: "pa.Table") -> Iterator["pa.Table"]:
+                total = pc.sum(input_table["value"]).as_py()
+                result_table = pa.table({"computed_value": pa.array([total], 
type=pa.int32())})
+                yield result_table
+
+        from pyspark.sql.functions import udtf
+        from pyspark.sql.types import StructType, StructField, IntegerType
+
+        @udtf(returnType=StructType([StructField("multiplied", 
IntegerType())]))
+        class MultiplyUDTF:
+            def eval(self, input_val: int):
+                yield (input_val * 3,)
+
+        self.spark.udtf.register("compute_udtf", ComputeUDTF)
+        self.spark.udtf.register("multiply_udtf", MultiplyUDTF)
+
+        values_df = self.spark.createDataFrame([(10,), (20,), (30,)], "value 
int")
+        values_df.createOrReplaceTempView("values_table")
+
+        result_df = self.spark.sql(
+            """
+            SELECT c.computed_value, m.multiplied
+            FROM compute_udtf(table(SELECT * FROM values_table) WITH SINGLE 
PARTITION) c,
+                LATERAL multiply_udtf(c.computed_value) m
+            """
+        )
+
+        expected_df = self.spark.createDataFrame([(60, 180)], "computed_value 
int, multiplied int")
+        assertDataFrameEqual(result_df, expected_df)
+
 
 class ArrowUDTFTests(ArrowUDTFTestsMixin, ReusedSQLTestCase):
     pass
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index b25e4d5d538f..1896a1c7ac27 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -2258,12 +2258,15 @@ class Analyzer(override val catalogManager: 
CatalogManager) extends RuleExecutor
             case _ => tvf
           }
 
-          Project(
-            Seq(UnresolvedStar(Some(Seq(alias)))),
-            LateralJoin(
-              tableArgs.map(_._2).reduceLeft(Join(_, _, Inner, None, 
JoinHint.NONE)),
-              LateralSubquery(SubqueryAlias(alias, 
tvfWithTableColumnIndexes)), Inner, None)
-          )
+          val lateralJoin = LateralJoin(
+            tableArgs.map(_._2).reduceLeft(Join(_, _, Inner, None, 
JoinHint.NONE)),
+            LateralSubquery(SubqueryAlias(alias, tvfWithTableColumnIndexes)), 
Inner, None)
+
+          // Set the tag so that it can be used to differentiate lateral join 
added by
+          // TABLE argument vs added by user.
+          lateralJoin.setTagValue(LateralJoin.BY_TABLE_ARGUMENT, ())
+
+          Project(Seq(UnresolvedStar(Some(Seq(alias)))), lateralJoin)
         }
 
       case q: LogicalPlan =>
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 0e7975a128bd..2ff842553bee 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis
 import scala.collection.mutable
 
 import org.apache.spark.{SparkException, SparkThrowable}
+import org.apache.spark.api.python.PythonEvalType
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.ExtendedAnalysisException
 import 
org.apache.spark.sql.catalyst.analysis.ResolveWithCTE.checkIfSelfReferenceIsPlacedCorrectly
@@ -889,6 +890,17 @@ trait CheckAnalysis extends LookupCatalog with 
QueryErrorsBase with PlanToString
               messageParameters = Map(
                 "invalidExprSqls" -> invalidExprSqls.mkString(", ")))
 
+          case j @ LateralJoin(_, right, _, _)
+              if j.getTagValue(LateralJoin.BY_TABLE_ARGUMENT).isEmpty =>
+            right.plan.foreach {
+              case Generate(pyudtf: PythonUDTF, _, _, _, _, _)
+                  if pyudtf.evalType == PythonEvalType.SQL_ARROW_UDTF =>
+                  j.failAnalysis(
+                    errorClass = "LATERAL_JOIN_WITH_ARROW_UDTF_UNSUPPORTED",
+                    messageParameters = Map.empty)
+              case _ =>
+            }
+
           case _ => // Analysis successful!
         }
     }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 90ff9146548a..add31448bef7 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -2123,6 +2123,15 @@ case class LateralJoin(
   }
 }
 
+
+object LateralJoin {
+  /**
+   * A tag to identify if a Lateral Join is added by resolving table argument.
+   */
+  val BY_TABLE_ARGUMENT = TreeNodeTag[Unit]("by_table_argument")
+}
+
+
 /**
  * A logical plan for as-of join.
  */


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to