This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 50780af2d82 [SPARK-41961][SQL] Support table-valued functions with 
LATERAL
50780af2d82 is described below

commit 50780af2d82689f7501a82d0ff9d5ace99f0703d
Author: allisonwang-db <[email protected]>
AuthorDate: Sat Jan 14 11:09:58 2023 +0800

    [SPARK-41961][SQL] Support table-valued functions with LATERAL
    
    ### What changes were proposed in this pull request?
    This PR allows table-valued functions to reference columns and aliases in 
the preceding FROM items using the LATERAL keyword.
    
    ### Why are the changes needed?
    To improve the usability of table-valued functions.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. Before this PR, users cannot invoke a table-valued function in the 
FROM clause with LATERAL:
    ```
    SELECT * FROM t, LATERAL EXPLODE(ARRAY(t.c1, t.c2));
    [INVALID_SQL_SYNTAX] Invalid SQL syntax: LATERAL can only be used with 
subquery.
    ```
    After this PR, this query can run successfully.
    
    ### How was this patch tested?
    
    New SQL query tests.
    
    Closes #39479 from allisonwang-db/spark-41961-lateral-tvf.
    
    Authored-by: allisonwang-db <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../spark/sql/catalyst/analysis/Analyzer.scala     |   2 +-
 .../spark/sql/catalyst/optimizer/subquery.scala    |   6 +
 .../spark/sql/catalyst/parser/AstBuilder.scala     |  19 ++-
 .../spark/sql/errors/QueryParsingErrors.scala      |   3 +-
 .../OptimizeOneRowRelationSubquerySuite.scala      |  28 ++++-
 .../resources/sql-tests/inputs/join-lateral.sql    |  19 +++
 .../sql-tests/results/join-lateral.sql.out         | 135 +++++++++++++++++++++
 .../spark/sql/errors/QueryParsingErrorsSuite.scala |   6 +-
 .../execution/datasources/SchemaPruningSuite.scala |  20 +++
 9 files changed, 226 insertions(+), 12 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 3b3a011db97..d6b68a45e77 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -73,7 +73,7 @@ object SimpleAnalyzer extends Analyzer(
     new SessionCatalog(
       new InMemoryCatalog,
       FunctionRegistry.builtin,
-      EmptyTableFunctionRegistry) {
+      TableFunctionRegistry.builtin) {
       override def createDatabase(dbDefinition: CatalogDatabase, 
ignoreIfExists: Boolean): Unit = {}
     })) {
   override def resolver: Resolver = caseSensitiveResolution
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
index 4d2c64c7c32..faafeecc316 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
@@ -761,6 +761,7 @@ object OptimizeOneRowRelationSubquery extends 
Rule[LogicalPlan] {
       CollapseProject(EliminateSubqueryAliases(plan), alwaysInline = 
alwaysInline) match {
         case p @ Project(_, _: OneRowRelation) => Some(p)
         case g @ Generate(_, _, _, _, _, _: OneRowRelation) => Some(g)
+        case p @ Project(_, Generate(_, _, _, _, _, _: OneRowRelation)) => 
Some(p)
         case _ => None
       }
     }
@@ -787,6 +788,11 @@ object OptimizeOneRowRelationSubquery extends 
Rule[LogicalPlan] {
           val newGenerator = stripOuterReference(generator)
           g.copy(generator = newGenerator, child = left)
 
+        case Project(projectList, g @ Generate(generator, _, _, _, _, _: 
OneRowRelation)) =>
+          val newPList = stripOuterReferences(projectList)
+          val newGenerator = stripOuterReference(generator)
+          Project(left.output ++ newPList, g.copy(generator = newGenerator, 
child = left))
+
         case o =>
           throw SparkException.internalError(
             s"Unexpected plan when optimizing one row relation subquery: $o")
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index eb85aee25ce..e74fe5ce003 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -914,15 +914,19 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] 
with SQLConfHelper wit
    */
   override def visitFromClause(ctx: FromClauseContext): LogicalPlan = 
withOrigin(ctx) {
     val from = ctx.relation.asScala.foldLeft(null: LogicalPlan) { (left, 
relation) =>
+      val relationPrimary = relation.relationPrimary()
       val right = if (conf.ansiRelationPrecedence) {
         visitRelation(relation)
       } else {
-        plan(relation.relationPrimary)
+        plan(relationPrimary)
       }
       val join = right.optionalMap(left) { (left, right) =>
         if (relation.LATERAL != null) {
-          if (!relation.relationPrimary.isInstanceOf[AliasedQueryContext]) {
-            throw 
QueryParsingErrors.invalidLateralJoinRelationError(relation.relationPrimary)
+          relationPrimary match {
+            case _: AliasedQueryContext =>
+            case _: TableValuedFunctionContext =>
+            case other =>
+              throw QueryParsingErrors.invalidLateralJoinRelationError(other)
           }
           LateralJoin(left, LateralSubquery(right), Inner, None)
         } else {
@@ -1295,8 +1299,13 @@ class AstBuilder extends 
SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit
         case _ => Inner
       }
 
-      if (ctx.LATERAL != null && !ctx.right.isInstanceOf[AliasedQueryContext]) 
{
-        throw QueryParsingErrors.invalidLateralJoinRelationError(ctx.right)
+      if (ctx.LATERAL != null) {
+        ctx.right match {
+          case _: AliasedQueryContext =>
+          case _: TableValuedFunctionContext =>
+          case other =>
+            throw QueryParsingErrors.invalidLateralJoinRelationError(other)
+        }
       }
 
       // Resolve the join type and join condition
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala
index ef59dfa5517..29766251abd 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala
@@ -136,7 +136,8 @@ private[sql] object QueryParsingErrors extends 
QueryErrorsBase {
     new ParseException(
       errorClass = "INVALID_SQL_SYNTAX",
       messageParameters = Map(
-        "inputString" -> s"${toSQLStmt("LATERAL")} can only be used with 
subquery."),
+        "inputString" ->
+          s"${toSQLStmt("LATERAL")} can only be used with subquery and 
table-valued functions."),
       ctx)
   }
 
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowRelationSubquerySuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowRelationSubquerySuite.scala
index 862af420bdb..a7f7cc1a9fb 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowRelationSubquerySuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowRelationSubquerySuite.scala
@@ -17,12 +17,12 @@
 
 package org.apache.spark.sql.catalyst.optimizer
 
-import org.apache.spark.sql.catalyst.analysis.CleanupAliases
+import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, 
UnresolvedTableValuedFunction, UnresolvedTVFAliases}
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.expressions.{Explode, ScalarSubquery}
 import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical.{DomainJoin, LocalRelation, 
LogicalPlan, OneRowRelation}
+import org.apache.spark.sql.catalyst.plans.logical.{DomainJoin, Generate, 
LocalRelation, LogicalPlan, OneRowRelation, SubqueryAlias}
 import org.apache.spark.sql.catalyst.rules.RuleExecutor
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.IntegerType
@@ -177,4 +177,28 @@ class OptimizeOneRowRelationSubquerySuite extends PlanTest 
{
     val optimized = Optimize.execute(query2.analyze)
     assertHasDomainJoin(optimized)
   }
+
+  test("SPARK-41961: optimize lateral subquery with table-valued functions") {
+    // SELECT * FROM t3 JOIN LATERAL EXPLODE(arr)
+    val query1 = t3.lateralJoin(UnresolvedTableValuedFunction("explode", 
$"arr" :: Nil))
+    comparePlans(
+      Optimize.execute(query1.analyze),
+      t3.generate(Explode($"arr")).analyze)
+
+    // SELECT * FROM t3 JOIN LATERAL EXPLODE(arr) t(v)
+    val query2 = t3.lateralJoin(
+      SubqueryAlias("t",
+        UnresolvedTVFAliases("explode" :: Nil,
+          UnresolvedTableValuedFunction("explode", $"arr" :: Nil), "v" :: 
Nil)))
+    comparePlans(
+      Optimize.execute(query2.analyze),
+      t3.generate(Explode($"arr")).select($"a", $"b", $"arr", 
$"col".as("v")).analyze)
+
+    // SELECT col FROM t3 JOIN LATERAL (SELECT * FROM EXPLODE(arr) WHERE col > 
0)
+    val query3 = t3.lateralJoin(
+      UnresolvedTableValuedFunction("explode", $"arr" :: Nil).where($"col" > 
0))
+    val optimized = Optimize.execute(query3.analyze)
+    optimized.exists(_.isInstanceOf[Generate])
+    assertHasDomainJoin(optimized)
+  }
 }
diff --git a/sql/core/src/test/resources/sql-tests/inputs/join-lateral.sql 
b/sql/core/src/test/resources/sql-tests/inputs/join-lateral.sql
index 6bf533d0509..92ea9587dc9 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/join-lateral.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/join-lateral.sql
@@ -177,6 +177,25 @@ SELECT * FROM t3 JOIN LATERAL (SELECT EXPLODE_OUTER(c2));
 SELECT * FROM t3 JOIN LATERAL (SELECT EXPLODE(c2)) t(c3) ON c1 = c3;
 SELECT * FROM t3 LEFT JOIN LATERAL (SELECT EXPLODE(c2)) t(c3) ON c1 = c3;
 
+-- SPARK-41961: lateral join with table-valued functions
+SELECT * FROM LATERAL EXPLODE(ARRAY(1, 2));
+SELECT * FROM t1, LATERAL RANGE(3);
+SELECT * FROM t1, LATERAL EXPLODE(ARRAY(c1, c2)) t2(c3);
+SELECT * FROM t3, LATERAL EXPLODE(c2) t2(v);
+SELECT * FROM t3, LATERAL EXPLODE_OUTER(c2) t2(v);
+SELECT * FROM EXPLODE(ARRAY(1, 2)) t(v), LATERAL (SELECT v + 1);
+
+-- lateral join with table-valued functions and join conditions
+SELECT * FROM t1 JOIN LATERAL EXPLODE(ARRAY(c1, c2)) t(c3) ON t1.c1 = c3;
+SELECT * FROM t3 JOIN LATERAL EXPLODE(c2) t(c3) ON t3.c1 = c3;
+SELECT * FROM t3 LEFT JOIN LATERAL EXPLODE(c2) t(c3) ON t3.c1 = c3;
+
+-- lateral join with table-valued functions in lateral subqueries
+SELECT * FROM t1, LATERAL (SELECT * FROM EXPLODE(ARRAY(c1, c2)));
+SELECT * FROM t1, LATERAL (SELECT t1.c1 + c3 FROM EXPLODE(ARRAY(c1, c2)) 
t(c3));
+SELECT * FROM t1, LATERAL (SELECT t1.c1 + c3 FROM EXPLODE(ARRAY(c1, c2)) t(c3) 
WHERE t1.c2 > 1);
+SELECT * FROM t1, LATERAL (SELECT * FROM EXPLODE(ARRAY(c1, c2)) l(x) JOIN 
EXPLODE(ARRAY(c2, c1)) r(y) ON x = y);
+
 -- clean up
 DROP VIEW t1;
 DROP VIEW t2;
diff --git a/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out 
b/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out
index e8e029bbe1e..9d6d542ca56 100644
--- a/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out
@@ -837,6 +837,141 @@ struct<c1:int,c2:array<int>,c3:int>
 NULL   [4]     NULL
 
 
+-- !query
+SELECT * FROM LATERAL EXPLODE(ARRAY(1, 2))
+-- !query schema
+struct<col:int>
+-- !query output
+1
+2
+
+
+-- !query
+SELECT * FROM t1, LATERAL RANGE(3)
+-- !query schema
+struct<c1:int,c2:int,id:bigint>
+-- !query output
+0      1       0
+0      1       1
+0      1       2
+1      2       0
+1      2       1
+1      2       2
+
+
+-- !query
+SELECT * FROM t1, LATERAL EXPLODE(ARRAY(c1, c2)) t2(c3)
+-- !query schema
+struct<c1:int,c2:int,c3:int>
+-- !query output
+0      1       0
+0      1       1
+1      2       1
+1      2       2
+
+
+-- !query
+SELECT * FROM t3, LATERAL EXPLODE(c2) t2(v)
+-- !query schema
+struct<c1:int,c2:array<int>,v:int>
+-- !query output
+0      [0,1]   0
+0      [0,1]   1
+1      [2]     2
+NULL   [4]     4
+
+
+-- !query
+SELECT * FROM t3, LATERAL EXPLODE_OUTER(c2) t2(v)
+-- !query schema
+struct<c1:int,c2:array<int>,v:int>
+-- !query output
+0      [0,1]   0
+0      [0,1]   1
+1      [2]     2
+2      []      NULL
+NULL   [4]     4
+
+
+-- !query
+SELECT * FROM EXPLODE(ARRAY(1, 2)) t(v), LATERAL (SELECT v + 1)
+-- !query schema
+struct<v:int,(outer(t.v) + 1):int>
+-- !query output
+1      2
+2      3
+
+
+-- !query
+SELECT * FROM t1 JOIN LATERAL EXPLODE(ARRAY(c1, c2)) t(c3) ON t1.c1 = c3
+-- !query schema
+struct<c1:int,c2:int,c3:int>
+-- !query output
+0      1       0
+1      2       1
+
+
+-- !query
+SELECT * FROM t3 JOIN LATERAL EXPLODE(c2) t(c3) ON t3.c1 = c3
+-- !query schema
+struct<c1:int,c2:array<int>,c3:int>
+-- !query output
+0      [0,1]   0
+
+
+-- !query
+SELECT * FROM t3 LEFT JOIN LATERAL EXPLODE(c2) t(c3) ON t3.c1 = c3
+-- !query schema
+struct<c1:int,c2:array<int>,c3:int>
+-- !query output
+0      [0,1]   0
+1      [2]     NULL
+2      []      NULL
+NULL   [4]     NULL
+
+
+-- !query
+SELECT * FROM t1, LATERAL (SELECT * FROM EXPLODE(ARRAY(c1, c2)))
+-- !query schema
+struct<c1:int,c2:int,col:int>
+-- !query output
+0      1       0
+0      1       1
+1      2       1
+1      2       2
+
+
+-- !query
+SELECT * FROM t1, LATERAL (SELECT t1.c1 + c3 FROM EXPLODE(ARRAY(c1, c2)) t(c3))
+-- !query schema
+struct<c1:int,c2:int,(outer(spark_catalog.default.t1.c1) + c3):int>
+-- !query output
+0      1       0
+0      1       1
+1      2       2
+1      2       3
+
+
+-- !query
+SELECT * FROM t1, LATERAL (SELECT t1.c1 + c3 FROM EXPLODE(ARRAY(c1, c2)) t(c3) 
WHERE t1.c2 > 1)
+-- !query schema
+struct<c1:int,c2:int,(outer(spark_catalog.default.t1.c1) + c3):int>
+-- !query output
+1      2       2
+1      2       3
+
+
+-- !query
+SELECT * FROM t1, LATERAL (SELECT * FROM EXPLODE(ARRAY(c1, c2)) l(x) JOIN 
EXPLODE(ARRAY(c2, c1)) r(y) ON x = y)
+-- !query schema
+struct<c1:int,c2:int,x:int,y:int>
+-- !query output
+0      1       0       0
+0      1       1       1
+1      2       1       1
+1      2       2       2
+
+
 -- !query
 DROP VIEW t1
 -- !query schema
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala
index 50df5808790..87e9a3d0a12 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala
@@ -75,14 +75,14 @@ class QueryParsingErrorsSuite extends QueryTest with 
SharedSparkSession {
       " JOIN LATERAL t2" -> ("JOIN LATERAL t2", 17, 31),
       ", LATERAL (t2 JOIN t3)" -> ("FROM t1, LATERAL (t2 JOIN t3)", 9, 37),
       ", LATERAL (LATERAL t2)" -> ("FROM t1, LATERAL (LATERAL t2)", 9, 37),
-      ", LATERAL VALUES (0, 1)" -> ("FROM t1, LATERAL VALUES (0, 1)", 9, 38),
-      ", LATERAL RANGE(0, 1)" -> ("FROM t1, LATERAL RANGE(0, 1)", 9, 36)
+      ", LATERAL VALUES (0, 1)" -> ("FROM t1, LATERAL VALUES (0, 1)", 9, 38)
     ).foreach { case (sqlText, (fragment, start, stop)) =>
       checkError(
         exception = parseException(s"SELECT * FROM t1$sqlText"),
         errorClass = "INVALID_SQL_SYNTAX",
         sqlState = "42000",
-        parameters = Map("inputString" -> "LATERAL can only be used with 
subquery."),
+        parameters = Map("inputString" ->
+          "LATERAL can only be used with subquery and table-valued 
functions."),
         context = ExpectedContext(fragment, start, stop))
     }
   }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
index 1f6a6bfda66..f9a8c67fc9f 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
@@ -398,6 +398,26 @@ abstract class SchemaPruningSuite
     }
   }
 
+  testSchemaPruning("SPARK-41961: nested schema pruning on table-valued 
generator functions") {
+    val query1 = sql("select friend.first from contacts, lateral 
explode(friends) t(friend)")
+    checkScan(query1, "struct<friends:array<struct<first:string>>>")
+    checkAnswer(query1, Row("Susan") :: Nil)
+
+    // Currently we don't prune multiple field case.
+    val query2 = sql(
+      "select friend.first, friend.middle from contacts, lateral 
explode(friends) t(friend)")
+    checkScan(query2, 
"struct<friends:array<struct<first:string,middle:string,last:string>>>")
+    checkAnswer(query2, Row("Susan", "Z.") :: Nil)
+
+    val query3 = sql(
+      """
+        |select friend.first, friend.middle, friend
+        |from contacts, lateral explode(friends) t(friend)
+        |""".stripMargin)
+    checkScan(query3, 
"struct<friends:array<struct<first:string,middle:string,last:string>>>")
+    checkAnswer(query3, Row("Susan", "Z.", Row("Susan", "Z.", "Smith")) :: Nil)
+  }
+
   testSchemaPruning("select one deep nested complex field after repartition") {
     val query = sql("select * from contacts")
       .repartition(100)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to