This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 8390b03df62 [SPARK-44200][SQL] Support TABLE argument parser rule for TableValuedFunction 8390b03df62 is described below commit 8390b03df62e7f808dc214c69e340fc1e70fb517 Author: Takuya UESHIN <ues...@databricks.com> AuthorDate: Mon Jul 3 16:26:03 2023 +0900 [SPARK-44200][SQL] Support TABLE argument parser rule for TableValuedFunction ### What changes were proposed in this pull request? Adds a new SQL syntax for `TableValuedFunction`. The syntax supports passing such relations one of two ways: 1. `SELECT ... FROM tvf_call(TABLE t)` 2. `SELECT ... FROM tvf_call(TABLE (<query>))` In the former case, the relation argument directly refers to the name of a table in the catalog. In the latter case, the relation argument comprises a table subquery that may itself refer to one or more tables in its own FROM clause. For example, for the given user defined table values function: ```py udtf(returnType="a: int") class TestUDTF: def eval(self, row: Row): if row[0] > 5: yield row[0], spark.udtf.register("test_udtf", TestUDTF) spark.sql("CREATE OR REPLACE TEMPORARY VIEW v as SELECT id FROM range(0, 8)") ``` , the following SQLs should work: ```py >>> spark.sql("SELECT * FROM test_udtf(TABLE v)").collect() [Row(a=6), Row(a=7)] ``` or ```py >>> spark.sql("SELECT * FROM test_udtf(TABLE (SELECT id + 1 FROM v))").collect() [Row(a=6), Row(a=7), Row(a=8)] ``` ### Why are the changes needed? To support `TABLE` argument parser rule for TableValuedFunction. ### Does this PR introduce _any_ user-facing change? Yes, new syntax for SQL. ### How was this patch tested? Added the related tests. Closes #41750 from ueshin/issues/SPARK-44200/table_argument. Authored-by: Takuya UESHIN <ues...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../src/main/resources/error/error-classes.json | 5 + python/pyspark/sql/tests/test_udtf.py | 194 ++++++++++++++++++++- .../spark/sql/catalyst/parser/SqlBaseParser.g4 | 23 ++- .../spark/sql/catalyst/analysis/Analyzer.scala | 30 +++- .../FunctionTableSubqueryArgumentExpression.scala | 65 +++++++ .../spark/sql/catalyst/parser/AstBuilder.scala | 37 +++- .../plans/logical/basicLogicalOperators.scala | 5 + .../spark/sql/catalyst/trees/TreePatterns.scala | 1 + .../spark/sql/errors/QueryCompilationErrors.scala | 7 + .../org/apache/spark/sql/internal/SQLConf.scala | 10 ++ .../sql/catalyst/parser/PlanParserSuite.scala | 38 ++++ .../apache/spark/sql/catalyst/plans/PlanTest.scala | 2 + .../spark/sql/errors/QueryParsingErrorsSuite.scala | 1 - 13 files changed, 411 insertions(+), 7 deletions(-) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 027d09eae10..753701cf581 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -2198,6 +2198,11 @@ ], "sqlState" : "42P01" }, + "TABLE_VALUED_FUNCTION_TOO_MANY_TABLE_ARGUMENTS" : { + "message" : [ + "There are too many table arguments for table-valued function. It allows one table argument, but got: <num>. If you want to allow it, please set \"spark.sql.allowMultipleTableArguments.enabled\" to \"true\"" + ] + }, "TASK_WRITE_FAILED" : { "message" : [ "Task failed while writing rows to <path>." diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index ccf271ceec2..43ab0795042 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -27,7 +27,7 @@ from pyspark.sql.types import Row from pyspark.testing.sqlutils import ReusedSQLTestCase -class UDTFTestsMixin(ReusedSQLTestCase): +class UDTFTestsMixin: def test_simple_udtf(self): class TestUDTF: def eval(self): @@ -397,6 +397,198 @@ class UDTFTestsMixin(ReusedSQLTestCase): with self.assertRaisesRegex(TypeError, err_msg): udtf(test_udtf, returnType="a: int") + def test_udtf_with_table_argument_query(self): + class TestUDTF: + def eval(self, row: Row): + if row["id"] > 5: + yield row["id"], + + func = udtf(TestUDTF, returnType="a: int") + self.spark.udtf.register("test_udtf", func) + self.assertEqual( + self.spark.sql("SELECT * FROM test_udtf(TABLE (SELECT id FROM range(0, 8)))").collect(), + [Row(a=6), Row(a=7)], + ) + + def test_udtf_with_int_and_table_argument_query(self): + class TestUDTF: + def eval(self, i: int, row: Row): + if row["id"] > i: + yield row["id"], + + func = udtf(TestUDTF, returnType="a: int") + self.spark.udtf.register("test_udtf", func) + self.assertEqual( + self.spark.sql( + "SELECT * FROM test_udtf(5, TABLE (SELECT id FROM range(0, 8)))" + ).collect(), + [Row(a=6), Row(a=7)], + ) + + def test_udtf_with_table_argument_identifier(self): + class TestUDTF: + def eval(self, row: Row): + if row["id"] > 5: + yield row["id"], + + func = udtf(TestUDTF, returnType="a: int") + self.spark.udtf.register("test_udtf", func) + + with self.tempView("v"): + self.spark.sql("CREATE OR REPLACE TEMPORARY VIEW v as SELECT id FROM range(0, 8)") + self.assertEqual( + self.spark.sql("SELECT * FROM test_udtf(TABLE v)").collect(), + [Row(a=6), Row(a=7)], + ) + + def test_udtf_with_int_and_table_argument_identifier(self): + class TestUDTF: + def eval(self, i: int, row: Row): + if row["id"] > i: + yield row["id"], + + func = udtf(TestUDTF, returnType="a: int") + self.spark.udtf.register("test_udtf", func) + + with self.tempView("v"): + self.spark.sql("CREATE OR REPLACE TEMPORARY VIEW v as SELECT id FROM range(0, 8)") + self.assertEqual( + self.spark.sql("SELECT * FROM test_udtf(5, TABLE v)").collect(), + [Row(a=6), Row(a=7)], + ) + + def test_udtf_with_table_argument_unknown_identifier(self): + class TestUDTF: + def eval(self, row: Row): + if row["id"] > 5: + yield row["id"], + + func = udtf(TestUDTF, returnType="a: int") + self.spark.udtf.register("test_udtf", func) + + with self.assertRaisesRegex(AnalysisException, "TABLE_OR_VIEW_NOT_FOUND"): + self.spark.sql("SELECT * FROM test_udtf(TABLE v)").collect() + + def test_udtf_with_table_argument_malformed_query(self): + class TestUDTF: + def eval(self, row: Row): + if row["id"] > 5: + yield row["id"], + + func = udtf(TestUDTF, returnType="a: int") + self.spark.udtf.register("test_udtf", func) + + with self.assertRaisesRegex(AnalysisException, "TABLE_OR_VIEW_NOT_FOUND"): + self.spark.sql("SELECT * FROM test_udtf(TABLE (SELECT * FROM v))").collect() + + def test_udtf_with_table_argument_cte_inside(self): + class TestUDTF: + def eval(self, row: Row): + if row["id"] > 5: + yield row["id"], + + func = udtf(TestUDTF, returnType="a: int") + self.spark.udtf.register("test_udtf", func) + self.assertEqual( + self.spark.sql( + """ + SELECT * FROM test_udtf(TABLE ( + WITH t AS ( + SELECT id FROM range(0, 8) + ) + SELECT * FROM t + )) + """ + ).collect(), + [Row(a=6), Row(a=7)], + ) + + def test_udtf_with_table_argument_cte_outside(self): + class TestUDTF: + def eval(self, row: Row): + if row["id"] > 5: + yield row["id"], + + func = udtf(TestUDTF, returnType="a: int") + self.spark.udtf.register("test_udtf", func) + self.assertEqual( + self.spark.sql( + """ + WITH t AS ( + SELECT id FROM range(0, 8) + ) + SELECT * FROM test_udtf(TABLE (SELECT id FROM t)) + """ + ).collect(), + [Row(a=6), Row(a=7)], + ) + + self.assertEqual( + self.spark.sql( + """ + WITH t AS ( + SELECT id FROM range(0, 8) + ) + SELECT * FROM test_udtf(TABLE t) + """ + ).collect(), + [Row(a=6), Row(a=7)], + ) + + # TODO(SPARK-44233): Fix the subquery resolution. + @unittest.skip("Fails to resolve the subquery.") + def test_udtf_with_table_argument_lateral_join(self): + class TestUDTF: + def eval(self, row: Row): + if row["id"] > 5: + yield row["id"], + + func = udtf(TestUDTF, returnType="a: int") + self.spark.udtf.register("test_udtf", func) + self.assertEqual( + self.spark.sql( + """ + SELECT * FROM + range(0, 8) AS t, + LATERAL test_udtf(TABLE t) + """ + ).collect(), + [Row(a=6), Row(a=7)], + ) + + def test_udtf_with_table_argument_multiple(self): + class TestUDTF: + def eval(self, a: Row, b: Row): + yield a[0], b[0] + + func = udtf(TestUDTF, returnType="a: int, b: int") + self.spark.udtf.register("test_udtf", func) + + query = """ + SELECT * FROM test_udtf( + TABLE (SELECT id FROM range(0, 2)), + TABLE (SELECT id FROM range(0, 3))) + """ + + with self.sql_conf({"spark.sql.tvf.allowMultipleTableArguments.enabled": False}): + with self.assertRaisesRegex( + AnalysisException, "TABLE_VALUED_FUNCTION_TOO_MANY_TABLE_ARGUMENTS" + ): + self.spark.sql(query).collect() + + with self.sql_conf({"spark.sql.tvf.allowMultipleTableArguments.enabled": True}): + self.assertEqual( + self.spark.sql(query).collect(), + [ + Row(a=0, b=0), + Row(a=1, b=0), + Row(a=0, b=1), + Row(a=1, b=1), + Row(a=0, b=2), + Row(a=1, b=2), + ], + ) + class UDTFTests(UDTFTestsMixin, ReusedSQLTestCase): @classmethod diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index ab6c0d0861f..0390785ab5d 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -788,8 +788,29 @@ inlineTable : VALUES expression (COMMA expression)* tableAlias ; +functionTableSubqueryArgument + : TABLE identifierReference + | TABLE LEFT_PAREN query RIGHT_PAREN + ; + +functionTableNamedArgumentExpression + : key=identifier FAT_ARROW table=functionTableSubqueryArgument + ; + +functionTableReferenceArgument + : functionTableSubqueryArgument + | functionTableNamedArgumentExpression + ; + +functionTableArgument + : functionArgument + | functionTableReferenceArgument + ; + functionTable - : funcName=functionName LEFT_PAREN (functionArgument (COMMA functionArgument)*)? RIGHT_PAREN tableAlias + : funcName=functionName LEFT_PAREN + (functionTableArgument (COMMA functionTableArgument)*)? + RIGHT_PAREN tableAlias ; tableAlias diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 47c266e7d18..94d341ed1d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2058,7 +2058,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) => withPosition(u) { try { - resolveBuiltinOrTempTableFunction(u.name, u.functionArgs).getOrElse { + val resolvedFunc = resolveBuiltinOrTempTableFunction(u.name, u.functionArgs).getOrElse { val CatalogAndIdentifier(catalog, ident) = expandIdentifier(u.name) if (CatalogV2Util.isSessionCatalog(catalog)) { v1SessionCatalog.resolvePersistentTableFunction( @@ -2068,6 +2068,30 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor catalog, "table-valued functions") } } + + val tableArgs = mutable.ArrayBuffer.empty[LogicalPlan] + val tvf = resolvedFunc.transformAllExpressionsWithPruning( + _.containsPattern(FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION), ruleId) { + case t: FunctionTableSubqueryArgumentExpression => + val alias = SubqueryAlias.generateSubqueryName(s"_${tableArgs.size}") + tableArgs.append(SubqueryAlias(alias, t.evaluable)) + UnresolvedAttribute(Seq(alias, "c")) + } + if (tableArgs.nonEmpty) { + if (!conf.tvfAllowMultipleTableArguments && tableArgs.size > 1) { + throw QueryCompilationErrors.tableValuedFunctionTooManyTableArgumentsError( + tableArgs.size) + } + val alias = SubqueryAlias.generateSubqueryName(s"_${tableArgs.size}") + Project( + Seq(UnresolvedStar(Some(Seq(alias)))), + LateralJoin( + tableArgs.reduceLeft(Join(_, _, Inner, None, JoinHint.NONE)), + LateralSubquery(SubqueryAlias(alias, tvf)), Inner, None) + ) + } else { + tvf + } } catch { case _: NoSuchFunctionException => u.failAnalysis( @@ -2416,6 +2440,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor InSubquery(values, expr.asInstanceOf[ListQuery]) case s @ LateralSubquery(sub, _, exprId, _, _) if !sub.resolved => resolveSubQuery(s, outer)(LateralSubquery(_, _, exprId)) + case a @ FunctionTableSubqueryArgumentExpression(sub, _, exprId) if !sub.resolved => + resolveSubQuery(a, outer)(FunctionTableSubqueryArgumentExpression(_, _, exprId)) } } @@ -2436,6 +2462,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor resolveSubQueries(r, r) case j: Join if j.childrenResolved && j.duplicateResolved => resolveSubQueries(j, j) + case tvf: UnresolvedTableValuedFunction => + resolveSubQueries(tvf, tvf) case s: SupportsSubquery if s.childrenResolved => resolveSubQueries(s, s) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala new file mode 100644 index 00000000000..6d502731251 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.trees.TreePattern.{FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION, TreePattern} +import org.apache.spark.sql.types.DataType + +/** + * This is the parsed representation of a relation argument for a TableValuedFunction call. + * The syntax supports passing such relations one of two ways: + * + * 1. SELECT ... FROM tvf_call(TABLE t) + * 2. SELECT ... FROM tvf_call(TABLE (<query>)) + * + * In the former case, the relation argument directly refers to the name of a + * table in the catalog. In the latter case, the relation argument comprises + * a table subquery that may itself refer to one or more tables in its own + * FROM clause. + */ +case class FunctionTableSubqueryArgumentExpression( + plan: LogicalPlan, + outerAttrs: Seq[Expression] = Seq.empty, + exprId: ExprId = NamedExpression.newExprId) + extends SubqueryExpression(plan, outerAttrs, exprId, Seq.empty, None) with Unevaluable { + + override def dataType: DataType = plan.schema + override def nullable: Boolean = false + override def withNewPlan(plan: LogicalPlan): FunctionTableSubqueryArgumentExpression = + copy(plan = plan) + override def hint: Option[HintInfo] = None + override def withNewHint(hint: Option[HintInfo]): FunctionTableSubqueryArgumentExpression = + copy() + override def toString: String = s"table-argument#${exprId.id} $conditionString" + override lazy val canonicalized: Expression = { + FunctionTableSubqueryArgumentExpression( + plan.canonicalized, + outerAttrs.map(_.canonicalized), + ExprId(0)) + } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): FunctionTableSubqueryArgumentExpression = + copy(outerAttrs = newChildren) + + final override def nodePatternsInternal: Seq[TreePattern] = + Seq(FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION) + + lazy val evaluable: LogicalPlan = Project(Seq(Alias(CreateStruct(plan.output), "c")()), plan) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 9a395924c45..488b4e46735 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1551,6 +1551,33 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit RelationTimeTravel(plan, timestamp, version) } + /** + * Create a relation argument for a table-valued function argument. + */ + override def visitFunctionTableSubqueryArgument( + ctx: FunctionTableSubqueryArgumentContext): Expression = withOrigin(ctx) { + val p = Option(ctx.identifierReference).map { r => + createUnresolvedRelation(r) + }.getOrElse { + plan(ctx.query) + } + FunctionTableSubqueryArgumentExpression(p) + } + + private def extractFunctionTableNamedArgument( + expr: FunctionTableReferenceArgumentContext, funcName: String) : Expression = { + Option(expr.functionTableNamedArgumentExpression).map { n => + if (conf.getConf(SQLConf.ALLOW_NAMED_FUNCTION_ARGUMENTS)) { + NamedArgumentExpression( + n.key.getText, visitFunctionTableSubqueryArgument(n.functionTableSubqueryArgument)) + } else { + throw QueryCompilationErrors.namedArgumentsNotEnabledError(funcName, n.key.getText) + } + }.getOrElse { + visitFunctionTableSubqueryArgument(expr.functionTableSubqueryArgument) + } + } + /** * Create a table-valued function call with arguments, e.g. range(1000) */ @@ -1569,8 +1596,12 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit if (ident.length > 1) { throw QueryParsingErrors.invalidTableValuedFunctionNameError(ident, ctx) } - val args = func.functionArgument.asScala.map { e => - extractNamedArgument(e, func.functionName.getText) + val funcName = func.functionName.getText + val args = func.functionTableArgument.asScala.map { e => + Option(e.functionArgument).map(extractNamedArgument(_, funcName)) + .getOrElse { + extractFunctionTableNamedArgument(e.functionTableReferenceArgument, funcName) + } }.toSeq val tvf = UnresolvedTableValuedFunction(ident, args) @@ -1634,7 +1665,7 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit // normal subquery names, so that parent operators can only access the columns in subquery by // unqualified names. Users can still use this special qualifier to access columns if they // know it, but that's not recommended. - SubqueryAlias("__auto_generated_subquery_name", relation) + SubqueryAlias(SubqueryAlias.generateSubqueryName(), relation) } else { mayApplyAliasPlan(ctx.tableAlias, relation) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index e23966775e9..c5ac0304841 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -1707,7 +1707,12 @@ object SubqueryAlias { child: LogicalPlan): SubqueryAlias = { SubqueryAlias(AliasIdentifier(multipartIdentifier.last, multipartIdentifier.init), child) } + + def generateSubqueryName(suffix: String = ""): String = { + s"__auto_generated_subquery_name$suffix" + } } + /** * Sample the dataset. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index 11d5cf54df4..b806ebbed52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -46,6 +46,7 @@ object TreePattern extends Enumeration { val EXISTS_SUBQUERY = Value val EXPRESSION_WITH_RANDOM_SEED: Value = Value val EXTRACT_VALUE: Value = Value + val FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION: Value = Value val GENERATE: Value = Value val GENERATOR: Value = Value val HIGH_ORDER_FUNCTION: Value = Value diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index e02708105d2..48223cb34e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -1907,6 +1907,13 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { "ability" -> ability)) } + def tableValuedFunctionTooManyTableArgumentsError(num: Int): Throwable = { + new AnalysisException( + errorClass = "TABLE_VALUED_FUNCTION_TOO_MANY_TABLE_ARGUMENTS", + messageParameters = Map("num" -> num.toString) + ) + } + def identifierTooManyNamePartsError(originalIdentifier: String): Throwable = { new AnalysisException( errorClass = "IDENTIFIER_TOO_MANY_NAME_PARTS", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 270508139e4..ecff6bef8ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2753,6 +2753,14 @@ object SQLConf { .booleanConf .createWithDefault(false) + val TVF_ALLOW_MULTIPLE_TABLE_ARGUMENTS_ENABLED = + buildConf("spark.sql.tvf.allowMultipleTableArguments.enabled") + .doc("When true, allows multiple table arguments for table-valued functions, " + + "receiving the cartesian product of all the rows of these tables.") + .version("3.5.0") + .booleanConf + .createWithDefault(false) + val RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION = buildConf("spark.sql.execution.rangeExchange.sampleSizePerPartition") .internal() @@ -4926,6 +4934,8 @@ class SQLConf extends Serializable with Logging { def supportQuotedRegexColumnName: Boolean = getConf(SUPPORT_QUOTED_REGEX_COLUMN_NAME) + def tvfAllowMultipleTableArguments: Boolean = getConf(TVF_ALLOW_MULTIPLE_TABLE_ARGUMENTS_ENABLED) + def rangeExchangeSampleSizePerPartition: Int = getConf(RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION) def arrowPySparkEnabled: Boolean = getConf(ARROW_PYSPARK_EXECUTION_ENABLED) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 228a287e14f..4bad3ced705 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -1441,6 +1441,44 @@ class PlanParserSuite extends AnalysisTest { NamedArgumentExpression("group", Literal("abc")) :: Nil).select(star())) } + test("table valued function with table arguments") { + assertEqual( + "select * from my_tvf(table v1, table (select 1))", + UnresolvedTableValuedFunction("my_tvf", + FunctionTableSubqueryArgumentExpression(UnresolvedRelation(Seq("v1"))) :: + FunctionTableSubqueryArgumentExpression( + Project(Seq(UnresolvedAlias(Literal(1))), OneRowRelation())) :: Nil).select(star())) + + // All named arguments + assertEqual( + "select * from my_tvf(arg1 => table v1, arg2 => table (select 1))", + UnresolvedTableValuedFunction("my_tvf", + NamedArgumentExpression("arg1", + FunctionTableSubqueryArgumentExpression(UnresolvedRelation(Seq("v1")))) :: + NamedArgumentExpression("arg2", + FunctionTableSubqueryArgumentExpression( + Project(Seq(UnresolvedAlias(Literal(1))), OneRowRelation()))) :: Nil).select(star())) + + // Unnamed and named arguments + assertEqual( + "select * from my_tvf(2, table v1, arg1 => table (select 1))", + UnresolvedTableValuedFunction("my_tvf", + Literal(2) :: + FunctionTableSubqueryArgumentExpression(UnresolvedRelation(Seq("v1"))) :: + NamedArgumentExpression("arg1", + FunctionTableSubqueryArgumentExpression( + Project(Seq(UnresolvedAlias(Literal(1))), OneRowRelation()))) :: Nil).select(star())) + + // Mixed arguments + assertEqual( + "select * from my_tvf(arg1 => table v1, 2, arg2 => true)", + UnresolvedTableValuedFunction("my_tvf", + NamedArgumentExpression("arg1", + FunctionTableSubqueryArgumentExpression(UnresolvedRelation(Seq("v1")))) :: + Literal(2) :: + NamedArgumentExpression("arg2", Literal(true)) :: Nil).select(star())) + } + test("SPARK-32106: TRANSFORM plan") { // verify schema less assertEqual( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 911ddfeb13b..ebf48c5f863 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -96,6 +96,8 @@ trait PlanTestBase extends PredicateHelper with SQLHelper with SQLConfHelper { s udf.copy(resultId = ExprId(0)) case udaf: PythonUDAF => udaf.copy(resultId = ExprId(0)) + case a: FunctionTableSubqueryArgumentExpression => + a.copy(plan = normalizeExprIds(a.plan), exprId = ExprId(0)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala index a7d5046245d..2731760f7ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala @@ -401,7 +401,6 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL checkParseSyntaxError("select * from my_tvf(arg1 => )", "')'") checkParseSyntaxError("select * from my_tvf(arg1 => , 42)", "','") checkParseSyntaxError("select * from my_tvf(my_tvf.arg1 => 'value1')", "'=>'") - checkParseSyntaxError("select * from my_tvf(arg1 => table t1)", "'t1'", hint = ": extra input 't1'") } test("PARSE_SYNTAX_ERROR: extraneous input") { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org