peter-toth commented on code in PR #52835:
URL: https://github.com/apache/spark/pull/52835#discussion_r2485847706


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala:
##########
@@ -117,297 +116,126 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] {
     }
   }
 
-  /**
-   * An item in the cache of merged scalar subqueries.
-   *
-   * @param plan The plan of a merged scalar subquery.
-   * @param merged A flag to identify if this item is the result of merging 
subqueries.
-   *               Please note that `attributes.size == 1` doesn't always mean 
that the plan is not
-   *               merged as there can be subqueries that are different 
([[checkIdenticalPlans]] is
-   *               false) due to an extra [[Project]] node in one of them. In 
that case
-   *               `attributes.size` remains 1 after merging, but the merged 
flag becomes true.
-   * @param references A set of subquery indexes in the cache to track all 
(including transitive)
-   *                   nested subqueries.
-   */
-  case class Header(
-      plan: LogicalPlan,
-      merged: Boolean,
-      references: Set[Int])
-
   private def extractCommonScalarSubqueries(plan: LogicalPlan) = {
-    val cache = ArrayBuffer.empty[Header]
-    val planWithReferences = insertReferences(plan, cache)
-    cache.zipWithIndex.foreach { case (header, i) =>
-      cache(i) = cache(i).copy(plan =
-        if (header.merged) {
+    // Collect `ScalarSubquery` plans by level into `PlanMerger`s and insert 
references in place of
+    // `ScalarSubquery`s.
+    val planMergers = ArrayBuffer.empty[PlanMerger]
+    val planWithReferences = insertReferences(plan, planMergers)._1
+
+    // Traverse level by level and convert merged plans to `CTERelationDef`s 
and keep non-merged
+    // ones. While traversing replace references in plans back to 
`CTERelationRef`s or to original
+    // `ScalarSubquery`s. This is safe as a subquery plan at a level can 
reference only lower level
+    // other subqueries.
+    val subqueryPlansByLevel = ArrayBuffer.empty[IndexedSeq[LogicalPlan]]
+    planMergers.foreach { planMerger =>
+      val mergedPlans = planMerger.mergedPlans()
+      subqueryPlansByLevel += mergedPlans.map { mergedPlan =>
+        val planWithoutReferences = if (subqueryPlansByLevel.isEmpty) {
+          // Level 0 plans can't contain references
+          mergedPlan.plan
+        } else {
+          removeReferences(mergedPlan.plan, subqueryPlansByLevel)
+        }
+        if (mergedPlan.merged) {
           CTERelationDef(
-            createProject(header.plan.output, removeReferences(header.plan, 
cache)),
+            Project(
+              Seq(Alias(
+                CreateNamedStruct(
+                  planWithoutReferences.output.flatMap(a => 
Seq(Literal(a.name), a))),
+                "mergedValue")()),
+              planWithoutReferences),
             underSubquery = true)
         } else {
-          removeReferences(header.plan, cache)
-        })
+          planWithoutReferences
+        }
+      }
     }
-    val newPlan = removeReferences(planWithReferences, cache)
-    val subqueryCTEs = 
cache.filter(_.merged).map(_.plan.asInstanceOf[CTERelationDef])
+
+    // Replace references back to `CTERelationRef`s or to original 
`ScalarSubquery`s in the main
+    // plan.
+    val newPlan = removeReferences(planWithReferences, subqueryPlansByLevel)
+
+    // Add `CTERelationDef`s to the plan.
+    val subqueryCTEs = subqueryPlansByLevel.flatMap(_.collect { case cte: 
CTERelationDef => cte })
     if (subqueryCTEs.nonEmpty) {
       WithCTE(newPlan, subqueryCTEs.toSeq)
     } else {
       newPlan
     }
   }
 
-  // First traversal builds up the cache and inserts 
`ScalarSubqueryReference`s to the plan.
-  private def insertReferences(plan: LogicalPlan, cache: ArrayBuffer[Header]): 
LogicalPlan = {
-    plan.transformUpWithSubqueries {
-      case n => 
n.transformExpressionsUpWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY)) {
-        // The subquery could contain a hint that is not propagated once we 
cache it, but as a
-        // non-correlated scalar subquery won't be turned into a Join the loss 
of hints is fine.
+  // First traversal inserts `ScalarSubqueryReference`s to the plan and tries 
to merge subquery
+  // plans by each level.
+  private def insertReferences(
+      plan: LogicalPlan,
+      planMergers: ArrayBuffer[PlanMerger]): (LogicalPlan, Int) = {
+    // The level of a subquery plan is maximum level of its inner subqueries + 
1 or 0 if it has no
+    // inner subqueries.
+    var maxLevel = 0
+    val planWithReferences =
+      
plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY)) {
         case s: ScalarSubquery if !s.isCorrelated && s.deterministic =>
-          val (subqueryIndex, headerIndex) = cacheSubquery(s.plan, cache)
-          ScalarSubqueryReference(subqueryIndex, headerIndex, s.dataType, 
s.exprId)
-      }
-    }
-  }
-
-  // Caching returns the index of the subquery in the cache and the index of 
scalar member in the
-  // "Header".
-  private def cacheSubquery(plan: LogicalPlan, cache: ArrayBuffer[Header]): 
(Int, Int) = {

Review Comment:
   Briefly, `cacheSubquery()` became `PlanMerger.merge()` with some 
modifications.
   
   As `PlanMerger.merge()` is kind of an internal API now, which can be used by 
other rules. It returns `MergeResult` instead of the tuple after this refactor.
   The `Header` objects, which contined the information about the merged plans 
became `MergedPlan`s and can be accessed as `PlanMerger.mergedPlans()` by the 
rules.
   
   The main difference between `cacheSubquery()` and the new 
`PlanMerger.merge()` is that I removed `Header.references` and the related 
safeguards like the `if !references.contains(subqueryIndex)` below.
   Previously, we filled up `Header.references` and with nested subquery 
indices so as to avoid trying to merge a subquery with other subqueries that 
are nested into it.
   E.g.
   ```
   SELECT (
     SELECT avg(a) FROM t WHERE c = (    -- subquery 1
       SELECT sum(b) FROM t2)            -- subquery 2 nested into subquery 1
     )
   )
   ```
   Here it doesn't make sense to try merging the plan of `subquery 1` and 
`subquery 2` as the former contains the latter. The problem of trying to merge 
the 2 was not just it doesn't make sense, but in certain weird cases (see the 
example in [SPARK-40618](https://issues.apache.org/jira/browse/SPARK-40618)) 
the merge was successful and resulted in invalid plans.
   
   But because `PlanMerger` should be a general plan merging tool and in some 
future rules it will be used to merge other than subqueries, tracking the 
nested subqueries in it is not the best way to deal with the problem.
   This is why the refactored `MergeScalarSubqueries` uses a different 
approach. We have one `PlanMerger` object per each subquery level. E.g. 
`subquery 2` is a level 0 (leaf) subquery, but `subquery 1` is a level 1 
subquery because it has an inner, level 0 subquery. So the 2 will be added to 
different `PlanMerger`s. This way a level n subquery is tried to merge with 
other level n subqueries only.
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to