This is an automated email from the ASF dual-hosted git repository.
ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 2cf937f9bac [SPARK-45266][PYTHON] Refactor ResolveFunctions analyzer
rule to delay making lateral join when table arguments are used
2cf937f9bac is described below
commit 2cf937f9bac2131f3657660a8d65d07ab4ece490
Author: Takuya UESHIN <[email protected]>
AuthorDate: Thu Sep 28 10:37:18 2023 -0700
[SPARK-45266][PYTHON] Refactor ResolveFunctions analyzer rule to delay
making lateral join when table arguments are used
### What changes were proposed in this pull request?
Refactors `ResolveFunctions` analyzer rule to delay making lateral join
when table arguments are used.
- Delay making lateral join when table arguments are used to after all the
children are resolved
- Resolve `UnresolvedPolymorphicPythonUDTF` in one place
- Introduce a new error class
`UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.UNSUPPORTED_TABLE_ARGUMENT` if table
arguments are used inproperly.
### Why are the changes needed?
The analyzer rule `ResolveFunctions` became complicated.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #43042 from ueshin/issues/SPARK-45266/analyzer.
Authored-by: Takuya UESHIN <[email protected]>
Signed-off-by: Takuya UESHIN <[email protected]>
---
.../src/main/resources/error/error-classes.json | 5 +
...ted-subquery-expression-category-error-class.md | 4 +
.../spark/sql/catalyst/analysis/Analyzer.scala | 155 +++++++++------------
.../sql/catalyst/analysis/CheckAnalysis.scala | 5 +
.../spark/sql/catalyst/expressions/PythonUDF.scala | 6 +-
.../named-function-arguments.sql.out | 16 +--
.../results/named-function-arguments.sql.out | 16 +--
.../sql/execution/python/PythonUDTFSuite.scala | 20 ++-
8 files changed, 103 insertions(+), 124 deletions(-)
diff --git a/common/utils/src/main/resources/error/error-classes.json
b/common/utils/src/main/resources/error/error-classes.json
index 0882e387176..58fcedae332 100644
--- a/common/utils/src/main/resources/error/error-classes.json
+++ b/common/utils/src/main/resources/error/error-classes.json
@@ -3484,6 +3484,11 @@
"message" : [
"IN/EXISTS predicate subqueries can only be used in filters, joins,
aggregations, window functions, projections, and UPDATE/MERGE/DELETE
commands<treeNode>."
]
+ },
+ "UNSUPPORTED_TABLE_ARGUMENT" : {
+ "message" : [
+ "Table arguments are used in a function where they are not
supported<treeNode>."
+ ]
}
},
"sqlState" : "0A000"
diff --git
a/docs/sql-error-conditions-unsupported-subquery-expression-category-error-class.md
b/docs/sql-error-conditions-unsupported-subquery-expression-category-error-class.md
index f61ea721aa0..45ad386c666 100644
---
a/docs/sql-error-conditions-unsupported-subquery-expression-category-error-class.md
+++
b/docs/sql-error-conditions-unsupported-subquery-expression-category-error-class.md
@@ -73,4 +73,8 @@ Correlated scalar subqueries can only be used in filters,
aggregations, projecti
IN/EXISTS predicate subqueries can only be used in filters, joins,
aggregations, window functions, projections, and UPDATE/MERGE/DELETE
commands`<treeNode>`.
+## UNSUPPORTED_TABLE_ARGUMENT
+
+Table arguments are used in a function where they are not
supported`<treeNode>`.
+
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 67a958d73f7..cc0bfd3fc31 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -2082,7 +2082,7 @@ class Analyzer(override val catalogManager:
CatalogManager) extends RuleExecutor
case u: UnresolvedTableValuedFunction if
u.functionArgs.forall(_.resolved) =>
withPosition(u) {
try {
- val resolvedTvf = resolveBuiltinOrTempTableFunction(u.name,
u.functionArgs).getOrElse {
+ val resolvedFunc = resolveBuiltinOrTempTableFunction(u.name,
u.functionArgs).getOrElse {
val CatalogAndIdentifier(catalog, ident) =
expandIdentifier(u.name)
if (CatalogV2Util.isSessionCatalog(catalog)) {
v1SessionCatalog.resolvePersistentTableFunction(
@@ -2092,93 +2092,19 @@ class Analyzer(override val catalogManager:
CatalogManager) extends RuleExecutor
catalog, "table-valued functions")
}
}
- // Resolve Python UDTF calls if needed.
- val resolvedFunc = resolvedTvf match {
- case g @ Generate(u: UnresolvedPolymorphicPythonUDTF, _, _, _,
_, _) =>
- val analyzeResult: PythonUDTFAnalyzeResult =
- u.resolveElementMetadata(u.func, u.children)
- g.copy(generator =
- PythonUDTF(u.name, u.func, analyzeResult.schema, u.children,
- u.evalType, u.udfDeterministic, u.resultId,
u.pythonUDTFPartitionColumnIndexes,
- analyzeResult = Some(analyzeResult)))
- case other =>
- other
- }
- val tableArgs = mutable.ArrayBuffer.empty[LogicalPlan]
- val functionTableSubqueryArgs =
-
mutable.ArrayBuffer.empty[FunctionTableSubqueryArgumentExpression]
- val tvf = resolvedFunc.transformAllExpressionsWithPruning(
- _.containsPattern(FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION),
ruleId) {
+ resolvedFunc.transformAllExpressionsWithPruning(
+ _.containsPattern(FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION))
{
case t: FunctionTableSubqueryArgumentExpression =>
- val alias =
SubqueryAlias.generateSubqueryName(s"_${tableArgs.size}")
- val (
- pythonUDTFName: String,
- pythonUDTFAnalyzeResult: Option[PythonUDTFAnalyzeResult]) =
- resolvedFunc match {
- case Generate(p: PythonUDTF, _, _, _, _, _) =>
- (p.name,
- p.analyzeResult)
- case _ =>
- assert(!t.hasRepartitioning,
- "Cannot evaluate the table-valued function call
because it included the " +
- "PARTITION BY clause, but only Python table
functions support this " +
- "clause")
- ("", None)
- }
- // Check if this is a call to a Python user-defined table
function whose polymorphic
- // 'analyze' method returned metadata indicated requested
partitioning and/or
- // ordering properties of the input relation. In that event,
make sure that the UDTF
- // call did not include any explicit PARTITION BY and/or ORDER
BY clauses for the
- // corresponding TABLE argument, and then update the TABLE
argument representation
- // to apply the requested partitioning and/or ordering.
- pythonUDTFAnalyzeResult.map { analyzeResult =>
- val newTableArgument:
FunctionTableSubqueryArgumentExpression =
- analyzeResult.applyToTableArgument(pythonUDTFName, t)
- tableArgs.append(SubqueryAlias(alias,
newTableArgument.evaluable))
- functionTableSubqueryArgs.append(newTableArgument)
- }.getOrElse {
- tableArgs.append(SubqueryAlias(alias, t.evaluable))
- functionTableSubqueryArgs.append(t)
+ resolvedFunc match {
+ case Generate(_: PythonUDTF, _, _, _, _, _) =>
+ case Generate(_: UnresolvedPolymorphicPythonUDTF, _, _, _,
_, _) =>
+ case _ =>
+ assert(!t.hasRepartitioning,
+ "Cannot evaluate the table-valued function call because
it included the " +
+ "PARTITION BY clause, but only Python table functions
support this " +
+ "clause")
}
- UnresolvedAttribute(Seq(alias, "c"))
- }
- if (tableArgs.nonEmpty) {
- if (!conf.tvfAllowMultipleTableArguments && tableArgs.size > 1) {
- throw
QueryCompilationErrors.tableValuedFunctionTooManyTableArgumentsError(
- tableArgs.size)
- }
- val alias =
SubqueryAlias.generateSubqueryName(s"_${tableArgs.size}")
- // Propagate the column indexes for TABLE arguments to the
PythonUDTF instance.
- def assignUDTFPartitionColumnIndexes(
- fn: PythonUDTFPartitionColumnIndexes => LogicalPlan):
Option[LogicalPlan] = {
- val indexes: Seq[Int] = functionTableSubqueryArgs.headOption
- .map(_.partitioningExpressionIndexes).getOrElse(Seq.empty)
- if (indexes.nonEmpty) {
- Some(fn(PythonUDTFPartitionColumnIndexes(indexes)))
- } else {
- None
- }
- }
- val tvfWithTableColumnIndexes: LogicalPlan = tvf match {
- case g@Generate(p: PythonUDTF, _, _, _, _, _) =>
- assignUDTFPartitionColumnIndexes(
- i => g.copy(generator =
p.copy(pythonUDTFPartitionColumnIndexes = Some(i))))
- .getOrElse(g)
- case g@Generate(p: UnresolvedPolymorphicPythonUDTF, _, _, _,
_, _) =>
- assignUDTFPartitionColumnIndexes(
- i => g.copy(generator =
p.copy(pythonUDTFPartitionColumnIndexes = Some(i))))
- .getOrElse(g)
- case _ =>
- tvf
- }
- Project(
- Seq(UnresolvedStar(Some(Seq(alias)))),
- LateralJoin(
- tableArgs.reduceLeft(Join(_, _, Inner, None, JoinHint.NONE)),
- LateralSubquery(SubqueryAlias(alias,
tvfWithTableColumnIndexes)), Inner, None)
- )
- } else {
- tvf
+ t
}
} catch {
case _: NoSuchFunctionException =>
@@ -2206,6 +2132,46 @@ class Analyzer(override val catalogManager:
CatalogManager) extends RuleExecutor
}
Project(aliases, u.child)
+ case p: LogicalPlan
+ if p.resolved &&
p.containsPattern(FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION) =>
+ withPosition(p) {
+ val tableArgs =
+
mutable.ArrayBuffer.empty[(FunctionTableSubqueryArgumentExpression,
LogicalPlan)]
+
+ val tvf = p.transformExpressionsWithPruning(
+ _.containsPattern(FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION)) {
+ case t: FunctionTableSubqueryArgumentExpression =>
+ val alias =
SubqueryAlias.generateSubqueryName(s"_${tableArgs.size}")
+ tableArgs.append((t, SubqueryAlias(alias, t.evaluable)))
+ UnresolvedAttribute(Seq(alias, "c"))
+ }
+
+ assert(tableArgs.nonEmpty)
+ if (!conf.tvfAllowMultipleTableArguments && tableArgs.size > 1) {
+ throw
QueryCompilationErrors.tableValuedFunctionTooManyTableArgumentsError(
+ tableArgs.size)
+ }
+ val alias = SubqueryAlias.generateSubqueryName(s"_${tableArgs.size}")
+
+ // Propagate the column indexes for TABLE arguments to the
PythonUDTF instance.
+ val tvfWithTableColumnIndexes = tvf match {
+ case g @ Generate(pyudtf: PythonUDTF, _, _, _, _, _)
+ if tableArgs.head._1.partitioningExpressionIndexes.nonEmpty =>
+ val partitionColumnIndexes =
+
PythonUDTFPartitionColumnIndexes(tableArgs.head._1.partitioningExpressionIndexes)
+ g.copy(generator = pyudtf.copy(
+ pythonUDTFPartitionColumnIndexes =
Some(partitionColumnIndexes)))
+ case _ => tvf
+ }
+
+ Project(
+ Seq(UnresolvedStar(Some(Seq(alias)))),
+ LateralJoin(
+ tableArgs.map(_._2).reduceLeft(Join(_, _, Inner, None,
JoinHint.NONE)),
+ LateralSubquery(SubqueryAlias(alias,
tvfWithTableColumnIndexes)), Inner, None)
+ )
+ }
+
case q: LogicalPlan =>
q.transformExpressionsUpWithPruning(
_.containsAnyPattern(UNRESOLVED_FUNCTION, GENERATOR),
@@ -2251,9 +2217,20 @@ class Analyzer(override val catalogManager:
CatalogManager) extends RuleExecutor
}
case u: UnresolvedPolymorphicPythonUDTF => withPosition(u) {
- val elementSchema = u.resolveElementMetadata(u.func,
u.children).schema
- PythonUDTF(u.name, u.func, elementSchema, u.children,
- u.evalType, u.udfDeterministic, u.resultId,
u.pythonUDTFPartitionColumnIndexes)
+ // Check if this is a call to a Python user-defined table function
whose polymorphic
+ // 'analyze' method returned metadata indicated requested
partitioning and/or
+ // ordering properties of the input relation. In that event, make
sure that the UDTF
+ // call did not include any explicit PARTITION BY and/or ORDER BY
clauses for the
+ // corresponding TABLE argument, and then update the TABLE
argument representation
+ // to apply the requested partitioning and/or ordering.
+ val analyzeResult = u.resolveElementMetadata(u.func, u.children)
+ val newChildren = u.children.map {
+ case t: FunctionTableSubqueryArgumentExpression =>
+ analyzeResult.applyToTableArgument(u.name, t)
+ case c => c
+ }
+ PythonUDTF(u.name, u.func, analyzeResult.schema, newChildren,
+ u.evalType, u.udfDeterministic, u.resultId)
}
}
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 83b682bc917..de453f6bc49 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -1075,6 +1075,11 @@ trait CheckAnalysis extends PredicateHelper with
LookupCatalog with QueryErrorsB
// allowed by spark.
checkCorrelationsInSubquery(expr.plan, isLateral = true)
+ case _: FunctionTableSubqueryArgumentExpression =>
+ expr.failAnalysis(
+ errorClass =
"UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.UNSUPPORTED_TABLE_ARGUMENT",
+ messageParameters = Map("treeNode" -> planToString(plan)))
+
case inSubqueryOrExistsSubquery =>
plan match {
case _: Filter | _: SupportsSubquery | _: Join |
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
index a615348bc6e..bc74572444c 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
@@ -180,8 +180,7 @@ case class PythonUDTF(
evalType: Int,
udfDeterministic: Boolean,
resultId: ExprId = NamedExpression.newExprId,
- pythonUDTFPartitionColumnIndexes: Option[PythonUDTFPartitionColumnIndexes]
= None,
- analyzeResult: Option[PythonUDTFAnalyzeResult] = None)
+ pythonUDTFPartitionColumnIndexes: Option[PythonUDTFPartitionColumnIndexes]
= None)
extends UnevaluableGenerator with PythonFuncExpression {
override lazy val canonicalized: Expression = {
@@ -210,8 +209,7 @@ case class UnresolvedPolymorphicPythonUDTF(
evalType: Int,
udfDeterministic: Boolean,
resolveElementMetadata: (PythonFunction, Seq[Expression]) =>
PythonUDTFAnalyzeResult,
- resultId: ExprId = NamedExpression.newExprId,
- pythonUDTFPartitionColumnIndexes: Option[PythonUDTFPartitionColumnIndexes]
= None)
+ resultId: ExprId = NamedExpression.newExprId)
extends UnevaluableGenerator with PythonFuncExpression {
override lazy val resolved = false
diff --git
a/sql/core/src/test/resources/sql-tests/analyzer-results/named-function-arguments.sql.out
b/sql/core/src/test/resources/sql-tests/analyzer-results/named-function-arguments.sql.out
index 11e2651c6f2..4ba47e9e1b4 100644
---
a/sql/core/src/test/resources/sql-tests/analyzer-results/named-function-arguments.sql.out
+++
b/sql/core/src/test/resources/sql-tests/analyzer-results/named-function-arguments.sql.out
@@ -202,21 +202,17 @@ SELECT * FROM explode(collection => TABLE(v))
-- !query analysis
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
- "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
- "sqlState" : "42K09",
+ "errorClass" :
"UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.UNSUPPORTED_TABLE_ARGUMENT",
+ "sqlState" : "0A000",
"messageParameters" : {
- "inputSql" : "\"outer(__auto_generated_subquery_name_0.c)\"",
- "inputType" : "\"STRUCT<id: BIGINT>\"",
- "paramIndex" : "1",
- "requiredType" : "(\"ARRAY\" or \"MAP\")",
- "sqlExpr" : "\"explode(outer(__auto_generated_subquery_name_0.c))\""
+ "treeNode" : "'Generate explode(table-argument#x []), false\n: +-
SubqueryAlias v\n: +- View (`v`, [id#xL])\n: +- Project [cast(id#xL
as bigint) AS id#xL]\n: +- Project [id#xL]\n: +- Range
(0, 8, step=1, splits=None)\n+- OneRowRelation\n"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
- "startIndex" : 15,
- "stopIndex" : 45,
- "fragment" : "explode(collection => TABLE(v))"
+ "startIndex" : 37,
+ "stopIndex" : 44,
+ "fragment" : "TABLE(v)"
} ]
}
diff --git
a/sql/core/src/test/resources/sql-tests/results/named-function-arguments.sql.out
b/sql/core/src/test/resources/sql-tests/results/named-function-arguments.sql.out
index 60301862a35..03963ac3ef9 100644
---
a/sql/core/src/test/resources/sql-tests/results/named-function-arguments.sql.out
+++
b/sql/core/src/test/resources/sql-tests/results/named-function-arguments.sql.out
@@ -185,21 +185,17 @@ struct<>
-- !query output
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
- "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
- "sqlState" : "42K09",
+ "errorClass" :
"UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.UNSUPPORTED_TABLE_ARGUMENT",
+ "sqlState" : "0A000",
"messageParameters" : {
- "inputSql" : "\"outer(__auto_generated_subquery_name_0.c)\"",
- "inputType" : "\"STRUCT<id: BIGINT>\"",
- "paramIndex" : "1",
- "requiredType" : "(\"ARRAY\" or \"MAP\")",
- "sqlExpr" : "\"explode(outer(__auto_generated_subquery_name_0.c))\""
+ "treeNode" : "'Generate explode(table-argument#x []), false\n: +-
SubqueryAlias v\n: +- View (`v`, [id#xL])\n: +- Project [cast(id#xL
as bigint) AS id#xL]\n: +- Project [id#xL]\n: +- Range
(0, 8, step=1, splits=None)\n+- OneRowRelation\n"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
- "startIndex" : 15,
- "stopIndex" : 45,
- "fragment" : "explode(collection => TABLE(v))"
+ "startIndex" : 37,
+ "stopIndex" : 44,
+ "fragment" : "TABLE(v)"
} ]
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala
index cf687f90287..cdc3ef9e417 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala
@@ -128,6 +128,8 @@ class PythonUDTFSuite extends QueryTest with
SharedSparkSession {
def failure(plan: LogicalPlan): Unit = {
fail(s"Unexpected plan: $plan")
}
+
+ spark.udtf.registerPython("testUDTF", pythonUDTF)
sql(
"""
|SELECT * FROM testUDTF(
@@ -187,19 +189,15 @@ class PythonUDTFSuite extends QueryTest with
SharedSparkSession {
withTable("t") {
sql("create table t(col array<int>) using parquet")
val query = "select * from explode(table(t))"
- checkError(
+ checkErrorMatchPVals(
exception = intercept[AnalysisException](sql(query)),
- errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
- parameters = Map(
- "sqlExpr" ->
"\"explode(outer(__auto_generated_subquery_name_0.c))\"",
- "paramIndex" -> "1",
- "inputSql" -> "\"outer(__auto_generated_subquery_name_0.c)\"",
- "inputType" -> "\"STRUCT<col: ARRAY<INT>>\"",
- "requiredType" -> "(\"ARRAY\" or \"MAP\")"),
+ errorClass =
"UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.UNSUPPORTED_TABLE_ARGUMENT",
+ sqlState = None,
+ parameters = Map("treeNode" -> "(?s).*"),
context = ExpectedContext(
- fragment = "explode(table(t))",
- start = 14,
- stop = 30))
+ fragment = "table(t)",
+ start = 22,
+ stop = 29))
}
spark.udtf.registerPython("UDTFCountSumLast", pythonUDTFCountSumLast)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]