This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push:
new 37307fb0d14e [SPARK-47241][SQL] Fix rule order issues for
ExtractGenerator
37307fb0d14e is described below
commit 37307fb0d14e03b1085ed12d8d540d2606bf1e9d
Author: Wenchen Fan <[email protected]>
AuthorDate: Thu Mar 7 17:02:09 2024 +0800
[SPARK-47241][SQL] Fix rule order issues for ExtractGenerator
### What changes were proposed in this pull request?
The rule `ExtractGenerator` does not define any trigger condition when
rewriting generator functions in `Project`, which makes the behavior quite
unstable and heavily depends on the execution order of analyzer rules.
Two bugs I've found so far:
1. By design, we want to forbid users from using more than one generator
function in SELECT. However, we can't really enforce it if two generator
functions are not resolved at the same time: the rule thinks there is only one
generate function (the other is still unresolved), then rewrite it. The other
one gets resolved later and gets rewritten later.
2. When a generator function is put after `SELECT *`, it's possible that
`*` is not expanded yet when we enter `ExtractGenerator`. The rule rewrites the
generator function: insert a `Generate` operator below, and add a new column to
the projectList for the generator function output. Then we expand `*` to the
child plan output which is `Generate`, we end up with two identical columns for
the generate function output.
This PR fixes it by adding a trigger condition when rewriting generator
functions in `Project`: the projectList should be resolved or a generator
function. This is the same trigger condition we used for `Aggregate`. To avoid
breaking changes, this PR also allows multiple generator functions in
`Project`, which works totally fine.
### Why are the changes needed?
bug fix
### Does this PR introduce _any_ user-facing change?
Yes, now multiple generator functions are allowed in `Project`. And there
won't be duplicated columns for generator function output.
### How was this patch tested?
new test
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #45350 from cloud-fan/generate.
Lead-authored-by: Wenchen Fan <[email protected]>
Co-authored-by: Wenchen Fan <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
(cherry picked from commit 51f4cfa7560bba576577d3a5f254daaad516849d)
Signed-off-by: Wenchen Fan <[email protected]>
---
.../src/main/resources/error/error-classes.json | 2 +-
...conditions-unsupported-generator-error-class.md | 2 +-
.../spark/sql/catalyst/analysis/Analyzer.scala | 43 ++++++++++++++--------
.../sql/catalyst/analysis/CheckAnalysis.scala | 10 -----
.../spark/sql/errors/QueryCompilationErrors.scala | 3 +-
.../sql/catalyst/analysis/AnalysisErrorSuite.scala | 14 +------
.../org/apache/spark/sql/DataFrameSuite.scala | 14 -------
.../apache/spark/sql/GeneratorFunctionSuite.scala | 27 +++++++++++++-
.../sql/errors/QueryCompilationErrorsSuite.scala | 12 ------
.../spark/sql/hive/execution/HiveQuerySuite.scala | 22 -----------
10 files changed, 57 insertions(+), 92 deletions(-)
diff --git a/common/utils/src/main/resources/error/error-classes.json
b/common/utils/src/main/resources/error/error-classes.json
index 2d50fe1a1a1a..b9d4c2c297f8 100644
--- a/common/utils/src/main/resources/error/error-classes.json
+++ b/common/utils/src/main/resources/error/error-classes.json
@@ -3056,7 +3056,7 @@
"subClass" : {
"MULTI_GENERATOR" : {
"message" : [
- "only one generator allowed per <clause> clause but found <num>:
<generators>."
+ "only one generator allowed per SELECT clause but found <num>:
<generators>."
]
},
"NESTED_IN_EXPRESSIONS" : {
diff --git a/docs/sql-error-conditions-unsupported-generator-error-class.md
b/docs/sql-error-conditions-unsupported-generator-error-class.md
index 7960c14767d1..38b3bbfaa3c3 100644
--- a/docs/sql-error-conditions-unsupported-generator-error-class.md
+++ b/docs/sql-error-conditions-unsupported-generator-error-class.md
@@ -27,7 +27,7 @@ This error class has the following derived error classes:
## MULTI_GENERATOR
-only one generator allowed per `<clause>` clause but found `<num>`:
`<generators>`.
+only one generator allowed per SELECT clause but found `<num>`: `<generators>`.
## NESTED_IN_EXPRESSIONS
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 8fe87a05d02d..eae150001249 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -2742,28 +2742,36 @@ class Analyzer(override val catalogManager:
CatalogManager) extends RuleExecutor
}
}
+ // We must wait until all expressions except for generator functions are
resolved before
+ // rewriting generator functions in Project/Aggregate. This is necessary
to make this rule
+ // stable for different execution orders of analyzer rules. See also
SPARK-47241.
+ private def canRewriteGenerator(namedExprs: Seq[NamedExpression]): Boolean
= {
+ namedExprs.forall { ne =>
+ ne.resolved || {
+ trimNonTopLevelAliases(ne) match {
+ case AliasedGenerator(_, _, _) => true
+ case _ => false
+ }
+ }
+ }
+ }
+
def apply(plan: LogicalPlan): LogicalPlan =
plan.resolveOperatorsUpWithPruning(
_.containsPattern(GENERATOR), ruleId) {
case Project(projectList, _) if projectList.exists(hasNestedGenerator) =>
val nestedGenerator = projectList.find(hasNestedGenerator).get
throw
QueryCompilationErrors.nestedGeneratorError(trimAlias(nestedGenerator))
- case Project(projectList, _) if projectList.count(hasGenerator) > 1 =>
- val generators = projectList.filter(hasGenerator).map(trimAlias)
- throw QueryCompilationErrors.moreThanOneGeneratorError(generators,
"SELECT")
-
case Aggregate(_, aggList, _) if aggList.exists(hasNestedGenerator) =>
val nestedGenerator = aggList.find(hasNestedGenerator).get
throw
QueryCompilationErrors.nestedGeneratorError(trimAlias(nestedGenerator))
case Aggregate(_, aggList, _) if aggList.count(hasGenerator) > 1 =>
val generators = aggList.filter(hasGenerator).map(trimAlias)
- throw QueryCompilationErrors.moreThanOneGeneratorError(generators,
"aggregate")
+ throw QueryCompilationErrors.moreThanOneGeneratorError(generators)
- case agg @ Aggregate(groupList, aggList, child) if aggList.forall {
- case AliasedGenerator(_, _, _) => true
- case other => other.resolved
- } && aggList.exists(hasGenerator) =>
+ case Aggregate(groupList, aggList, child) if
canRewriteGenerator(aggList) &&
+ aggList.exists(hasGenerator) =>
// If generator in the aggregate list was visited, set the boolean
flag true.
var generatorVisited = false
@@ -2808,16 +2816,16 @@ class Analyzer(override val catalogManager:
CatalogManager) extends RuleExecutor
// first for replacing `Project` with `Aggregate`.
p
- case p @ Project(projectList, child) =>
+ case p @ Project(projectList, child) if canRewriteGenerator(projectList)
&&
+ projectList.exists(hasGenerator) =>
val (resolvedGenerator, newProjectList) = projectList
.map(trimNonTopLevelAliases)
.foldLeft((None: Option[Generate], Nil: Seq[NamedExpression])) {
(res, e) =>
e match {
- case AliasedGenerator(generator, names, outer) if
generator.childrenResolved =>
- // It's a sanity check, this should not happen as the previous
case will throw
- // exception earlier.
- assert(res._1.isEmpty, "More than one generator found in
SELECT.")
-
+ // If there are more than one generator, we only rewrite the
first one and wait for
+ // the next analyzer iteration to rewrite the next one.
+ case AliasedGenerator(generator, names, outer) if res._1.isEmpty
&&
+ generator.childrenResolved =>
val g = Generate(
generator,
unrequiredChildIndex = Nil,
@@ -2825,7 +2833,6 @@ class Analyzer(override val catalogManager:
CatalogManager) extends RuleExecutor
qualifier = None,
generatorOutput =
ResolveGenerate.makeGeneratorOutput(generator, names),
child)
-
(Some(g), res._2 ++ g.nullableOutput)
case other =>
(res._1, res._2 :+ other)
@@ -2845,6 +2852,10 @@ class Analyzer(override val catalogManager:
CatalogManager) extends RuleExecutor
case u: UnresolvedTableValuedFunction => u
+ case p: Project => p
+
+ case a: Aggregate => a
+
case p if p.expressions.exists(hasGenerator) =>
throw QueryCompilationErrors.generatorOutsideSelectError(p)
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 533ea8a2b799..7f10bdbc80ca 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -64,12 +64,6 @@ trait CheckAnalysis extends PredicateHelper with
LookupCatalog with QueryErrorsB
messageParameters = messageParameters)
}
- protected def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = {
- exprs.flatMap(_.collect {
- case e: Generator => e
- }).length > 1
- }
-
protected def hasMapType(dt: DataType): Boolean = {
dt.existsRecursively(_.isInstanceOf[MapType])
}
@@ -687,10 +681,6 @@ trait CheckAnalysis extends PredicateHelper with
LookupCatalog with QueryErrorsB
))
}
- case p @ Project(exprs, _) if containsMultipleGenerators(exprs) =>
- val generators = exprs.filter(expr =>
expr.exists(_.isInstanceOf[Generator]))
- throw QueryCompilationErrors.moreThanOneGeneratorError(generators,
"SELECT")
-
case p @ Project(projectList, _) =>
projectList.foreach(_.transformDownWithPruning(
_.containsPattern(UNRESOLVED_WINDOW_EXPRESSION)) {
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index 9dca2c5f2822..a78e092c4bfa 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -248,11 +248,10 @@ private[sql] object QueryCompilationErrors extends
QueryErrorsBase with Compilat
messageParameters = Map("expression" ->
toSQLExpr(trimmedNestedGenerator)))
}
- def moreThanOneGeneratorError(generators: Seq[Expression], clause: String):
Throwable = {
+ def moreThanOneGeneratorError(generators: Seq[Expression]): Throwable = {
new AnalysisException(
errorClass = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR",
messageParameters = Map(
- "clause" -> clause,
"num" -> generators.size.toString,
"generators" -> generators.map(toSQLExpr).mkString(", ")))
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index e2e980073307..e8dc9061199c 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -344,11 +344,6 @@ class AnalysisErrorSuite extends AnalysisTest {
"inputType" -> "\"BOOLEAN\"",
"requiredType" -> "\"INT\""))
- errorTest(
- "too many generators",
- listRelation.select(Explode($"list").as("a"), Explode($"list").as("b")),
- "only one generator" :: "explode" :: Nil)
-
errorClassTest(
"unresolved attributes",
testRelation.select($"abcd"),
@@ -754,18 +749,11 @@ class AnalysisErrorSuite extends AnalysisTest {
"SUM_OF_LIMIT_AND_OFFSET_EXCEEDS_MAX_INT",
Map("limit" -> "1000000000", "offset" -> "2000000000"))
- errorTest(
- "more than one generators in SELECT",
- listRelation.select(Explode($"list"), Explode($"list")),
- "The generator is not supported: only one generator allowed per select
clause but found 2: " +
- """"explode(list)", "explode(list)"""" :: Nil
- )
-
errorTest(
"more than one generators for aggregates in SELECT",
testRelation.select(Explode(CreateArray(min($"a") :: Nil)),
Explode(CreateArray(max($"a") :: Nil))),
- "The generator is not supported: only one generator allowed per select
clause but found 2: " +
+ "The generator is not supported: only one generator allowed per SELECT
clause but found 2: " +
""""explode(array(min(a)))", "explode(array(max(a)))"""" :: Nil
)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 002719f06896..c586da6105fd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -368,20 +368,6 @@ class DataFrameSuite extends QueryTest
Row("a", Seq("a"), 1) :: Nil)
}
- test("more than one generator in SELECT clause") {
- val df = Seq((Array("a"), 1)).toDF("a", "b")
-
- checkError(
- exception = intercept[AnalysisException] {
- df.select(explode($"a").as("a"), explode($"a").as("b"))
- },
- errorClass = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR",
- parameters = Map(
- "clause" -> "SELECT",
- "num" -> "2",
- "generators" -> "\"explode(a)\", \"explode(a)\""))
- }
-
test("sort after generate with join=true") {
val df = Seq((Array("a"), 1)).toDF("a", "b")
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
index 0746a4b92af2..7c285759fcd9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
@@ -432,7 +432,6 @@ class GeneratorFunctionSuite extends QueryTest with
SharedSparkSession {
},
errorClass = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR",
parameters = Map(
- "clause" -> "aggregate",
"num" -> "2",
"generators" -> ("\"explode(array(min(c2), max(c2)))\", " +
"\"posexplode(array(min(c2), max(c2)))\"")))
@@ -543,6 +542,32 @@ class GeneratorFunctionSuite extends QueryTest with
SharedSparkSession {
checkAnswer(df, Row(0.7604953758285915d))
}
}
+
+ test("SPARK-47241: two generator functions in SELECT") {
+ def testTwoGenerators(needImplicitCast: Boolean): Unit = {
+ val df = sql(
+ s"""
+ |SELECT
+ |explode(array('a', 'b')) as c1,
+ |explode(array(0L, ${if (needImplicitCast) "0L + 1" else "1L"})) as
c2
+ |""".stripMargin)
+ checkAnswer(df, Seq(Row("a", 0L), Row("a", 1L), Row("b", 0L), Row("b",
1L)))
+ }
+ testTwoGenerators(needImplicitCast = true)
+ testTwoGenerators(needImplicitCast = false)
+ }
+
+ test("SPARK-47241: generator function after wildcard in SELECT") {
+ val df = sql(
+ s"""
+ |SELECT *, explode(array('a', 'b')) as c1
+ |FROM
+ |(
+ | SELECT id FROM range(1) GROUP BY 1
+ |)
+ |""".stripMargin)
+ checkAnswer(df, Seq(Row(0, "a"), Row(0, "b")))
+ }
}
case class EmptyGenerator() extends Generator with LeafLike[Expression] {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala
index 7f938deaaa64..ac57c958828b 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala
@@ -646,18 +646,6 @@ class QueryCompilationErrorsSuite
parameters = Map("expression" -> "\"(explode(array(1, 2, 3)) + 1)\""))
}
- test("UNSUPPORTED_GENERATOR: only one generator allowed") {
- val e = intercept[AnalysisException](
- sql("""select explode(Array(1, 2, 3)), explode(Array(1, 2,
3))""").collect()
- )
-
- checkError(
- exception = e,
- errorClass = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR",
- parameters = Map("clause" -> "SELECT", "num" -> "2",
- "generators" -> "\"explode(array(1, 2, 3))\", \"explode(array(1, 2,
3))\""))
- }
-
test("UNSUPPORTED_GENERATOR: generators are not supported outside the SELECT
clause") {
val e = intercept[AnalysisException](
sql("""select 1 from t order by explode(Array(1, 2, 3))""").collect()
diff --git
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index 82b88ec9f35d..4b85b37b6c2c 100644
---
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -161,28 +161,6 @@ class HiveQuerySuite extends HiveComparisonTest with
SQLTestUtils with BeforeAnd
| SELECT key FROM gen_tmp ORDER BY key ASC;
""".stripMargin)
- test("multiple generators in projection") {
- checkError(
- exception = intercept[AnalysisException] {
- sql("SELECT explode(array(key, key)), explode(array(key, key)) FROM
src").collect()
- },
- errorClass = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR",
- parameters = Map(
- "clause" -> "SELECT",
- "num" -> "2",
- "generators" -> "\"explode(array(key, key))\", \"explode(array(key,
key))\""))
-
- checkError(
- exception = intercept[AnalysisException] {
- sql("SELECT explode(array(key, key)) as k1, explode(array(key, key))
FROM src").collect()
- },
- errorClass = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR",
- parameters = Map(
- "clause" -> "SELECT",
- "num" -> "2",
- "generators" -> "\"explode(array(key, key))\", \"explode(array(key,
key))\""))
- }
-
createQueryTest("! operator",
"""
|SELECT a FROM (
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]