This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 763c448759d [SPARK-44714] Ease restriction of LCA resolution regarding 
queries with having
763c448759d is described below

commit 763c448759df02a5370b8b50cb877f855c4eda10
Author: Xinyi Yu <xinyi...@databricks.com>
AuthorDate: Tue Aug 8 17:07:04 2023 +0800

    [SPARK-44714] Ease restriction of LCA resolution regarding queries with 
having
    
    ### What changes were proposed in this pull request?
    This PR eases some restriction of LCA resolution regarding queries with 
having.
    
    Previously LCA won't rewrite (to the new plan shape) when the whole queries 
contains `UnresolvedHaving`, in case it breaks the plan shape of 
`UnresolvedHaving - Aggregate` that can be recognized by other rules. But this 
limitation is too strict and it causes some deadlock in having - lca - window 
queries. See https://issues.apache.org/jira/browse/SPARK-42936 for more details 
and examples.
    
    With this PR, it will only skip LCA resolution on the `Aggregate` whose 
direct parent is `UnresolvedHaving`. This is enabled by a new bottom-up 
resolution without using the transform or resolve utility function.
    
    This PR also recognizes a vulnerability related to `TEMP_RESOVLED_COLUMN` 
and comments in the code. It should be considered as future work.
    
    ### Why are the changes needed?
    More complete functionality and better user experience.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    New tests.
    
    Closes #42276 from anchovYu/lca-limitation-better-error.
    
    Authored-by: Xinyi Yu <xinyi...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
    (cherry picked from commit 29e8331681c6214390f426806d19ee9673b073e1)
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../spark/sql/catalyst/analysis/Analyzer.scala     |   2 +
 .../ResolveLateralColumnAliasReference.scala       | 200 ++++++++++++---------
 .../apache/spark/sql/LateralColumnAliasSuite.scala | 145 +++++++++++++++
 3 files changed, 261 insertions(+), 86 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 09467c22e2b..6c5d19f58ac 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -285,6 +285,8 @@ class Analyzer(override val catalogManager: CatalogManager) 
extends RuleExecutor
       AddMetadataColumns ::
       DeduplicateRelations ::
       ResolveReferences ::
+      // Please do not insert any other rules in between. See the TODO 
comments in rule
+      // ResolveLateralColumnAliasReference for more details.
       ResolveLateralColumnAliasReference ::
       ResolveExpressionsWithNamePlaceholders ::
       ResolveDeserializer ::
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala
index 5d89de00084..c249a3506f2 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala
@@ -22,7 +22,7 @@ import 
org.apache.spark.sql.catalyst.expressions.WindowExpression.hasWindowExpre
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, 
Project}
 import org.apache.spark.sql.catalyst.rules.Rule
-import 
org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE,
 TEMP_RESOLVED_COLUMN, UNRESOLVED_HAVING}
+import 
org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE,
 TEMP_RESOLVED_COLUMN}
 import org.apache.spark.sql.catalyst.util.toPrettySQL
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.internal.SQLConf
@@ -131,95 +131,97 @@ object ResolveLateralColumnAliasReference extends 
Rule[LogicalPlan] {
       (pList.exists(hasWindowExpression) && p.expressions.forall(_.resolved) 
&& p.childrenResolved)
   }
 
-  override def apply(plan: LogicalPlan): LogicalPlan = {
-    if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) {
-      plan
-    } else if (plan.containsAnyPattern(TEMP_RESOLVED_COLUMN, 
UNRESOLVED_HAVING)) {
-      // It should not change the plan if `TempResolvedColumn` or 
`UnresolvedHaving` is present in
-      // the query plan. These plans need certain plan shape to get recognized 
and resolved by other
-      // rules, such as Filter/Sort + Aggregate to be matched by 
ResolveAggregateFunctions.
-      // LCA resolution can break the plan shape, like adding Project above 
Aggregate.
-      plan
-    } else {
-      // phase 2: unwrap
-      plan.resolveOperatorsUpWithPruning(
-        _.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE), ruleId) {
-        case p @ Project(projectList, child) if ruleApplicableOnOperator(p, 
projectList)
-          && 
projectList.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) =>
-          var aliasMap = AttributeMap.empty[AliasEntry]
-          val referencedAliases = collection.mutable.Set.empty[AliasEntry]
-          def unwrapLCAReference(e: NamedExpression): NamedExpression = {
-            
e.transformWithPruning(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) {
-              case lcaRef: LateralColumnAliasReference if 
aliasMap.contains(lcaRef.a) =>
-                val aliasEntry = aliasMap.get(lcaRef.a).get
-                // If there is no chaining of lateral column alias reference, 
push down the alias
-                // and unwrap the LateralColumnAliasReference to the 
NamedExpression inside
-                // If there is chaining, don't resolve and save to future 
rounds
-                if 
(!aliasEntry.alias.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) {
-                  referencedAliases += aliasEntry
-                  lcaRef.ne
-                } else {
-                  lcaRef
-                }
-              case lcaRef: LateralColumnAliasReference if 
!aliasMap.contains(lcaRef.a) =>
-                // It shouldn't happen, but restore to unresolved attribute to 
be safe.
-                UnresolvedAttribute(lcaRef.nameParts)
-            }.asInstanceOf[NamedExpression]
-          }
-          val newProjectList = projectList.zipWithIndex.map {
-            case (a: Alias, idx) =>
-              val lcaResolved = unwrapLCAReference(a)
-              // Insert the original alias instead of rewritten one to detect 
chained LCA
-              aliasMap += (a.toAttribute -> AliasEntry(a, idx))
-              lcaResolved
-            case (e, _) =>
-              unwrapLCAReference(e)
-          }
+  /** Internal application method. A hand-written bottom-up recursive 
traverse. */
+  private def apply0(plan: LogicalPlan): LogicalPlan = {
+    plan match {
+      case p: LogicalPlan if 
!p.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE) =>
+        p
 
-          if (referencedAliases.isEmpty) {
-            p
-          } else {
-            val outerProjectList = collection.mutable.Seq(newProjectList: _*)
-            val innerProjectList =
-              
collection.mutable.ArrayBuffer(child.output.map(_.asInstanceOf[NamedExpression]):
 _*)
-            referencedAliases.foreach { case AliasEntry(alias: Alias, idx) =>
-              outerProjectList.update(idx, alias.toAttribute)
-              innerProjectList += alias
-            }
-            p.copy(
-              projectList = outerProjectList.toSeq,
-              child = Project(innerProjectList.toSeq, child)
-            )
-          }
+      // It should not change the Aggregate (and thus the plan shape) if its 
parent is an
+      // UnresolvedHaving, to avoid breaking the shape pattern 
`UnresolvedHaving - Aggregate`
+      // matched by ResolveAggregateFunctions. See SPARK-42936 and SPARK-44714 
for more details.
+      case u @ UnresolvedHaving(_, agg: Aggregate) =>
+        u.copy(child = agg.mapChildren(apply0))
 
-        case agg @ Aggregate(groupingExpressions, aggregateExpressions, _)
-          if ruleApplicableOnOperator(agg, aggregateExpressions)
-            && 
aggregateExpressions.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) 
=>
+      case pOriginal: Project if ruleApplicableOnOperator(pOriginal, 
pOriginal.projectList)
+          && 
pOriginal.projectList.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) 
=>
+        val p @ Project(projectList, child) = pOriginal.mapChildren(apply0)
+        var aliasMap = AttributeMap.empty[AliasEntry]
+        val referencedAliases = collection.mutable.Set.empty[AliasEntry]
+        def unwrapLCAReference(e: NamedExpression): NamedExpression = {
+          
e.transformWithPruning(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) {
+            case lcaRef: LateralColumnAliasReference if 
aliasMap.contains(lcaRef.a) =>
+              val aliasEntry = aliasMap.get(lcaRef.a).get
+              // If there is no chaining of lateral column alias reference, 
push down the alias
+              // and unwrap the LateralColumnAliasReference to the 
NamedExpression inside
+              // If there is chaining, don't resolve and save to future rounds
+              if 
(!aliasEntry.alias.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) {
+                referencedAliases += aliasEntry
+                lcaRef.ne
+              } else {
+                lcaRef
+              }
+            case lcaRef: LateralColumnAliasReference if 
!aliasMap.contains(lcaRef.a) =>
+              // It shouldn't happen, but restore to unresolved attribute to 
be safe.
+              UnresolvedAttribute(lcaRef.nameParts)
+          }.asInstanceOf[NamedExpression]
+        }
+        val newProjectList = projectList.zipWithIndex.map {
+          case (a: Alias, idx) =>
+            val lcaResolved = unwrapLCAReference(a)
+            // Insert the original alias instead of rewritten one to detect 
chained LCA
+            aliasMap += (a.toAttribute -> AliasEntry(a, idx))
+            lcaResolved
+          case (e, _) =>
+            unwrapLCAReference(e)
+        }
 
-          // Check if current Aggregate is eligible to lift up with Project: 
the aggregate
-          // expression only contains: 1) aggregate functions, 2) grouping 
expressions, 3) leaf
-          // expressions excluding attributes not in grouping expressions
-          // This check is to prevent unnecessary transformation on invalid 
plan, to guarantee it
-          // throws the same exception. For example, cases like non-aggregate 
expressions not
-          // in group by, once transformed, will throw a different exception: 
missing input.
-          def eligibleToLiftUp(exp: Expression): Boolean = {
-            exp match {
-              case _: AggregateExpression => true
-              case e if groupingExpressions.exists(_.semanticEquals(e)) => true
-              case a: Attribute => false
-              case s: ScalarSubquery if s.children.nonEmpty
-                && !groupingExpressions.exists(_.semanticEquals(s)) => false
-              // Manually skip detection on function itself because it can be 
an aggregate function.
-              // This is to avoid expressions like sum(salary) over () 
eligible to lift up.
-              case WindowExpression(function, spec) =>
-                function.children.forall(eligibleToLiftUp) && 
eligibleToLiftUp(spec)
-              case e => e.children.forall(eligibleToLiftUp)
-            }
-          }
-          if (!aggregateExpressions.forall(eligibleToLiftUp)) {
-            return agg
+        if (referencedAliases.isEmpty) {
+          p
+        } else {
+          val outerProjectList = collection.mutable.Seq(newProjectList: _*)
+          val innerProjectList =
+            
collection.mutable.ArrayBuffer(child.output.map(_.asInstanceOf[NamedExpression]):
 _*)
+          referencedAliases.foreach { case AliasEntry(alias: Alias, idx) =>
+            outerProjectList.update(idx, alias.toAttribute)
+            innerProjectList += alias
           }
+          p.copy(
+            projectList = outerProjectList.toSeq,
+            child = Project(innerProjectList.toSeq, child)
+          )
+        }
+
+      case aggOriginal: Aggregate
+        if ruleApplicableOnOperator(aggOriginal, 
aggOriginal.aggregateExpressions)
+          && aggOriginal.aggregateExpressions.exists(
+            _.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) =>
+        val agg @ Aggregate(groupingExpressions, aggregateExpressions, _) =
+          aggOriginal.mapChildren(apply0)
 
+        // Check if current Aggregate is eligible to lift up with Project: the 
aggregate
+        // expression only contains: 1) aggregate functions, 2) grouping 
expressions, 3) leaf
+        // expressions excluding attributes not in grouping expressions
+        // This check is to prevent unnecessary transformation on invalid 
plan, to guarantee it
+        // throws the same exception. For example, cases like non-aggregate 
expressions not
+        // in group by, once transformed, will throw a different exception: 
missing input.
+        def eligibleToLiftUp(exp: Expression): Boolean = {
+          exp match {
+            case _: AggregateExpression => true
+            case e if groupingExpressions.exists(_.semanticEquals(e)) => true
+            case a: Attribute => false
+            case s: ScalarSubquery if s.children.nonEmpty
+              && !groupingExpressions.exists(_.semanticEquals(s)) => false
+            // Manually skip detection on function itself because it can be an 
aggregate function.
+            // This is to avoid expressions like sum(salary) over () eligible 
to lift up.
+            case WindowExpression(function, spec) =>
+              function.children.forall(eligibleToLiftUp) && 
eligibleToLiftUp(spec)
+            case e => e.children.forall(eligibleToLiftUp)
+          }
+        }
+        if (!aggregateExpressions.forall(eligibleToLiftUp)) {
+          agg
+        } else {
           val newAggExprs = collection.mutable.Set.empty[NamedExpression]
           val expressionMap = 
collection.mutable.LinkedHashMap.empty[Expression, NamedExpression]
           // Extract the expressions to keep in the Aggregate. Return the 
transformed expression
@@ -262,7 +264,33 @@ object ResolveLateralColumnAliasReference extends 
Rule[LogicalPlan] {
             projectList = projectExprs,
             child = agg.copy(aggregateExpressions = newAggExprs.toSeq)
           )
-      }
+        }
+
+      case p: LogicalPlan =>
+        p.mapChildren(apply0)
+    }
+  }
+
+  override def apply(plan: LogicalPlan): LogicalPlan = {
+    if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) {
+      plan
+    } else if (plan.containsAnyPattern(TEMP_RESOLVED_COLUMN)) {
+      // It should not change the plan if `TempResolvedColumn` is present in 
the query plan. These
+      // plans need certain plan shape to get recognized and resolved by other 
rules, such as
+      // Filter/Sort + Aggregate to be matched by ResolveAggregateFunctions. 
LCA resolution can
+      // break the plan shape, like adding Project above Aggregate.
+      // TODO: this condition only guarantees to keep the shape after the plan 
has
+      //  `TempResolvedColumn`. However, it does not consider the case of 
breaking the shape even
+      //  before `TempResolvedColumn` is generated by matching Filter/Sort - 
Aggregate in
+      //  ResolveReferences. Currently the correctness of this case now relies 
on the rule
+      //  application order, that ResolveReference is right before the 
application of
+      //  ResolveLateralColumnAliasReference. The condition in the two rules 
guarantees that the
+      //  case can never happen. We should consider to remove this order 
dependency but still assure
+      //  correctness in the future.
+      plan
+    } else {
+      // phase 2: unwrap
+      apply0(plan)
     }
   }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala
index 1e3a0d70c7f..cc4aeb42326 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala
@@ -669,6 +669,20 @@ class LateralColumnAliasSuite extends 
LateralColumnAliasSuiteBase {
         s"FROM $testTable GROUP BY dept ORDER BY max(name)"),
       Row(1, 1) :: Row(2, 2) :: Row(6, 6) :: Nil
     )
+    checkAnswer(
+      sql("SELECT dept, avg(salary) AS a, a + 10 FROM employee GROUP BY dept 
ORDER BY max(name)"),
+      Row(1, 9500, 9510) :: Row(2, 11000, 11010) :: Row(6, 12000, 12010) :: Nil
+    )
+    checkAnswer(
+      sql("SELECT dept, avg(salary) AS a, a + 10 AS b " +
+        "FROM employee GROUP BY dept ORDER BY max(name)"),
+      Row(1, 9500, 9510) :: Row(2, 11000, 11010) :: Row(6, 12000, 12010) :: Nil
+    )
+    checkAnswer(
+      sql("SELECT dept, avg(salary) AS a, a + cast(10 as double) AS b " +
+        "FROM employee GROUP BY dept ORDER BY max(name)"),
+      Row(1, 9500, 9510) :: Row(2, 11000, 11010) :: Row(6, 12000, 12010) :: Nil
+    )
 
     // having cond is resolved by aggregate's child
     checkAnswer(
@@ -676,6 +690,21 @@ class LateralColumnAliasSuite extends 
LateralColumnAliasSuiteBase {
         s"FROM $testTable GROUP BY dept HAVING max(name) = 'david'"),
       Row(1250, 2, 11000, 11010) :: Nil
     )
+    checkAnswer(
+      sql("SELECT dept, avg(salary) AS a, a + 10 " +
+        "FROM employee GROUP BY dept HAVING max(bonus) > 1200"),
+      Row(2, 11000, 11010) :: Nil
+    )
+    checkAnswer(
+      sql("SELECT dept, avg(salary) AS a, a + 10 AS b " +
+        "FROM employee GROUP BY dept HAVING max(bonus) > 1200"),
+      Row(2, 11000, 11010) :: Nil
+    )
+    checkAnswer(
+      sql("SELECT dept, avg(salary) AS a, a + cast(10 as double) AS b " +
+        "FROM employee GROUP BY dept HAVING max(bonus) > 1200"),
+      Row(2, 11000, 11010) :: Nil
+    )
     // having cond is resolved by aggregate itself
     checkAnswer(
       sql(s"SELECT avg(bonus) AS a, a FROM $testTable GROUP BY dept HAVING a > 
1200"),
@@ -1139,4 +1168,120 @@ class LateralColumnAliasSuite extends 
LateralColumnAliasSuiteBase {
     // non group by or non aggregate function in Aggregate queries negative 
cases are covered in
     // "Aggregate expressions not eligible to lift up, throws same error as 
inline".
   }
+
+  test("Still resolves when Aggregate with LCA is not the direct child of 
Having") {
+    // Previously there was a limitation of lca that it can't resolve the 
query when it satisfies
+    // all the following criteria:
+    //  1) the main (outer) query has having clause
+    //  2) there is a window expression in the query
+    //  3) in the same SELECT list as the window expression in 2), there is an 
lca
+    // Though 
[UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_AGGREGATE_WITH_WINDOW_AND_HAVING] 
is
+    // still not supported, after SPARK-44714, a lot other limitations are
+    // lifted because it allows to resolve LCA when the query has 
UnresolvedHaving but its direct
+    // child does not contain an LCA.
+    // Testcases in this test focus on this change regarding enablement of 
resolution.
+
+    // CTE definition contains window and LCA; outer query contains having
+    checkAnswer(
+      sql(
+        s"""
+           |with w as (
+           |  select name, dept, salary, rank() over (partition by dept order 
by salary) as r, r
+           |  from $testTable
+           |)
+           |select dept
+           |from w
+           |group by dept
+           |having max(salary) > 10000
+           |""".stripMargin),
+      Row(2) :: Row(6) :: Nil
+    )
+    checkAnswer(
+      sql(
+        s"""
+           |with w as (
+           |  select name, dept, salary, rank() over (partition by dept order 
by salary) as r, r
+           |  from $testTable
+           |)
+           |select dept as d, d
+           |from w
+           |group by dept
+           |having max(salary) > 10000
+           |""".stripMargin),
+      Row(2, 2) :: Row(6, 6) :: Nil
+    )
+    checkAnswer(
+      sql(
+        s"""
+           |with w as (
+           |  select name, dept, salary, rank() over (partition by dept order 
by salary) as r, r
+           |  from $testTable
+           |)
+           |select dept as d
+           |from w
+           |group by dept
+           |having d = 2
+           |""".stripMargin),
+      Row(2) :: Nil
+    )
+
+    // inner subquery contains window and LCA; outer query contains having
+    checkAnswer(
+      sql(
+        s"""
+          |SELECT
+          |  dept
+          |FROM
+          |   (
+          |    select
+          |      name, dept, salary, rank() over (partition by dept order by 
salary) as r,
+          |      1 as a, a + 1 as e
+          |    FROM
+          |      $testTable
+          |  ) AS inner_t
+          |GROUP BY
+          |  dept
+          |HAVING max(salary) > 10000
+          |""".stripMargin),
+      Row(2) :: Row(6) :: Nil
+    )
+    checkAnswer(
+      sql(
+        s"""
+           |SELECT
+           |  dept as d, d
+           |FROM
+           |   (
+           |    select
+           |      name, dept, salary, rank() over (partition by dept order by 
salary) as r,
+           |      1 as a, a + 1 as e
+           |    FROM
+           |      $testTable
+           |  ) AS inner_t
+           |GROUP BY
+           |  dept
+           |HAVING max(salary) > 10000
+           |""".stripMargin),
+      Row(2, 2) :: Row(6, 6) :: Nil
+    )
+    checkAnswer(
+      sql(
+        s"""
+           |SELECT
+           |  dept as d
+           |FROM
+           |   (
+           |    select
+           |      name, dept, salary, rank() over (partition by dept order by 
salary) as r,
+           |      1 as a, a + 1 as e
+           |    FROM
+           |      $testTable
+           |  ) AS inner_t
+           |GROUP BY
+           |  dept
+           |HAVING d = 2
+           |""".stripMargin),
+      Row(2) :: Nil
+    )
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to