This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 50780af2d82 [SPARK-41961][SQL] Support table-valued functions with
LATERAL
50780af2d82 is described below
commit 50780af2d82689f7501a82d0ff9d5ace99f0703d
Author: allisonwang-db <[email protected]>
AuthorDate: Sat Jan 14 11:09:58 2023 +0800
[SPARK-41961][SQL] Support table-valued functions with LATERAL
### What changes were proposed in this pull request?
This PR allows table-valued functions to reference columns and aliases in
the preceding FROM items using the LATERAL keyword.
### Why are the changes needed?
To improve the usability of table-valued functions.
### Does this PR introduce _any_ user-facing change?
Yes. Before this PR, users cannot invoke a table-valued function in the
FROM clause with LATERAL:
```
SELECT * FROM t, LATERAL EXPLODE(ARRAY(t.c1, t.c2));
[INVALID_SQL_SYNTAX] Invalid SQL syntax: LATERAL can only be used with
subquery.
```
After this PR, this query can run successfully.
### How was this patch tested?
New SQL query tests.
Closes #39479 from allisonwang-db/spark-41961-lateral-tvf.
Authored-by: allisonwang-db <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../spark/sql/catalyst/analysis/Analyzer.scala | 2 +-
.../spark/sql/catalyst/optimizer/subquery.scala | 6 +
.../spark/sql/catalyst/parser/AstBuilder.scala | 19 ++-
.../spark/sql/errors/QueryParsingErrors.scala | 3 +-
.../OptimizeOneRowRelationSubquerySuite.scala | 28 ++++-
.../resources/sql-tests/inputs/join-lateral.sql | 19 +++
.../sql-tests/results/join-lateral.sql.out | 135 +++++++++++++++++++++
.../spark/sql/errors/QueryParsingErrorsSuite.scala | 6 +-
.../execution/datasources/SchemaPruningSuite.scala | 20 +++
9 files changed, 226 insertions(+), 12 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 3b3a011db97..d6b68a45e77 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -73,7 +73,7 @@ object SimpleAnalyzer extends Analyzer(
new SessionCatalog(
new InMemoryCatalog,
FunctionRegistry.builtin,
- EmptyTableFunctionRegistry) {
+ TableFunctionRegistry.builtin) {
override def createDatabase(dbDefinition: CatalogDatabase,
ignoreIfExists: Boolean): Unit = {}
})) {
override def resolver: Resolver = caseSensitiveResolution
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
index 4d2c64c7c32..faafeecc316 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
@@ -761,6 +761,7 @@ object OptimizeOneRowRelationSubquery extends
Rule[LogicalPlan] {
CollapseProject(EliminateSubqueryAliases(plan), alwaysInline =
alwaysInline) match {
case p @ Project(_, _: OneRowRelation) => Some(p)
case g @ Generate(_, _, _, _, _, _: OneRowRelation) => Some(g)
+ case p @ Project(_, Generate(_, _, _, _, _, _: OneRowRelation)) =>
Some(p)
case _ => None
}
}
@@ -787,6 +788,11 @@ object OptimizeOneRowRelationSubquery extends
Rule[LogicalPlan] {
val newGenerator = stripOuterReference(generator)
g.copy(generator = newGenerator, child = left)
+ case Project(projectList, g @ Generate(generator, _, _, _, _, _:
OneRowRelation)) =>
+ val newPList = stripOuterReferences(projectList)
+ val newGenerator = stripOuterReference(generator)
+ Project(left.output ++ newPList, g.copy(generator = newGenerator,
child = left))
+
case o =>
throw SparkException.internalError(
s"Unexpected plan when optimizing one row relation subquery: $o")
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index eb85aee25ce..e74fe5ce003 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -914,15 +914,19 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef]
with SQLConfHelper wit
*/
override def visitFromClause(ctx: FromClauseContext): LogicalPlan =
withOrigin(ctx) {
val from = ctx.relation.asScala.foldLeft(null: LogicalPlan) { (left,
relation) =>
+ val relationPrimary = relation.relationPrimary()
val right = if (conf.ansiRelationPrecedence) {
visitRelation(relation)
} else {
- plan(relation.relationPrimary)
+ plan(relationPrimary)
}
val join = right.optionalMap(left) { (left, right) =>
if (relation.LATERAL != null) {
- if (!relation.relationPrimary.isInstanceOf[AliasedQueryContext]) {
- throw
QueryParsingErrors.invalidLateralJoinRelationError(relation.relationPrimary)
+ relationPrimary match {
+ case _: AliasedQueryContext =>
+ case _: TableValuedFunctionContext =>
+ case other =>
+ throw QueryParsingErrors.invalidLateralJoinRelationError(other)
}
LateralJoin(left, LateralSubquery(right), Inner, None)
} else {
@@ -1295,8 +1299,13 @@ class AstBuilder extends
SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit
case _ => Inner
}
- if (ctx.LATERAL != null && !ctx.right.isInstanceOf[AliasedQueryContext])
{
- throw QueryParsingErrors.invalidLateralJoinRelationError(ctx.right)
+ if (ctx.LATERAL != null) {
+ ctx.right match {
+ case _: AliasedQueryContext =>
+ case _: TableValuedFunctionContext =>
+ case other =>
+ throw QueryParsingErrors.invalidLateralJoinRelationError(other)
+ }
}
// Resolve the join type and join condition
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala
index ef59dfa5517..29766251abd 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala
@@ -136,7 +136,8 @@ private[sql] object QueryParsingErrors extends
QueryErrorsBase {
new ParseException(
errorClass = "INVALID_SQL_SYNTAX",
messageParameters = Map(
- "inputString" -> s"${toSQLStmt("LATERAL")} can only be used with
subquery."),
+ "inputString" ->
+ s"${toSQLStmt("LATERAL")} can only be used with subquery and
table-valued functions."),
ctx)
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowRelationSubquerySuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowRelationSubquerySuite.scala
index 862af420bdb..a7f7cc1a9fb 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowRelationSubquerySuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowRelationSubquerySuite.scala
@@ -17,12 +17,12 @@
package org.apache.spark.sql.catalyst.optimizer
-import org.apache.spark.sql.catalyst.analysis.CleanupAliases
+import org.apache.spark.sql.catalyst.analysis.{CleanupAliases,
UnresolvedTableValuedFunction, UnresolvedTVFAliases}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Explode, ScalarSubquery}
import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical.{DomainJoin, LocalRelation,
LogicalPlan, OneRowRelation}
+import org.apache.spark.sql.catalyst.plans.logical.{DomainJoin, Generate,
LocalRelation, LogicalPlan, OneRowRelation, SubqueryAlias}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.IntegerType
@@ -177,4 +177,28 @@ class OptimizeOneRowRelationSubquerySuite extends PlanTest
{
val optimized = Optimize.execute(query2.analyze)
assertHasDomainJoin(optimized)
}
+
+ test("SPARK-41961: optimize lateral subquery with table-valued functions") {
+ // SELECT * FROM t3 JOIN LATERAL EXPLODE(arr)
+ val query1 = t3.lateralJoin(UnresolvedTableValuedFunction("explode",
$"arr" :: Nil))
+ comparePlans(
+ Optimize.execute(query1.analyze),
+ t3.generate(Explode($"arr")).analyze)
+
+ // SELECT * FROM t3 JOIN LATERAL EXPLODE(arr) t(v)
+ val query2 = t3.lateralJoin(
+ SubqueryAlias("t",
+ UnresolvedTVFAliases("explode" :: Nil,
+ UnresolvedTableValuedFunction("explode", $"arr" :: Nil), "v" ::
Nil)))
+ comparePlans(
+ Optimize.execute(query2.analyze),
+ t3.generate(Explode($"arr")).select($"a", $"b", $"arr",
$"col".as("v")).analyze)
+
+ // SELECT col FROM t3 JOIN LATERAL (SELECT * FROM EXPLODE(arr) WHERE col >
0)
+ val query3 = t3.lateralJoin(
+ UnresolvedTableValuedFunction("explode", $"arr" :: Nil).where($"col" >
0))
+ val optimized = Optimize.execute(query3.analyze)
+ optimized.exists(_.isInstanceOf[Generate])
+ assertHasDomainJoin(optimized)
+ }
}
diff --git a/sql/core/src/test/resources/sql-tests/inputs/join-lateral.sql
b/sql/core/src/test/resources/sql-tests/inputs/join-lateral.sql
index 6bf533d0509..92ea9587dc9 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/join-lateral.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/join-lateral.sql
@@ -177,6 +177,25 @@ SELECT * FROM t3 JOIN LATERAL (SELECT EXPLODE_OUTER(c2));
SELECT * FROM t3 JOIN LATERAL (SELECT EXPLODE(c2)) t(c3) ON c1 = c3;
SELECT * FROM t3 LEFT JOIN LATERAL (SELECT EXPLODE(c2)) t(c3) ON c1 = c3;
+-- SPARK-41961: lateral join with table-valued functions
+SELECT * FROM LATERAL EXPLODE(ARRAY(1, 2));
+SELECT * FROM t1, LATERAL RANGE(3);
+SELECT * FROM t1, LATERAL EXPLODE(ARRAY(c1, c2)) t2(c3);
+SELECT * FROM t3, LATERAL EXPLODE(c2) t2(v);
+SELECT * FROM t3, LATERAL EXPLODE_OUTER(c2) t2(v);
+SELECT * FROM EXPLODE(ARRAY(1, 2)) t(v), LATERAL (SELECT v + 1);
+
+-- lateral join with table-valued functions and join conditions
+SELECT * FROM t1 JOIN LATERAL EXPLODE(ARRAY(c1, c2)) t(c3) ON t1.c1 = c3;
+SELECT * FROM t3 JOIN LATERAL EXPLODE(c2) t(c3) ON t3.c1 = c3;
+SELECT * FROM t3 LEFT JOIN LATERAL EXPLODE(c2) t(c3) ON t3.c1 = c3;
+
+-- lateral join with table-valued functions in lateral subqueries
+SELECT * FROM t1, LATERAL (SELECT * FROM EXPLODE(ARRAY(c1, c2)));
+SELECT * FROM t1, LATERAL (SELECT t1.c1 + c3 FROM EXPLODE(ARRAY(c1, c2))
t(c3));
+SELECT * FROM t1, LATERAL (SELECT t1.c1 + c3 FROM EXPLODE(ARRAY(c1, c2)) t(c3)
WHERE t1.c2 > 1);
+SELECT * FROM t1, LATERAL (SELECT * FROM EXPLODE(ARRAY(c1, c2)) l(x) JOIN
EXPLODE(ARRAY(c2, c1)) r(y) ON x = y);
+
-- clean up
DROP VIEW t1;
DROP VIEW t2;
diff --git a/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out
b/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out
index e8e029bbe1e..9d6d542ca56 100644
--- a/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out
@@ -837,6 +837,141 @@ struct<c1:int,c2:array<int>,c3:int>
NULL [4] NULL
+-- !query
+SELECT * FROM LATERAL EXPLODE(ARRAY(1, 2))
+-- !query schema
+struct<col:int>
+-- !query output
+1
+2
+
+
+-- !query
+SELECT * FROM t1, LATERAL RANGE(3)
+-- !query schema
+struct<c1:int,c2:int,id:bigint>
+-- !query output
+0 1 0
+0 1 1
+0 1 2
+1 2 0
+1 2 1
+1 2 2
+
+
+-- !query
+SELECT * FROM t1, LATERAL EXPLODE(ARRAY(c1, c2)) t2(c3)
+-- !query schema
+struct<c1:int,c2:int,c3:int>
+-- !query output
+0 1 0
+0 1 1
+1 2 1
+1 2 2
+
+
+-- !query
+SELECT * FROM t3, LATERAL EXPLODE(c2) t2(v)
+-- !query schema
+struct<c1:int,c2:array<int>,v:int>
+-- !query output
+0 [0,1] 0
+0 [0,1] 1
+1 [2] 2
+NULL [4] 4
+
+
+-- !query
+SELECT * FROM t3, LATERAL EXPLODE_OUTER(c2) t2(v)
+-- !query schema
+struct<c1:int,c2:array<int>,v:int>
+-- !query output
+0 [0,1] 0
+0 [0,1] 1
+1 [2] 2
+2 [] NULL
+NULL [4] 4
+
+
+-- !query
+SELECT * FROM EXPLODE(ARRAY(1, 2)) t(v), LATERAL (SELECT v + 1)
+-- !query schema
+struct<v:int,(outer(t.v) + 1):int>
+-- !query output
+1 2
+2 3
+
+
+-- !query
+SELECT * FROM t1 JOIN LATERAL EXPLODE(ARRAY(c1, c2)) t(c3) ON t1.c1 = c3
+-- !query schema
+struct<c1:int,c2:int,c3:int>
+-- !query output
+0 1 0
+1 2 1
+
+
+-- !query
+SELECT * FROM t3 JOIN LATERAL EXPLODE(c2) t(c3) ON t3.c1 = c3
+-- !query schema
+struct<c1:int,c2:array<int>,c3:int>
+-- !query output
+0 [0,1] 0
+
+
+-- !query
+SELECT * FROM t3 LEFT JOIN LATERAL EXPLODE(c2) t(c3) ON t3.c1 = c3
+-- !query schema
+struct<c1:int,c2:array<int>,c3:int>
+-- !query output
+0 [0,1] 0
+1 [2] NULL
+2 [] NULL
+NULL [4] NULL
+
+
+-- !query
+SELECT * FROM t1, LATERAL (SELECT * FROM EXPLODE(ARRAY(c1, c2)))
+-- !query schema
+struct<c1:int,c2:int,col:int>
+-- !query output
+0 1 0
+0 1 1
+1 2 1
+1 2 2
+
+
+-- !query
+SELECT * FROM t1, LATERAL (SELECT t1.c1 + c3 FROM EXPLODE(ARRAY(c1, c2)) t(c3))
+-- !query schema
+struct<c1:int,c2:int,(outer(spark_catalog.default.t1.c1) + c3):int>
+-- !query output
+0 1 0
+0 1 1
+1 2 2
+1 2 3
+
+
+-- !query
+SELECT * FROM t1, LATERAL (SELECT t1.c1 + c3 FROM EXPLODE(ARRAY(c1, c2)) t(c3)
WHERE t1.c2 > 1)
+-- !query schema
+struct<c1:int,c2:int,(outer(spark_catalog.default.t1.c1) + c3):int>
+-- !query output
+1 2 2
+1 2 3
+
+
+-- !query
+SELECT * FROM t1, LATERAL (SELECT * FROM EXPLODE(ARRAY(c1, c2)) l(x) JOIN
EXPLODE(ARRAY(c2, c1)) r(y) ON x = y)
+-- !query schema
+struct<c1:int,c2:int,x:int,y:int>
+-- !query output
+0 1 0 0
+0 1 1 1
+1 2 1 1
+1 2 2 2
+
+
-- !query
DROP VIEW t1
-- !query schema
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala
index 50df5808790..87e9a3d0a12 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala
@@ -75,14 +75,14 @@ class QueryParsingErrorsSuite extends QueryTest with
SharedSparkSession {
" JOIN LATERAL t2" -> ("JOIN LATERAL t2", 17, 31),
", LATERAL (t2 JOIN t3)" -> ("FROM t1, LATERAL (t2 JOIN t3)", 9, 37),
", LATERAL (LATERAL t2)" -> ("FROM t1, LATERAL (LATERAL t2)", 9, 37),
- ", LATERAL VALUES (0, 1)" -> ("FROM t1, LATERAL VALUES (0, 1)", 9, 38),
- ", LATERAL RANGE(0, 1)" -> ("FROM t1, LATERAL RANGE(0, 1)", 9, 36)
+ ", LATERAL VALUES (0, 1)" -> ("FROM t1, LATERAL VALUES (0, 1)", 9, 38)
).foreach { case (sqlText, (fragment, start, stop)) =>
checkError(
exception = parseException(s"SELECT * FROM t1$sqlText"),
errorClass = "INVALID_SQL_SYNTAX",
sqlState = "42000",
- parameters = Map("inputString" -> "LATERAL can only be used with
subquery."),
+ parameters = Map("inputString" ->
+ "LATERAL can only be used with subquery and table-valued
functions."),
context = ExpectedContext(fragment, start, stop))
}
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
index 1f6a6bfda66..f9a8c67fc9f 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
@@ -398,6 +398,26 @@ abstract class SchemaPruningSuite
}
}
+ testSchemaPruning("SPARK-41961: nested schema pruning on table-valued
generator functions") {
+ val query1 = sql("select friend.first from contacts, lateral
explode(friends) t(friend)")
+ checkScan(query1, "struct<friends:array<struct<first:string>>>")
+ checkAnswer(query1, Row("Susan") :: Nil)
+
+ // Currently we don't prune multiple field case.
+ val query2 = sql(
+ "select friend.first, friend.middle from contacts, lateral
explode(friends) t(friend)")
+ checkScan(query2,
"struct<friends:array<struct<first:string,middle:string,last:string>>>")
+ checkAnswer(query2, Row("Susan", "Z.") :: Nil)
+
+ val query3 = sql(
+ """
+ |select friend.first, friend.middle, friend
+ |from contacts, lateral explode(friends) t(friend)
+ |""".stripMargin)
+ checkScan(query3,
"struct<friends:array<struct<first:string,middle:string,last:string>>>")
+ checkAnswer(query3, Row("Susan", "Z.", Row("Susan", "Z.", "Smith")) :: Nil)
+ }
+
testSchemaPruning("select one deep nested complex field after repartition") {
val query = sql("select * from contacts")
.repartition(100)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]