This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8970415e9c1 [SPARK-43199][SQL] Make InlineCTE idempotent
8970415e9c1 is described below

commit 8970415e9c1384ebe12defb9bfc66dca3c08aa48
Author: Peter Toth <[email protected]>
AuthorDate: Wed Apr 26 15:32:47 2023 +0800

    [SPARK-43199][SQL] Make InlineCTE idempotent
    
    ### What changes were proposed in this pull request?
    This PR fixes `InlineCTE`'s idempotence. E.g. the following query:
    ```
    WITH
      x(r) AS (SELECT random()),
      y(r) AS (SELECT * FROM x),
      z(r) AS (SELECT * FROM x)
    SELECT * FROM z
    ```
    currently breaks it because we take into account the reference to `x` from 
`y` when deciding about not inlining `x` in the first round:
    ```
    === Applying Rule org.apache.spark.sql.catalyst.optimizer.InlineCTE ===
     WithCTE                                                        WithCTE
     :- CTERelationDef 0, false                                     :- 
CTERelationDef 0, false
     :  +- Project [rand()#218 AS r#219]                            :  +- 
Project [rand()#218 AS r#219]
     :     +- Project [random(2957388522017368375) AS rand()#218]   :     +- 
Project [random(2957388522017368375) AS rand()#218]
     :        +- OneRowRelation                                     :        +- 
OneRowRelation
    !:- CTERelationDef 1, false                                     +- Project 
[r#222]
    !:  +- Project [r#219 AS r#221]                                    +- 
Project [r#220 AS r#222]
    !:     +- Project [r#219]                                             +- 
Project [r#220]
    !:        +- CTERelationRef 0, true, [r#219]                             +- 
CTERelationRef 0, true, [r#220]
    !:- CTERelationDef 2, false
    !:  +- Project [r#220 AS r#222]
    !:     +- Project [r#220]
    !:        +- CTERelationRef 0, true, [r#220]
    !+- Project [r#222]
    !   +- CTERelationRef 2, true, [r#222]
    ```
    But in the next round we inline `x` because `y` was removed due to lack of 
references:
    ```
    Once strategy's idempotence is broken for batch Inline CTE
    !WithCTE                                                        Project 
[r#222]
    !:- CTERelationDef 0, false                                     +- Project 
[r#220 AS r#222]
    !:  +- Project [rand()#218 AS r#219]                               +- 
Project [r#220]
    !:     +- Project [random(2957388522017368375) AS rand()#218]         +- 
Project [r#225 AS r#220]
    !:        +- OneRowRelation                                              +- 
Project [rand()#218 AS r#225]
    !+- Project [r#222]                                                         
+- Project [random(2957388522017368375) AS rand()#218]
    !   +- Project [r#220 AS r#222]                                             
   +- OneRowRelation
    !      +- Project [r#220]
    !         +- CTERelationRef 0, true, [r#220]
    ```
    
    ### Why are the changes needed?
    We use `InlineCTE` as an idempotent rule in the `Optimizer`, 
`CheckAnalysis` and `ProgressReporter`.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Added new UT.
    
    Closes #40856 from peter-toth/SPARK-43199-make-inlinecte-idempotent.
    
    Authored-by: Peter Toth <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../sql/catalyst/analysis/CheckAnalysis.scala      |  4 +-
 .../spark/sql/catalyst/optimizer/InlineCTE.scala   | 86 ++++++++++++++++------
 .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 11 +++
 3 files changed, 77 insertions(+), 24 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 49f0d438d0a..fd5a86b3ba4 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -145,9 +145,9 @@ trait CheckAnalysis extends PredicateHelper with 
LookupCatalog with QueryErrorsB
 
   def checkAnalysis(plan: LogicalPlan): Unit = {
     val inlineCTE = InlineCTE(alwaysInline = true)
-    val cteMap = mutable.HashMap.empty[Long, (CTERelationDef, Int)]
+    val cteMap = mutable.HashMap.empty[Long, (CTERelationDef, Int, 
mutable.Map[Long, Int])]
     inlineCTE.buildCTEMap(plan, cteMap)
-    cteMap.values.foreach { case (relation, refCount) =>
+    cteMap.values.foreach { case (relation, refCount, _) =>
       // If a CTE relation is never used, it will disappear after inline. Here 
we explicitly check
       // analysis for it, to make sure the entire query plan is valid.
       if (refCount == 0) checkAnalysis0(relation.child)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala
index 1e4364b3f4a..8d7ff4cbf16 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala
@@ -42,8 +42,9 @@ case class InlineCTE(alwaysInline: Boolean = false) extends 
Rule[LogicalPlan] {
 
   override def apply(plan: LogicalPlan): LogicalPlan = {
     if (!plan.isInstanceOf[Subquery] && plan.containsPattern(CTE)) {
-      val cteMap = mutable.HashMap.empty[Long, (CTERelationDef, Int)]
+      val cteMap = mutable.SortedMap.empty[Long, (CTERelationDef, Int, 
mutable.Map[Long, Int])]
       buildCTEMap(plan, cteMap)
+      cleanCTEMap(cteMap)
       val notInlined = mutable.ArrayBuffer.empty[CTERelationDef]
       val inlined = inlineCTE(plan, cteMap, notInlined)
       // CTEs in SQL Commands have been inlined by `CTESubstitution` already, 
so it is safe to add
@@ -68,50 +69,91 @@ case class InlineCTE(alwaysInline: Boolean = false) extends 
Rule[LogicalPlan] {
       cteDef.child.exists(_.expressions.exists(_.isInstanceOf[OuterReference]))
   }
 
+  /**
+   * Accumulates all the CTEs from a plan into a special map.
+   *
+   * @param plan The plan to collect the CTEs from
+   * @param cteMap A mutable map that accumulates the CTEs and their reference 
information by CTE
+   *               ids. The value of the map is tuple whose elements are:
+   *               - The CTE definition
+   *               - The number of incoming references to the CTE. This 
includes references from
+   *                 other CTEs and regular places.
+   *               - A mutable inner map that tracks outgoing references 
(counts) to other CTEs.
+   * @param outerCTEId While collecting the map we use this optional CTE id to 
identify the
+   *                   current outer CTE.
+   */
   def buildCTEMap(
       plan: LogicalPlan,
-      cteMap: mutable.HashMap[Long, (CTERelationDef, Int)]): Unit = {
+      cteMap: mutable.Map[Long, (CTERelationDef, Int, mutable.Map[Long, Int])],
+      outerCTEId: Option[Long] = None): Unit = {
     plan match {
-      case WithCTE(_, cteDefs) =>
+      case WithCTE(child, cteDefs) =>
+        cteDefs.foreach { cteDef =>
+          cteMap(cteDef.id) = (cteDef, 0, 
mutable.Map.empty.withDefaultValue(0))
+        }
         cteDefs.foreach { cteDef =>
-          cteMap.put(cteDef.id, (cteDef, 0))
+          buildCTEMap(cteDef, cteMap, Some(cteDef.id))
         }
+        buildCTEMap(child, cteMap, outerCTEId)
 
       case ref: CTERelationRef =>
-        val (cteDef, refCount) = cteMap(ref.cteId)
-        cteMap.update(ref.cteId, (cteDef, refCount + 1))
+        val (cteDef, refCount, refMap) = cteMap(ref.cteId)
+        cteMap(ref.cteId) = (cteDef, refCount + 1, refMap)
+        outerCTEId.foreach { cteId =>
+          val (_, _, outerRefMap) = cteMap(cteId)
+          outerRefMap(ref.cteId) += 1
+        }
 
       case _ =>
-    }
-
-    if (plan.containsPattern(CTE)) {
-      plan.children.foreach { child =>
-        buildCTEMap(child, cteMap)
-      }
+        if (plan.containsPattern(CTE)) {
+          plan.children.foreach { child =>
+            buildCTEMap(child, cteMap, outerCTEId)
+          }
 
-      plan.expressions.foreach { expr =>
-        if (expr.containsAllPatterns(PLAN_EXPRESSION, CTE)) {
-          expr.foreach {
-            case e: SubqueryExpression =>
-              buildCTEMap(e.plan, cteMap)
-            case _ =>
+          plan.expressions.foreach { expr =>
+            if (expr.containsAllPatterns(PLAN_EXPRESSION, CTE)) {
+              expr.foreach {
+                case e: SubqueryExpression => buildCTEMap(e.plan, cteMap, 
outerCTEId)
+                case _ =>
+              }
+            }
           }
         }
+    }
+  }
+
+  /**
+   * Cleans the CTE map by removing those CTEs that are not referenced at all 
and corrects those
+   * CTE's reference counts where the removed CTE referred to.
+   *
+   * @param cteMap A mutable map that accumulates the CTEs and their reference 
information by CTE
+   *               ids. Needs to be sorted to speed up cleaning.
+   */
+  private def cleanCTEMap(
+      cteMap: mutable.SortedMap[Long, (CTERelationDef, Int, mutable.Map[Long, 
Int])]
+    ) = {
+    cteMap.keys.toSeq.reverse.foreach { currentCTEId =>
+      val (_, currentRefCount, refMap) = cteMap(currentCTEId)
+      if (currentRefCount == 0) {
+        refMap.foreach { case (referencedCTEId, uselessRefCount) =>
+          val (cteDef, refCount, refMap) = cteMap(referencedCTEId)
+          cteMap(referencedCTEId) = (cteDef, refCount - uselessRefCount, 
refMap)
+        }
       }
     }
   }
 
   private def inlineCTE(
       plan: LogicalPlan,
-      cteMap: mutable.HashMap[Long, (CTERelationDef, Int)],
+      cteMap: mutable.Map[Long, (CTERelationDef, Int, mutable.Map[Long, Int])],
       notInlined: mutable.ArrayBuffer[CTERelationDef]): LogicalPlan = {
     plan match {
       case WithCTE(child, cteDefs) =>
         cteDefs.foreach { cteDef =>
-          val (cte, refCount) = cteMap(cteDef.id)
+          val (cte, refCount, refMap) = cteMap(cteDef.id)
           if (refCount > 0) {
             val inlined = cte.copy(child = inlineCTE(cte.child, cteMap, 
notInlined))
-            cteMap.update(cteDef.id, (inlined, refCount))
+            cteMap(cteDef.id) = (inlined, refCount, refMap)
             if (!shouldInline(inlined, refCount)) {
               notInlined.append(inlined)
             }
@@ -120,7 +162,7 @@ case class InlineCTE(alwaysInline: Boolean = false) extends 
Rule[LogicalPlan] {
         inlineCTE(child, cteMap, notInlined)
 
       case ref: CTERelationRef =>
-        val (cteDef, refCount) = cteMap(ref.cteId)
+        val (cteDef, refCount, _) = cteMap(ref.cteId)
         if (shouldInline(cteDef, refCount)) {
           if (ref.outputSet == cteDef.outputSet) {
             cteDef.child
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 123364f18ce..e14a01e15a3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -4648,6 +4648,17 @@ class SQLQuerySuite extends QueryTest with 
SharedSparkSession with AdaptiveSpark
       sql("SELECT /*+ hash(t2) */ * FROM t1 join t2 on c1 = c2")
     }
   }
+
+  test("SPARK-43199: InlineCTE is idempotent") {
+    sql(
+      """
+        |WITH
+        |  x(r) AS (SELECT random()),
+        |  y(r) AS (SELECT * FROM x),
+        |  z(r) AS (SELECT * FROM x)
+        |SELECT * FROM z
+        |""".stripMargin).collect()
+  }
 }
 
 case class Foo(bar: Option[String])


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to