This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 37307fb0d14e [SPARK-47241][SQL] Fix rule order issues for 
ExtractGenerator
37307fb0d14e is described below

commit 37307fb0d14e03b1085ed12d8d540d2606bf1e9d
Author: Wenchen Fan <wenc...@databricks.com>
AuthorDate: Thu Mar 7 17:02:09 2024 +0800

    [SPARK-47241][SQL] Fix rule order issues for ExtractGenerator
    
    ### What changes were proposed in this pull request?
    
    The rule `ExtractGenerator` does not define any trigger condition when 
rewriting generator functions in `Project`, which makes the behavior quite 
unstable and heavily depends on the execution order of analyzer rules.
    
    Two bugs I've found so far:
    1. By design, we want to forbid users from using more than one generator 
function in SELECT. However, we can't really enforce it if two generator 
functions are not resolved at the same time: the rule thinks there is only one 
generate function (the other is still unresolved), then rewrite it. The other 
one gets resolved later and gets rewritten later.
    2. When a generator function is put after `SELECT *`, it's possible that 
`*` is not expanded yet when we enter `ExtractGenerator`. The rule rewrites the 
generator function: insert a `Generate` operator below, and add a new column to 
the projectList for the generator function output. Then we expand `*` to the 
child plan output which is `Generate`, we end up with two identical columns for 
the generate function output.
    
    This PR fixes it by adding a trigger condition when rewriting generator 
functions in `Project`: the projectList should be resolved or a generator 
function. This is the same trigger condition we used for `Aggregate`. To avoid 
breaking changes, this PR also allows multiple generator functions in 
`Project`, which works totally fine.
    ### Why are the changes needed?
    
    bug fix
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, now multiple generator functions are allowed in `Project`. And there 
won't be duplicated columns for generator function output.
    
    ### How was this patch tested?
    
    new test
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #45350 from cloud-fan/generate.
    
    Lead-authored-by: Wenchen Fan <wenc...@databricks.com>
    Co-authored-by: Wenchen Fan <cloud0...@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
    (cherry picked from commit 51f4cfa7560bba576577d3a5f254daaad516849d)
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../src/main/resources/error/error-classes.json    |  2 +-
 ...conditions-unsupported-generator-error-class.md |  2 +-
 .../spark/sql/catalyst/analysis/Analyzer.scala     | 43 ++++++++++++++--------
 .../sql/catalyst/analysis/CheckAnalysis.scala      | 10 -----
 .../spark/sql/errors/QueryCompilationErrors.scala  |  3 +-
 .../sql/catalyst/analysis/AnalysisErrorSuite.scala | 14 +------
 .../org/apache/spark/sql/DataFrameSuite.scala      | 14 -------
 .../apache/spark/sql/GeneratorFunctionSuite.scala  | 27 +++++++++++++-
 .../sql/errors/QueryCompilationErrorsSuite.scala   | 12 ------
 .../spark/sql/hive/execution/HiveQuerySuite.scala  | 22 -----------
 10 files changed, 57 insertions(+), 92 deletions(-)

diff --git a/common/utils/src/main/resources/error/error-classes.json 
b/common/utils/src/main/resources/error/error-classes.json
index 2d50fe1a1a1a..b9d4c2c297f8 100644
--- a/common/utils/src/main/resources/error/error-classes.json
+++ b/common/utils/src/main/resources/error/error-classes.json
@@ -3056,7 +3056,7 @@
     "subClass" : {
       "MULTI_GENERATOR" : {
         "message" : [
-          "only one generator allowed per <clause> clause but found <num>: 
<generators>."
+          "only one generator allowed per SELECT clause but found <num>: 
<generators>."
         ]
       },
       "NESTED_IN_EXPRESSIONS" : {
diff --git a/docs/sql-error-conditions-unsupported-generator-error-class.md 
b/docs/sql-error-conditions-unsupported-generator-error-class.md
index 7960c14767d1..38b3bbfaa3c3 100644
--- a/docs/sql-error-conditions-unsupported-generator-error-class.md
+++ b/docs/sql-error-conditions-unsupported-generator-error-class.md
@@ -27,7 +27,7 @@ This error class has the following derived error classes:
 
 ## MULTI_GENERATOR
 
-only one generator allowed per `<clause>` clause but found `<num>`: 
`<generators>`.
+only one generator allowed per SELECT clause but found `<num>`: `<generators>`.
 
 ## NESTED_IN_EXPRESSIONS
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 8fe87a05d02d..eae150001249 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -2742,28 +2742,36 @@ class Analyzer(override val catalogManager: 
CatalogManager) extends RuleExecutor
       }
     }
 
+    // We must wait until all expressions except for generator functions are 
resolved before
+    // rewriting generator functions in Project/Aggregate. This is necessary 
to make this rule
+    // stable for different execution orders of analyzer rules. See also 
SPARK-47241.
+    private def canRewriteGenerator(namedExprs: Seq[NamedExpression]): Boolean 
= {
+      namedExprs.forall { ne =>
+        ne.resolved || {
+          trimNonTopLevelAliases(ne) match {
+            case AliasedGenerator(_, _, _) => true
+            case _ => false
+          }
+        }
+      }
+    }
+
     def apply(plan: LogicalPlan): LogicalPlan = 
plan.resolveOperatorsUpWithPruning(
       _.containsPattern(GENERATOR), ruleId) {
       case Project(projectList, _) if projectList.exists(hasNestedGenerator) =>
         val nestedGenerator = projectList.find(hasNestedGenerator).get
         throw 
QueryCompilationErrors.nestedGeneratorError(trimAlias(nestedGenerator))
 
-      case Project(projectList, _) if projectList.count(hasGenerator) > 1 =>
-        val generators = projectList.filter(hasGenerator).map(trimAlias)
-        throw QueryCompilationErrors.moreThanOneGeneratorError(generators, 
"SELECT")
-
       case Aggregate(_, aggList, _) if aggList.exists(hasNestedGenerator) =>
         val nestedGenerator = aggList.find(hasNestedGenerator).get
         throw 
QueryCompilationErrors.nestedGeneratorError(trimAlias(nestedGenerator))
 
       case Aggregate(_, aggList, _) if aggList.count(hasGenerator) > 1 =>
         val generators = aggList.filter(hasGenerator).map(trimAlias)
-        throw QueryCompilationErrors.moreThanOneGeneratorError(generators, 
"aggregate")
+        throw QueryCompilationErrors.moreThanOneGeneratorError(generators)
 
-      case agg @ Aggregate(groupList, aggList, child) if aggList.forall {
-          case AliasedGenerator(_, _, _) => true
-          case other => other.resolved
-        } && aggList.exists(hasGenerator) =>
+      case Aggregate(groupList, aggList, child) if 
canRewriteGenerator(aggList) &&
+          aggList.exists(hasGenerator) =>
         // If generator in the aggregate list was visited, set the boolean 
flag true.
         var generatorVisited = false
 
@@ -2808,16 +2816,16 @@ class Analyzer(override val catalogManager: 
CatalogManager) extends RuleExecutor
         // first for replacing `Project` with `Aggregate`.
         p
 
-      case p @ Project(projectList, child) =>
+      case p @ Project(projectList, child) if canRewriteGenerator(projectList) 
&&
+          projectList.exists(hasGenerator) =>
         val (resolvedGenerator, newProjectList) = projectList
           .map(trimNonTopLevelAliases)
           .foldLeft((None: Option[Generate], Nil: Seq[NamedExpression])) { 
(res, e) =>
             e match {
-              case AliasedGenerator(generator, names, outer) if 
generator.childrenResolved =>
-                // It's a sanity check, this should not happen as the previous 
case will throw
-                // exception earlier.
-                assert(res._1.isEmpty, "More than one generator found in 
SELECT.")
-
+              // If there are more than one generator, we only rewrite the 
first one and wait for
+              // the next analyzer iteration to rewrite the next one.
+              case AliasedGenerator(generator, names, outer) if res._1.isEmpty 
&&
+                  generator.childrenResolved =>
                 val g = Generate(
                   generator,
                   unrequiredChildIndex = Nil,
@@ -2825,7 +2833,6 @@ class Analyzer(override val catalogManager: 
CatalogManager) extends RuleExecutor
                   qualifier = None,
                   generatorOutput = 
ResolveGenerate.makeGeneratorOutput(generator, names),
                   child)
-
                 (Some(g), res._2 ++ g.nullableOutput)
               case other =>
                 (res._1, res._2 :+ other)
@@ -2845,6 +2852,10 @@ class Analyzer(override val catalogManager: 
CatalogManager) extends RuleExecutor
 
       case u: UnresolvedTableValuedFunction => u
 
+      case p: Project => p
+
+      case a: Aggregate => a
+
       case p if p.expressions.exists(hasGenerator) =>
         throw QueryCompilationErrors.generatorOutsideSelectError(p)
     }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 533ea8a2b799..7f10bdbc80ca 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -64,12 +64,6 @@ trait CheckAnalysis extends PredicateHelper with 
LookupCatalog with QueryErrorsB
       messageParameters = messageParameters)
   }
 
-  protected def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = {
-    exprs.flatMap(_.collect {
-      case e: Generator => e
-    }).length > 1
-  }
-
   protected def hasMapType(dt: DataType): Boolean = {
     dt.existsRecursively(_.isInstanceOf[MapType])
   }
@@ -687,10 +681,6 @@ trait CheckAnalysis extends PredicateHelper with 
LookupCatalog with QueryErrorsB
                 ))
             }
 
-          case p @ Project(exprs, _) if containsMultipleGenerators(exprs) =>
-            val generators = exprs.filter(expr => 
expr.exists(_.isInstanceOf[Generator]))
-            throw QueryCompilationErrors.moreThanOneGeneratorError(generators, 
"SELECT")
-
           case p @ Project(projectList, _) =>
             projectList.foreach(_.transformDownWithPruning(
               _.containsPattern(UNRESOLVED_WINDOW_EXPRESSION)) {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index 9dca2c5f2822..a78e092c4bfa 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -248,11 +248,10 @@ private[sql] object QueryCompilationErrors extends 
QueryErrorsBase with Compilat
       messageParameters = Map("expression" -> 
toSQLExpr(trimmedNestedGenerator)))
   }
 
-  def moreThanOneGeneratorError(generators: Seq[Expression], clause: String): 
Throwable = {
+  def moreThanOneGeneratorError(generators: Seq[Expression]): Throwable = {
     new AnalysisException(
       errorClass = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR",
       messageParameters = Map(
-        "clause" -> clause,
         "num" -> generators.size.toString,
         "generators" -> generators.map(toSQLExpr).mkString(", ")))
   }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index e2e980073307..e8dc9061199c 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -344,11 +344,6 @@ class AnalysisErrorSuite extends AnalysisTest {
       "inputType" -> "\"BOOLEAN\"",
       "requiredType" -> "\"INT\""))
 
-  errorTest(
-    "too many generators",
-    listRelation.select(Explode($"list").as("a"), Explode($"list").as("b")),
-    "only one generator" :: "explode" :: Nil)
-
   errorClassTest(
     "unresolved attributes",
     testRelation.select($"abcd"),
@@ -754,18 +749,11 @@ class AnalysisErrorSuite extends AnalysisTest {
     "SUM_OF_LIMIT_AND_OFFSET_EXCEEDS_MAX_INT",
     Map("limit" -> "1000000000", "offset" -> "2000000000"))
 
-  errorTest(
-    "more than one generators in SELECT",
-    listRelation.select(Explode($"list"), Explode($"list")),
-    "The generator is not supported: only one generator allowed per select 
clause but found 2: " +
-      """"explode(list)", "explode(list)"""" :: Nil
-  )
-
   errorTest(
     "more than one generators for aggregates in SELECT",
     testRelation.select(Explode(CreateArray(min($"a") :: Nil)),
       Explode(CreateArray(max($"a") :: Nil))),
-    "The generator is not supported: only one generator allowed per select 
clause but found 2: " +
+    "The generator is not supported: only one generator allowed per SELECT 
clause but found 2: " +
       """"explode(array(min(a)))", "explode(array(max(a)))"""" :: Nil
   )
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 002719f06896..c586da6105fd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -368,20 +368,6 @@ class DataFrameSuite extends QueryTest
       Row("a", Seq("a"), 1) :: Nil)
   }
 
-  test("more than one generator in SELECT clause") {
-    val df = Seq((Array("a"), 1)).toDF("a", "b")
-
-    checkError(
-      exception = intercept[AnalysisException] {
-        df.select(explode($"a").as("a"), explode($"a").as("b"))
-      },
-      errorClass = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR",
-      parameters = Map(
-        "clause" -> "SELECT",
-        "num" -> "2",
-        "generators" -> "\"explode(a)\", \"explode(a)\""))
-  }
-
   test("sort after generate with join=true") {
     val df = Seq((Array("a"), 1)).toDF("a", "b")
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
index 0746a4b92af2..7c285759fcd9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
@@ -432,7 +432,6 @@ class GeneratorFunctionSuite extends QueryTest with 
SharedSparkSession {
         },
         errorClass = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR",
         parameters = Map(
-          "clause" -> "aggregate",
           "num" -> "2",
           "generators" -> ("\"explode(array(min(c2), max(c2)))\", " +
             "\"posexplode(array(min(c2), max(c2)))\"")))
@@ -543,6 +542,32 @@ class GeneratorFunctionSuite extends QueryTest with 
SharedSparkSession {
       checkAnswer(df, Row(0.7604953758285915d))
     }
   }
+
+  test("SPARK-47241: two generator functions in SELECT") {
+    def testTwoGenerators(needImplicitCast: Boolean): Unit = {
+      val df = sql(
+        s"""
+          |SELECT
+          |explode(array('a', 'b')) as c1,
+          |explode(array(0L, ${if (needImplicitCast) "0L + 1" else "1L"})) as 
c2
+          |""".stripMargin)
+      checkAnswer(df, Seq(Row("a", 0L), Row("a", 1L), Row("b", 0L), Row("b", 
1L)))
+    }
+    testTwoGenerators(needImplicitCast = true)
+    testTwoGenerators(needImplicitCast = false)
+  }
+
+  test("SPARK-47241: generator function after wildcard in SELECT") {
+    val df = sql(
+      s"""
+         |SELECT *, explode(array('a', 'b')) as c1
+         |FROM
+         |(
+         |  SELECT id FROM range(1) GROUP BY 1
+         |)
+         |""".stripMargin)
+    checkAnswer(df, Seq(Row(0, "a"), Row(0, "b")))
+  }
 }
 
 case class EmptyGenerator() extends Generator with LeafLike[Expression] {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala
index 7f938deaaa64..ac57c958828b 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala
@@ -646,18 +646,6 @@ class QueryCompilationErrorsSuite
       parameters = Map("expression" -> "\"(explode(array(1, 2, 3)) + 1)\""))
   }
 
-  test("UNSUPPORTED_GENERATOR: only one generator allowed") {
-    val e = intercept[AnalysisException](
-      sql("""select explode(Array(1, 2, 3)), explode(Array(1, 2, 
3))""").collect()
-    )
-
-    checkError(
-      exception = e,
-      errorClass = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR",
-      parameters = Map("clause" -> "SELECT", "num" -> "2",
-        "generators" -> "\"explode(array(1, 2, 3))\", \"explode(array(1, 2, 
3))\""))
-  }
-
   test("UNSUPPORTED_GENERATOR: generators are not supported outside the SELECT 
clause") {
     val e = intercept[AnalysisException](
       sql("""select 1 from t order by explode(Array(1, 2, 3))""").collect()
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index 82b88ec9f35d..4b85b37b6c2c 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -161,28 +161,6 @@ class HiveQuerySuite extends HiveComparisonTest with 
SQLTestUtils with BeforeAnd
       |  SELECT key FROM gen_tmp ORDER BY key ASC;
     """.stripMargin)
 
-  test("multiple generators in projection") {
-    checkError(
-      exception = intercept[AnalysisException] {
-        sql("SELECT explode(array(key, key)), explode(array(key, key)) FROM 
src").collect()
-      },
-      errorClass = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR",
-      parameters = Map(
-        "clause" -> "SELECT",
-        "num" -> "2",
-        "generators" -> "\"explode(array(key, key))\", \"explode(array(key, 
key))\""))
-
-    checkError(
-      exception = intercept[AnalysisException] {
-        sql("SELECT explode(array(key, key)) as k1, explode(array(key, key)) 
FROM src").collect()
-      },
-      errorClass = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR",
-      parameters = Map(
-        "clause" -> "SELECT",
-        "num" -> "2",
-        "generators" -> "\"explode(array(key, key))\", \"explode(array(key, 
key))\""))
-  }
-
   createQueryTest("! operator",
     """
       |SELECT a FROM (


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to