This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9287fc9dd73 [SPARK-40618][SQL] Fix bug in MergeScalarSubqueries rule 
with nested subqueries using reference tracking
9287fc9dd73 is described below

commit 9287fc9dd73a0909d5705308532b24528b3f1090
Author: Peter Toth <peter.t...@gmail.com>
AuthorDate: Thu Oct 13 22:06:15 2022 +0800

    [SPARK-40618][SQL] Fix bug in MergeScalarSubqueries rule with nested 
subqueries using reference tracking
    
    ### What changes were proposed in this pull request?
    This PR reverts the previous fix https://github.com/apache/spark/pull/38052 
and adds subquery reference tracking to `MergeScalarSubqueries` to restore 
previous functionality of merging independent nested subqueries.
    
    ### Why are the changes needed?
    Restore previous functionality but fix the bug discovered in 
https://issues.apache.org/jira/browse/SPARK-40618.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Existing and new UTs.
    
    Closes #38093 from peter-toth/SPARK-40618-fix-mergescalarsubqueries.
    
    Authored-by: Peter Toth <peter.t...@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../catalyst/optimizer/MergeScalarSubqueries.scala | 62 +++++++++++++---------
 .../scala/org/apache/spark/sql/SubquerySuite.scala | 35 +++++++++---
 2 files changed, 67 insertions(+), 30 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala
index 69f77e8f3f4..1cb3f3f157c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.catalyst.optimizer
 
+import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.sql.catalyst.expressions._
@@ -126,8 +127,14 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] {
    *               merged as there can be subqueries that are different 
([[checkIdenticalPlans]] is
    *               false) due to an extra [[Project]] node in one of them. In 
that case
    *               `attributes.size` remains 1 after merging, but the merged 
flag becomes true.
+   * @param references A set of subquery indexes in the cache to track all 
(including transitive)
+   *                   nested subqueries.
    */
-  case class Header(attributes: Seq[Attribute], plan: LogicalPlan, merged: 
Boolean)
+  case class Header(
+      attributes: Seq[Attribute],
+      plan: LogicalPlan,
+      merged: Boolean,
+      references: Set[Int])
 
   private def extractCommonScalarSubqueries(plan: LogicalPlan) = {
     val cache = ArrayBuffer.empty[Header]
@@ -166,26 +173,39 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] {
   // "Header".
   private def cacheSubquery(plan: LogicalPlan, cache: ArrayBuffer[Header]): 
(Int, Int) = {
     val output = plan.output.head
-    cache.zipWithIndex.collectFirst(Function.unlift { case (header, 
subqueryIndex) =>
-      checkIdenticalPlans(plan, header.plan).map { outputMap =>
-        val mappedOutput = mapAttributes(output, outputMap)
-        val headerIndex = header.attributes.indexWhere(_.exprId == 
mappedOutput.exprId)
-        subqueryIndex -> headerIndex
-      }.orElse(tryMergePlans(plan, header.plan).map {
-        case (mergedPlan, outputMap) =>
+    val references = mutable.HashSet.empty[Int]
+    
plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY_REFERENCE))
 {
+      case ssr: ScalarSubqueryReference =>
+        references += ssr.subqueryIndex
+        references ++= cache(ssr.subqueryIndex).references
+        ssr
+    }
+
+    cache.zipWithIndex.collectFirst(Function.unlift {
+      case (header, subqueryIndex) if !references.contains(subqueryIndex) =>
+        checkIdenticalPlans(plan, header.plan).map { outputMap =>
           val mappedOutput = mapAttributes(output, outputMap)
-          var headerIndex = header.attributes.indexWhere(_.exprId == 
mappedOutput.exprId)
-          val newHeaderAttributes = if (headerIndex == -1) {
-            headerIndex = header.attributes.size
-            header.attributes :+ mappedOutput
-          } else {
-            header.attributes
-          }
-          cache(subqueryIndex) = Header(newHeaderAttributes, mergedPlan, true)
+          val headerIndex = header.attributes.indexWhere(_.exprId == 
mappedOutput.exprId)
           subqueryIndex -> headerIndex
-      })
+        }.orElse{
+          tryMergePlans(plan, header.plan).map {
+            case (mergedPlan, outputMap) =>
+              val mappedOutput = mapAttributes(output, outputMap)
+              var headerIndex = header.attributes.indexWhere(_.exprId == 
mappedOutput.exprId)
+              val newHeaderAttributes = if (headerIndex == -1) {
+                headerIndex = header.attributes.size
+                header.attributes :+ mappedOutput
+              } else {
+                header.attributes
+              }
+              cache(subqueryIndex) =
+                Header(newHeaderAttributes, mergedPlan, true, 
header.references ++ references)
+              subqueryIndex -> headerIndex
+          }
+        }
+      case _ => None
     }).getOrElse {
-      cache += Header(Seq(output), plan, false)
+      cache += Header(Seq(output), plan, false, references.toSet)
       cache.length - 1 -> 0
     }
   }
@@ -210,12 +230,6 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] {
       cachedPlan: LogicalPlan): Option[(LogicalPlan, AttributeMap[Attribute])] 
= {
     checkIdenticalPlans(newPlan, cachedPlan).map(cachedPlan -> _).orElse(
       (newPlan, cachedPlan) match {
-        case (_, _) if newPlan.containsPattern(SCALAR_SUBQUERY_REFERENCE) ||
-          cachedPlan.containsPattern(SCALAR_SUBQUERY_REFERENCE) =>
-          // Subquery expressions with nested subquery expressions within are 
not supported for now.
-          // TODO: support this optimization by collecting the transitive 
subquery references in the
-          // new plan and recording them in order to suppress merging the new 
plan into those.
-          None
         case (np: Project, cp: Project) =>
           tryMergePlans(np.child, cp.child).map { case (mergedChild, 
outputMap) =>
             val (mergedProjectList, newOutputMap) =
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index ca78aaae414..2c8c3eda953 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -2251,7 +2251,7 @@ class SubquerySuite extends QueryTest
     }
   }
 
-  test("SPARK-40618: Do not merge scalar subqueries with nested subqueries 
inside") {
+  test("Merge non-correlated scalar subqueries from different parent plans") {
     Seq(false, true).foreach { enableAQE =>
       withSQLConf(
         SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) {
@@ -2283,13 +2283,13 @@ class SubquerySuite extends QueryTest
         }
 
         if (enableAQE) {
-          assert(subqueryIds.size == 4, "Missing or unexpected SubqueryExec in 
the plan")
-          assert(reusedSubqueryIds.size == 2,
-            "Missing or unexpected reused ReusedSubqueryExec in the plan")
-        } else {
           assert(subqueryIds.size == 3, "Missing or unexpected SubqueryExec in 
the plan")
           assert(reusedSubqueryIds.size == 3,
             "Missing or unexpected reused ReusedSubqueryExec in the plan")
+        } else {
+          assert(subqueryIds.size == 2, "Missing or unexpected SubqueryExec in 
the plan")
+          assert(reusedSubqueryIds.size == 4,
+            "Missing or unexpected reused ReusedSubqueryExec in the plan")
         }
       }
     }
@@ -2426,9 +2426,32 @@ class SubquerySuite extends QueryTest
     // This test contains a subquery expression with another subquery 
expression nested inside.
     // It acts as a regression test to ensure that the MergeScalarSubqueries 
rule does not attempt
     // to merge them together.
-    withTable("t") {
+    withTable("t", "t2") {
       sql("create table t(col int) using csv")
       checkAnswer(sql("select(select sum((select sum(col) from t)) from t)"), 
Row(null))
+
+      checkAnswer(sql(
+        """
+          |select
+          |  (select sum(
+          |    (select sum(
+          |        (select sum(col) from t))
+          |     from t))
+          |  from t)
+          |""".stripMargin),
+        Row(null))
+
+      sql("create table t2(col int) using csv")
+      checkAnswer(sql(
+        """
+          |select
+          |  (select sum(
+          |    (select sum(
+          |        (select sum(col) from t))
+          |     from t2))
+          |  from t)
+          |""".stripMargin),
+        Row(null))
     }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to