This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8390b03df62 [SPARK-44200][SQL] Support TABLE argument parser rule for 
TableValuedFunction
8390b03df62 is described below

commit 8390b03df62e7f808dc214c69e340fc1e70fb517
Author: Takuya UESHIN <ues...@databricks.com>
AuthorDate: Mon Jul 3 16:26:03 2023 +0900

    [SPARK-44200][SQL] Support TABLE argument parser rule for 
TableValuedFunction
    
    ### What changes were proposed in this pull request?
    
    Adds a new SQL syntax for `TableValuedFunction`.
    
    The syntax supports passing such relations one of two ways:
    
    1. `SELECT ... FROM tvf_call(TABLE t)`
    2. `SELECT ... FROM tvf_call(TABLE (<query>))`
    
    In the former case, the relation argument directly refers to the name of a 
table in the catalog. In the latter case, the relation argument comprises a 
table subquery that may itself refer to one or more tables in its own FROM 
clause.
    
    For example, for the given user defined table values function:
    
    ```py
    udtf(returnType="a: int")
    class TestUDTF:
        def eval(self, row: Row):
            if row[0] > 5:
                yield row[0],
    
    spark.udtf.register("test_udtf", TestUDTF)
    
    spark.sql("CREATE OR REPLACE TEMPORARY VIEW v as SELECT id FROM range(0, 
8)")
    ```
    
    , the following SQLs should work:
    
    ```py
    >>> spark.sql("SELECT * FROM test_udtf(TABLE v)").collect()
    [Row(a=6), Row(a=7)]
    ```
    
    or
    
    ```py
    >>> spark.sql("SELECT * FROM test_udtf(TABLE (SELECT id + 1 FROM 
v))").collect()
    [Row(a=6), Row(a=7), Row(a=8)]
    ```
    
    ### Why are the changes needed?
    
    To support `TABLE` argument parser rule for TableValuedFunction.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, new syntax for SQL.
    
    ### How was this patch tested?
    
    Added the related tests.
    
    Closes #41750 from ueshin/issues/SPARK-44200/table_argument.
    
    Authored-by: Takuya UESHIN <ues...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../src/main/resources/error/error-classes.json    |   5 +
 python/pyspark/sql/tests/test_udtf.py              | 194 ++++++++++++++++++++-
 .../spark/sql/catalyst/parser/SqlBaseParser.g4     |  23 ++-
 .../spark/sql/catalyst/analysis/Analyzer.scala     |  30 +++-
 .../FunctionTableSubqueryArgumentExpression.scala  |  65 +++++++
 .../spark/sql/catalyst/parser/AstBuilder.scala     |  37 +++-
 .../plans/logical/basicLogicalOperators.scala      |   5 +
 .../spark/sql/catalyst/trees/TreePatterns.scala    |   1 +
 .../spark/sql/errors/QueryCompilationErrors.scala  |   7 +
 .../org/apache/spark/sql/internal/SQLConf.scala    |  10 ++
 .../sql/catalyst/parser/PlanParserSuite.scala      |  38 ++++
 .../apache/spark/sql/catalyst/plans/PlanTest.scala |   2 +
 .../spark/sql/errors/QueryParsingErrorsSuite.scala |   1 -
 13 files changed, 411 insertions(+), 7 deletions(-)

diff --git a/common/utils/src/main/resources/error/error-classes.json 
b/common/utils/src/main/resources/error/error-classes.json
index 027d09eae10..753701cf581 100644
--- a/common/utils/src/main/resources/error/error-classes.json
+++ b/common/utils/src/main/resources/error/error-classes.json
@@ -2198,6 +2198,11 @@
     ],
     "sqlState" : "42P01"
   },
+  "TABLE_VALUED_FUNCTION_TOO_MANY_TABLE_ARGUMENTS" : {
+    "message" : [
+      "There are too many table arguments for table-valued function. It allows 
one table argument, but got: <num>. If you want to allow it, please set 
\"spark.sql.allowMultipleTableArguments.enabled\" to \"true\""
+    ]
+  },
   "TASK_WRITE_FAILED" : {
     "message" : [
       "Task failed while writing rows to <path>."
diff --git a/python/pyspark/sql/tests/test_udtf.py 
b/python/pyspark/sql/tests/test_udtf.py
index ccf271ceec2..43ab0795042 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -27,7 +27,7 @@ from pyspark.sql.types import Row
 from pyspark.testing.sqlutils import ReusedSQLTestCase
 
 
-class UDTFTestsMixin(ReusedSQLTestCase):
+class UDTFTestsMixin:
     def test_simple_udtf(self):
         class TestUDTF:
             def eval(self):
@@ -397,6 +397,198 @@ class UDTFTestsMixin(ReusedSQLTestCase):
         with self.assertRaisesRegex(TypeError, err_msg):
             udtf(test_udtf, returnType="a: int")
 
+    def test_udtf_with_table_argument_query(self):
+        class TestUDTF:
+            def eval(self, row: Row):
+                if row["id"] > 5:
+                    yield row["id"],
+
+        func = udtf(TestUDTF, returnType="a: int")
+        self.spark.udtf.register("test_udtf", func)
+        self.assertEqual(
+            self.spark.sql("SELECT * FROM test_udtf(TABLE (SELECT id FROM 
range(0, 8)))").collect(),
+            [Row(a=6), Row(a=7)],
+        )
+
+    def test_udtf_with_int_and_table_argument_query(self):
+        class TestUDTF:
+            def eval(self, i: int, row: Row):
+                if row["id"] > i:
+                    yield row["id"],
+
+        func = udtf(TestUDTF, returnType="a: int")
+        self.spark.udtf.register("test_udtf", func)
+        self.assertEqual(
+            self.spark.sql(
+                "SELECT * FROM test_udtf(5, TABLE (SELECT id FROM range(0, 
8)))"
+            ).collect(),
+            [Row(a=6), Row(a=7)],
+        )
+
+    def test_udtf_with_table_argument_identifier(self):
+        class TestUDTF:
+            def eval(self, row: Row):
+                if row["id"] > 5:
+                    yield row["id"],
+
+        func = udtf(TestUDTF, returnType="a: int")
+        self.spark.udtf.register("test_udtf", func)
+
+        with self.tempView("v"):
+            self.spark.sql("CREATE OR REPLACE TEMPORARY VIEW v as SELECT id 
FROM range(0, 8)")
+            self.assertEqual(
+                self.spark.sql("SELECT * FROM test_udtf(TABLE v)").collect(),
+                [Row(a=6), Row(a=7)],
+            )
+
+    def test_udtf_with_int_and_table_argument_identifier(self):
+        class TestUDTF:
+            def eval(self, i: int, row: Row):
+                if row["id"] > i:
+                    yield row["id"],
+
+        func = udtf(TestUDTF, returnType="a: int")
+        self.spark.udtf.register("test_udtf", func)
+
+        with self.tempView("v"):
+            self.spark.sql("CREATE OR REPLACE TEMPORARY VIEW v as SELECT id 
FROM range(0, 8)")
+            self.assertEqual(
+                self.spark.sql("SELECT * FROM test_udtf(5, TABLE 
v)").collect(),
+                [Row(a=6), Row(a=7)],
+            )
+
+    def test_udtf_with_table_argument_unknown_identifier(self):
+        class TestUDTF:
+            def eval(self, row: Row):
+                if row["id"] > 5:
+                    yield row["id"],
+
+        func = udtf(TestUDTF, returnType="a: int")
+        self.spark.udtf.register("test_udtf", func)
+
+        with self.assertRaisesRegex(AnalysisException, 
"TABLE_OR_VIEW_NOT_FOUND"):
+            self.spark.sql("SELECT * FROM test_udtf(TABLE v)").collect()
+
+    def test_udtf_with_table_argument_malformed_query(self):
+        class TestUDTF:
+            def eval(self, row: Row):
+                if row["id"] > 5:
+                    yield row["id"],
+
+        func = udtf(TestUDTF, returnType="a: int")
+        self.spark.udtf.register("test_udtf", func)
+
+        with self.assertRaisesRegex(AnalysisException, 
"TABLE_OR_VIEW_NOT_FOUND"):
+            self.spark.sql("SELECT * FROM test_udtf(TABLE (SELECT * FROM 
v))").collect()
+
+    def test_udtf_with_table_argument_cte_inside(self):
+        class TestUDTF:
+            def eval(self, row: Row):
+                if row["id"] > 5:
+                    yield row["id"],
+
+        func = udtf(TestUDTF, returnType="a: int")
+        self.spark.udtf.register("test_udtf", func)
+        self.assertEqual(
+            self.spark.sql(
+                """
+                SELECT * FROM test_udtf(TABLE (
+                  WITH t AS (
+                    SELECT id FROM range(0, 8)
+                  )
+                  SELECT * FROM t
+                ))
+                """
+            ).collect(),
+            [Row(a=6), Row(a=7)],
+        )
+
+    def test_udtf_with_table_argument_cte_outside(self):
+        class TestUDTF:
+            def eval(self, row: Row):
+                if row["id"] > 5:
+                    yield row["id"],
+
+        func = udtf(TestUDTF, returnType="a: int")
+        self.spark.udtf.register("test_udtf", func)
+        self.assertEqual(
+            self.spark.sql(
+                """
+                WITH t AS (
+                  SELECT id FROM range(0, 8)
+                )
+                SELECT * FROM test_udtf(TABLE (SELECT id FROM t))
+                """
+            ).collect(),
+            [Row(a=6), Row(a=7)],
+        )
+
+        self.assertEqual(
+            self.spark.sql(
+                """
+                WITH t AS (
+                  SELECT id FROM range(0, 8)
+                )
+                SELECT * FROM test_udtf(TABLE t)
+                """
+            ).collect(),
+            [Row(a=6), Row(a=7)],
+        )
+
+    # TODO(SPARK-44233): Fix the subquery resolution.
+    @unittest.skip("Fails to resolve the subquery.")
+    def test_udtf_with_table_argument_lateral_join(self):
+        class TestUDTF:
+            def eval(self, row: Row):
+                if row["id"] > 5:
+                    yield row["id"],
+
+        func = udtf(TestUDTF, returnType="a: int")
+        self.spark.udtf.register("test_udtf", func)
+        self.assertEqual(
+            self.spark.sql(
+                """
+                SELECT * FROM
+                  range(0, 8) AS t,
+                  LATERAL test_udtf(TABLE t)
+                """
+            ).collect(),
+            [Row(a=6), Row(a=7)],
+        )
+
+    def test_udtf_with_table_argument_multiple(self):
+        class TestUDTF:
+            def eval(self, a: Row, b: Row):
+                yield a[0], b[0]
+
+        func = udtf(TestUDTF, returnType="a: int, b: int")
+        self.spark.udtf.register("test_udtf", func)
+
+        query = """
+          SELECT * FROM test_udtf(
+            TABLE (SELECT id FROM range(0, 2)),
+            TABLE (SELECT id FROM range(0, 3)))
+        """
+
+        with 
self.sql_conf({"spark.sql.tvf.allowMultipleTableArguments.enabled": False}):
+            with self.assertRaisesRegex(
+                AnalysisException, 
"TABLE_VALUED_FUNCTION_TOO_MANY_TABLE_ARGUMENTS"
+            ):
+                self.spark.sql(query).collect()
+
+        with 
self.sql_conf({"spark.sql.tvf.allowMultipleTableArguments.enabled": True}):
+            self.assertEqual(
+                self.spark.sql(query).collect(),
+                [
+                    Row(a=0, b=0),
+                    Row(a=1, b=0),
+                    Row(a=0, b=1),
+                    Row(a=1, b=1),
+                    Row(a=0, b=2),
+                    Row(a=1, b=2),
+                ],
+            )
+
 
 class UDTFTests(UDTFTestsMixin, ReusedSQLTestCase):
     @classmethod
diff --git 
a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
 
b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
index ab6c0d0861f..0390785ab5d 100644
--- 
a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
+++ 
b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
@@ -788,8 +788,29 @@ inlineTable
     : VALUES expression (COMMA expression)* tableAlias
     ;
 
+functionTableSubqueryArgument
+    : TABLE identifierReference
+    | TABLE LEFT_PAREN query RIGHT_PAREN
+    ;
+
+functionTableNamedArgumentExpression
+    : key=identifier FAT_ARROW table=functionTableSubqueryArgument
+    ;
+
+functionTableReferenceArgument
+    : functionTableSubqueryArgument
+    | functionTableNamedArgumentExpression
+    ;
+
+functionTableArgument
+    : functionArgument
+    | functionTableReferenceArgument
+    ;
+
 functionTable
-    : funcName=functionName LEFT_PAREN (functionArgument (COMMA 
functionArgument)*)? RIGHT_PAREN tableAlias
+    : funcName=functionName LEFT_PAREN
+      (functionTableArgument (COMMA functionTableArgument)*)?
+      RIGHT_PAREN tableAlias
     ;
 
 tableAlias
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 47c266e7d18..94d341ed1d7 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -2058,7 +2058,7 @@ class Analyzer(override val catalogManager: 
CatalogManager) extends RuleExecutor
       case u: UnresolvedTableValuedFunction if 
u.functionArgs.forall(_.resolved) =>
         withPosition(u) {
           try {
-            resolveBuiltinOrTempTableFunction(u.name, 
u.functionArgs).getOrElse {
+            val resolvedFunc = resolveBuiltinOrTempTableFunction(u.name, 
u.functionArgs).getOrElse {
               val CatalogAndIdentifier(catalog, ident) = 
expandIdentifier(u.name)
               if (CatalogV2Util.isSessionCatalog(catalog)) {
                 v1SessionCatalog.resolvePersistentTableFunction(
@@ -2068,6 +2068,30 @@ class Analyzer(override val catalogManager: 
CatalogManager) extends RuleExecutor
                   catalog, "table-valued functions")
               }
             }
+
+            val tableArgs = mutable.ArrayBuffer.empty[LogicalPlan]
+            val tvf = resolvedFunc.transformAllExpressionsWithPruning(
+              _.containsPattern(FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION), 
ruleId)  {
+              case t: FunctionTableSubqueryArgumentExpression =>
+                val alias = 
SubqueryAlias.generateSubqueryName(s"_${tableArgs.size}")
+                tableArgs.append(SubqueryAlias(alias, t.evaluable))
+                UnresolvedAttribute(Seq(alias, "c"))
+            }
+            if (tableArgs.nonEmpty) {
+              if (!conf.tvfAllowMultipleTableArguments && tableArgs.size > 1) {
+                throw 
QueryCompilationErrors.tableValuedFunctionTooManyTableArgumentsError(
+                  tableArgs.size)
+              }
+              val alias = 
SubqueryAlias.generateSubqueryName(s"_${tableArgs.size}")
+              Project(
+                Seq(UnresolvedStar(Some(Seq(alias)))),
+                LateralJoin(
+                  tableArgs.reduceLeft(Join(_, _, Inner, None, JoinHint.NONE)),
+                  LateralSubquery(SubqueryAlias(alias, tvf)), Inner, None)
+              )
+            } else {
+              tvf
+            }
           } catch {
             case _: NoSuchFunctionException =>
               u.failAnalysis(
@@ -2416,6 +2440,8 @@ class Analyzer(override val catalogManager: 
CatalogManager) extends RuleExecutor
           InSubquery(values, expr.asInstanceOf[ListQuery])
         case s @ LateralSubquery(sub, _, exprId, _, _) if !sub.resolved =>
           resolveSubQuery(s, outer)(LateralSubquery(_, _, exprId))
+        case a @ FunctionTableSubqueryArgumentExpression(sub, _, exprId) if 
!sub.resolved =>
+          resolveSubQuery(a, outer)(FunctionTableSubqueryArgumentExpression(_, 
_, exprId))
       }
     }
 
@@ -2436,6 +2462,8 @@ class Analyzer(override val catalogManager: 
CatalogManager) extends RuleExecutor
         resolveSubQueries(r, r)
       case j: Join if j.childrenResolved && j.duplicateResolved =>
         resolveSubQueries(j, j)
+      case tvf: UnresolvedTableValuedFunction =>
+        resolveSubQueries(tvf, tvf)
       case s: SupportsSubquery if s.childrenResolved =>
         resolveSubQueries(s, s)
     }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala
new file mode 100644
index 00000000000..6d502731251
--- /dev/null
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, LogicalPlan, 
Project}
+import 
org.apache.spark.sql.catalyst.trees.TreePattern.{FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION,
 TreePattern}
+import org.apache.spark.sql.types.DataType
+
+/**
+ * This is the parsed representation of a relation argument for a 
TableValuedFunction call.
+ * The syntax supports passing such relations one of two ways:
+ *
+ * 1. SELECT ... FROM tvf_call(TABLE t)
+ * 2. SELECT ... FROM tvf_call(TABLE (<query>))
+ *
+ * In the former case, the relation argument directly refers to the name of a
+ * table in the catalog. In the latter case, the relation argument comprises
+ * a table subquery that may itself refer to one or more tables in its own
+ * FROM clause.
+ */
+case class FunctionTableSubqueryArgumentExpression(
+    plan: LogicalPlan,
+    outerAttrs: Seq[Expression] = Seq.empty,
+    exprId: ExprId = NamedExpression.newExprId)
+  extends SubqueryExpression(plan, outerAttrs, exprId, Seq.empty, None) with 
Unevaluable {
+
+  override def dataType: DataType = plan.schema
+  override def nullable: Boolean = false
+  override def withNewPlan(plan: LogicalPlan): 
FunctionTableSubqueryArgumentExpression =
+    copy(plan = plan)
+  override def hint: Option[HintInfo] = None
+  override def withNewHint(hint: Option[HintInfo]): 
FunctionTableSubqueryArgumentExpression =
+    copy()
+  override def toString: String = s"table-argument#${exprId.id} 
$conditionString"
+  override lazy val canonicalized: Expression = {
+    FunctionTableSubqueryArgumentExpression(
+      plan.canonicalized,
+      outerAttrs.map(_.canonicalized),
+      ExprId(0))
+  }
+
+  override protected def withNewChildrenInternal(
+      newChildren: IndexedSeq[Expression]): 
FunctionTableSubqueryArgumentExpression =
+    copy(outerAttrs = newChildren)
+
+  final override def nodePatternsInternal: Seq[TreePattern] =
+    Seq(FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION)
+
+  lazy val evaluable: LogicalPlan = 
Project(Seq(Alias(CreateStruct(plan.output), "c")()), plan)
+}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 9a395924c45..488b4e46735 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -1551,6 +1551,33 @@ class AstBuilder extends 
SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit
     RelationTimeTravel(plan, timestamp, version)
   }
 
+  /**
+   * Create a relation argument for a table-valued function argument.
+   */
+  override def visitFunctionTableSubqueryArgument(
+      ctx: FunctionTableSubqueryArgumentContext): Expression = withOrigin(ctx) 
{
+    val p = Option(ctx.identifierReference).map { r =>
+      createUnresolvedRelation(r)
+    }.getOrElse {
+      plan(ctx.query)
+    }
+    FunctionTableSubqueryArgumentExpression(p)
+  }
+
+  private def extractFunctionTableNamedArgument(
+      expr: FunctionTableReferenceArgumentContext, funcName: String) : 
Expression = {
+    Option(expr.functionTableNamedArgumentExpression).map { n =>
+      if (conf.getConf(SQLConf.ALLOW_NAMED_FUNCTION_ARGUMENTS)) {
+        NamedArgumentExpression(
+          n.key.getText, 
visitFunctionTableSubqueryArgument(n.functionTableSubqueryArgument))
+      } else {
+        throw QueryCompilationErrors.namedArgumentsNotEnabledError(funcName, 
n.key.getText)
+      }
+    }.getOrElse {
+      visitFunctionTableSubqueryArgument(expr.functionTableSubqueryArgument)
+    }
+  }
+
   /**
    * Create a table-valued function call with arguments, e.g. range(1000)
    */
@@ -1569,8 +1596,12 @@ class AstBuilder extends 
SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit
         if (ident.length > 1) {
           throw QueryParsingErrors.invalidTableValuedFunctionNameError(ident, 
ctx)
         }
-        val args = func.functionArgument.asScala.map { e =>
-          extractNamedArgument(e, func.functionName.getText)
+        val funcName = func.functionName.getText
+        val args = func.functionTableArgument.asScala.map { e =>
+          Option(e.functionArgument).map(extractNamedArgument(_, funcName))
+            .getOrElse {
+              
extractFunctionTableNamedArgument(e.functionTableReferenceArgument, funcName)
+            }
         }.toSeq
 
         val tvf = UnresolvedTableValuedFunction(ident, args)
@@ -1634,7 +1665,7 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] 
with SQLConfHelper wit
       // normal subquery names, so that parent operators can only access the 
columns in subquery by
       // unqualified names. Users can still use this special qualifier to 
access columns if they
       // know it, but that's not recommended.
-      SubqueryAlias("__auto_generated_subquery_name", relation)
+      SubqueryAlias(SubqueryAlias.generateSubqueryName(), relation)
     } else {
       mayApplyAliasPlan(ctx.tableAlias, relation)
     }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index e23966775e9..c5ac0304841 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -1707,7 +1707,12 @@ object SubqueryAlias {
       child: LogicalPlan): SubqueryAlias = {
     SubqueryAlias(AliasIdentifier(multipartIdentifier.last, 
multipartIdentifier.init), child)
   }
+
+  def generateSubqueryName(suffix: String = ""): String = {
+    s"__auto_generated_subquery_name$suffix"
+  }
 }
+
 /**
  * Sample the dataset.
  *
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
index 11d5cf54df4..b806ebbed52 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
@@ -46,6 +46,7 @@ object TreePattern extends Enumeration  {
   val EXISTS_SUBQUERY = Value
   val EXPRESSION_WITH_RANDOM_SEED: Value = Value
   val EXTRACT_VALUE: Value = Value
+  val FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION: Value = Value
   val GENERATE: Value = Value
   val GENERATOR: Value = Value
   val HIGH_ORDER_FUNCTION: Value = Value
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index e02708105d2..48223cb34e1 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -1907,6 +1907,13 @@ private[sql] object QueryCompilationErrors extends 
QueryErrorsBase {
         "ability" -> ability))
   }
 
+  def tableValuedFunctionTooManyTableArgumentsError(num: Int): Throwable = {
+    new AnalysisException(
+      errorClass = "TABLE_VALUED_FUNCTION_TOO_MANY_TABLE_ARGUMENTS",
+      messageParameters = Map("num" -> num.toString)
+    )
+  }
+
   def identifierTooManyNamePartsError(originalIdentifier: String): Throwable = 
{
     new AnalysisException(
       errorClass = "IDENTIFIER_TOO_MANY_NAME_PARTS",
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 270508139e4..ecff6bef8ae 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -2753,6 +2753,14 @@ object SQLConf {
     .booleanConf
     .createWithDefault(false)
 
+  val TVF_ALLOW_MULTIPLE_TABLE_ARGUMENTS_ENABLED =
+    buildConf("spark.sql.tvf.allowMultipleTableArguments.enabled")
+      .doc("When true, allows multiple table arguments for table-valued 
functions, " +
+        "receiving the cartesian product of all the rows of these tables.")
+      .version("3.5.0")
+      .booleanConf
+      .createWithDefault(false)
+
   val RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION =
     buildConf("spark.sql.execution.rangeExchange.sampleSizePerPartition")
       .internal()
@@ -4926,6 +4934,8 @@ class SQLConf extends Serializable with Logging {
 
   def supportQuotedRegexColumnName: Boolean = 
getConf(SUPPORT_QUOTED_REGEX_COLUMN_NAME)
 
+  def tvfAllowMultipleTableArguments: Boolean = 
getConf(TVF_ALLOW_MULTIPLE_TABLE_ARGUMENTS_ENABLED)
+
   def rangeExchangeSampleSizePerPartition: Int = 
getConf(RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION)
 
   def arrowPySparkEnabled: Boolean = getConf(ARROW_PYSPARK_EXECUTION_ENABLED)
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index 228a287e14f..4bad3ced705 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -1441,6 +1441,44 @@ class PlanParserSuite extends AnalysisTest {
         NamedArgumentExpression("group", Literal("abc")) :: 
Nil).select(star()))
   }
 
+  test("table valued function with table arguments") {
+    assertEqual(
+      "select * from my_tvf(table v1, table (select 1))",
+      UnresolvedTableValuedFunction("my_tvf",
+        FunctionTableSubqueryArgumentExpression(UnresolvedRelation(Seq("v1"))) 
::
+          FunctionTableSubqueryArgumentExpression(
+            Project(Seq(UnresolvedAlias(Literal(1))), OneRowRelation())) :: 
Nil).select(star()))
+
+    // All named arguments
+    assertEqual(
+      "select * from my_tvf(arg1 => table v1, arg2 => table (select 1))",
+      UnresolvedTableValuedFunction("my_tvf",
+        NamedArgumentExpression("arg1",
+          
FunctionTableSubqueryArgumentExpression(UnresolvedRelation(Seq("v1")))) ::
+          NamedArgumentExpression("arg2",
+            FunctionTableSubqueryArgumentExpression(
+              Project(Seq(UnresolvedAlias(Literal(1))), OneRowRelation()))) :: 
Nil).select(star()))
+
+    // Unnamed and named arguments
+    assertEqual(
+      "select * from my_tvf(2, table v1, arg1 => table (select 1))",
+      UnresolvedTableValuedFunction("my_tvf",
+        Literal(2) ::
+          
FunctionTableSubqueryArgumentExpression(UnresolvedRelation(Seq("v1"))) ::
+          NamedArgumentExpression("arg1",
+            FunctionTableSubqueryArgumentExpression(
+              Project(Seq(UnresolvedAlias(Literal(1))), OneRowRelation()))) :: 
Nil).select(star()))
+
+    // Mixed arguments
+    assertEqual(
+      "select * from my_tvf(arg1 => table v1, 2, arg2 => true)",
+      UnresolvedTableValuedFunction("my_tvf",
+        NamedArgumentExpression("arg1",
+          
FunctionTableSubqueryArgumentExpression(UnresolvedRelation(Seq("v1")))) ::
+          Literal(2) ::
+          NamedArgumentExpression("arg2", Literal(true)) :: 
Nil).select(star()))
+  }
+
   test("SPARK-32106: TRANSFORM plan") {
     // verify schema less
     assertEqual(
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
index 911ddfeb13b..ebf48c5f863 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
@@ -96,6 +96,8 @@ trait PlanTestBase extends PredicateHelper with SQLHelper 
with SQLConfHelper { s
         udf.copy(resultId = ExprId(0))
       case udaf: PythonUDAF =>
         udaf.copy(resultId = ExprId(0))
+      case a: FunctionTableSubqueryArgumentExpression =>
+        a.copy(plan = normalizeExprIds(a.plan), exprId = ExprId(0))
     }
   }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala
index a7d5046245d..2731760f7ef 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala
@@ -401,7 +401,6 @@ class QueryParsingErrorsSuite extends QueryTest with 
SharedSparkSession with SQL
     checkParseSyntaxError("select * from my_tvf(arg1 => )", "')'")
     checkParseSyntaxError("select * from my_tvf(arg1 => , 42)", "','")
     checkParseSyntaxError("select * from my_tvf(my_tvf.arg1 => 'value1')", 
"'=>'")
-    checkParseSyntaxError("select * from my_tvf(arg1 => table t1)", "'t1'", 
hint = ": extra input 't1'")
   }
 
   test("PARSE_SYNTAX_ERROR: extraneous input") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to