This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 88582e11519b [SPARK-48610][SQL] refactor: use auxiliary idMap instead 
of OP_ID_TAG
88582e11519b is described below

commit 88582e11519b18549b6ed19868e1749963a5299e
Author: Ziqi Liu <[email protected]>
AuthorDate: Mon Jun 17 14:34:35 2024 +0800

    [SPARK-48610][SQL] refactor: use auxiliary idMap instead of OP_ID_TAG
    
    ### What changes were proposed in this pull request?
    
    refactor: In `ExplainUtils.processPlan`, use auxiliary idMap instead of 
OP_ID_TAG
    
    ### Why are the changes needed?
    
    https://github.com/apache/spark/pull/45282 introduced synchronize to 
`ExplainUtils.processPlan`  to avoid race condition when multiple queries 
refers to same cached plan.
    
    The granularity of lock is too large. We can try to fix the root cause of 
this concurrency issue by refactoring the usage of mutable `OP_ID_TAG`, which 
is not a good practice in terms of immutable nature of SparkPlan.
    
    Instead, we can use an auxiliary id map, with object identity as the key. 
The entire scope of `OP_ID_TAG` usage is within `ExplainUtils.processPlan`, 
therefore it's safe to do so, with thread local to make it available in other 
involved classes.
    
    ### Does this PR introduce _any_ user-facing change?
      NO
    
    ### How was this patch tested?
    existing UTs.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    NO
    
    Closes #46965 from liuzqt/SPARK-48610.
    
    Authored-by: Ziqi Liu <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
    (cherry picked from commit d3da240ee3023887062909f99dc382b38b4daf1b)
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../spark/sql/catalyst/plans/QueryPlan.scala       | 18 +++++-
 .../apache/spark/sql/execution/ExplainUtils.scala  | 75 +++++++++-------------
 2 files changed, 47 insertions(+), 46 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index aee4790eb42a..12ee0274fd7a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.catalyst.plans
 
+import java.util.IdentityHashMap
+
 import scala.collection.mutable
 
 import org.apache.spark.sql.AnalysisException
@@ -429,7 +431,8 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
   override def verboseString(maxFields: Int): String = simpleString(maxFields)
 
   override def simpleStringWithNodeId(): String = {
-    val operatorId = getTagValue(QueryPlan.OP_ID_TAG).map(id => 
s"$id").getOrElse("unknown")
+    val operatorId = Option(QueryPlan.localIdMap.get().get(this)).map(id => 
s"$id")
+      .getOrElse("unknown")
     s"$nodeName ($operatorId)".trim
   }
 
@@ -449,7 +452,8 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
   }
 
   protected def formattedNodeName: String = {
-    val opId = getTagValue(QueryPlan.OP_ID_TAG).map(id => 
s"$id").getOrElse("unknown")
+    val opId = Option(QueryPlan.localIdMap.get().get(this)).map(id => s"$id")
+      .getOrElse("unknown")
     val codegenId =
       getTagValue(QueryPlan.CODEGEN_ID_TAG).map(id => s" [codegen id : 
$id]").getOrElse("")
     s"($opId) $nodeName$codegenId"
@@ -626,9 +630,17 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
 }
 
 object QueryPlan extends PredicateHelper {
-  val OP_ID_TAG = TreeNodeTag[Int]("operatorId")
   val CODEGEN_ID_TAG = new TreeNodeTag[Int]("wholeStageCodegenId")
 
+  /**
+   * A thread local map to store the mapping between the query plan and the 
query plan id.
+   * The scope of this thread local is within ExplainUtils.processPlan. The 
reason we define it here
+   * is because [[ QueryPlan ]] also needs this, and it doesn't have access to 
`execution` package
+   * from `catalyst`.
+   */
+  val localIdMap: ThreadLocal[java.util.Map[QueryPlan[_], Int]] = 
ThreadLocal.withInitial(() =>
+    new IdentityHashMap[QueryPlan[_], Int]())
+
   /**
    * Normalize the exprIds in the given expression, by updating the exprId in 
`AttributeReference`
    * with its referenced ordinal from input attributes. It's similar to 
`BindReferences` but we
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExplainUtils.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExplainUtils.scala
index 11f6ae0e47ee..421a963453f0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExplainUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExplainUtils.scala
@@ -17,9 +17,7 @@
 
 package org.apache.spark.sql.execution
 
-import java.util.Collections.newSetFromMap
 import java.util.IdentityHashMap
-import java.util.Set
 
 import scala.collection.mutable.{ArrayBuffer, BitSet}
 
@@ -30,6 +28,8 @@ import 
org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveS
 import org.apache.spark.sql.execution.exchange.{Exchange, ReusedExchangeExec}
 
 object ExplainUtils extends AdaptiveSparkPlanHelper {
+  def localIdMap: ThreadLocal[java.util.Map[QueryPlan[_], Int]] = 
QueryPlan.localIdMap
+
   /**
    * Given a input physical plan, performs the following tasks.
    *   1. Computes the whole stage codegen id for current operator and records 
it in the
@@ -80,24 +80,26 @@ object ExplainUtils extends AdaptiveSparkPlanHelper {
    * instances but cached plan is an exception. The 
`InMemoryRelation#innerChildren` use a shared
    * plan instance across multi-queries. Add lock for this method to avoid tag 
race condition.
    */
-  def processPlan[T <: QueryPlan[T]](plan: T, append: String => Unit): Unit = 
synchronized {
+  def processPlan[T <: QueryPlan[T]](plan: T, append: String => Unit): Unit = {
+    val prevIdMap = localIdMap.get()
     try {
-      // Initialize a reference-unique set of Operators to avoid accdiental 
overwrites and to allow
-      // intentional overwriting of IDs generated in previous AQE iteration
-      val operators = newSetFromMap[QueryPlan[_]](new IdentityHashMap())
+      // Initialize a reference-unique id map to store generated ids, which 
also avoid accidental
+      // overwrites and to allow intentional overwriting of IDs generated in 
previous AQE iteration
+      val idMap = new IdentityHashMap[QueryPlan[_], Int]()
+      localIdMap.set(idMap)
       // Initialize an array of ReusedExchanges to help find Adaptively 
Optimized Out
       // Exchanges as part of SPARK-42753
       val reusedExchanges = ArrayBuffer.empty[ReusedExchangeExec]
 
       var currentOperatorID = 0
-      currentOperatorID = generateOperatorIDs(plan, currentOperatorID, 
operators, reusedExchanges,
+      currentOperatorID = generateOperatorIDs(plan, currentOperatorID, idMap, 
reusedExchanges,
         true)
 
       val subqueries = ArrayBuffer.empty[(SparkPlan, Expression, 
BaseSubqueryExec)]
       getSubqueries(plan, subqueries)
 
       currentOperatorID = subqueries.foldLeft(currentOperatorID) {
-        (curId, plan) => generateOperatorIDs(plan._3.child, curId, operators, 
reusedExchanges,
+        (curId, plan) => generateOperatorIDs(plan._3.child, curId, idMap, 
reusedExchanges,
           true)
       }
 
@@ -105,9 +107,9 @@ object ExplainUtils extends AdaptiveSparkPlanHelper {
       val optimizedOutExchanges = ArrayBuffer.empty[Exchange]
       reusedExchanges.foreach{ reused =>
         val child = reused.child
-        if (!operators.contains(child)) {
+        if (!idMap.containsKey(child)) {
           optimizedOutExchanges.append(child)
-          currentOperatorID = generateOperatorIDs(child, currentOperatorID, 
operators,
+          currentOperatorID = generateOperatorIDs(child, currentOperatorID, 
idMap,
             reusedExchanges, false)
         }
       }
@@ -144,7 +146,7 @@ object ExplainUtils extends AdaptiveSparkPlanHelper {
         append("\n")
       }
     } finally {
-      removeTags(plan)
+      localIdMap.set(prevIdMap)
     }
   }
 
@@ -159,13 +161,15 @@ object ExplainUtils extends AdaptiveSparkPlanHelper {
    * @param plan Input query plan to process
    * @param startOperatorID The start value of operation id. The subsequent 
operations will be
    *                        assigned higher value.
-   * @param visited A unique set of operators visited by generateOperatorIds. 
The set is scoped
-   *                at the callsite function processPlan. It serves two 
purpose: Firstly, it is
-   *                used to avoid accidentally overwriting existing IDs that 
were generated in
-   *                the same processPlan call. Secondly, it is used to allow 
for intentional ID
-   *                overwriting as part of SPARK-42753 where an Adaptively 
Optimized Out Exchange
-   *                and its subtree may contain IDs that were generated in a 
previous AQE
-   *                iteration's processPlan call which would result in 
incorrect IDs.
+   * @param idMap   A reference-unique map store operators visited by 
generateOperatorIds and its
+   *                id. This Map is scoped at the callsite function 
processPlan. It serves three
+   *                purpose:
+   *                Firstly, it stores the QueryPlan - generated ID mapping. 
Secondly, it is used to
+   *                avoid accidentally overwriting existing IDs that were 
generated in the same
+   *                processPlan call. Thirdly, it is used to allow for 
intentional ID overwriting as
+   *                part of SPARK-42753 where an Adaptively Optimized Out 
Exchange and its subtree
+   *                may contain IDs that were generated in a previous AQE 
iteration's processPlan
+   *                call which would result in incorrect IDs.
    * @param reusedExchanges A unique set of ReusedExchange nodes visited which 
will be used to
    *                        idenitfy adaptively optimized out exchanges in 
SPARK-42753.
    * @param addReusedExchanges Whether to add ReusedExchange nodes to 
reusedExchanges set. We set it
@@ -177,7 +181,7 @@ object ExplainUtils extends AdaptiveSparkPlanHelper {
   private def generateOperatorIDs(
       plan: QueryPlan[_],
       startOperatorID: Int,
-      visited: Set[QueryPlan[_]],
+      idMap: java.util.Map[QueryPlan[_], Int],
       reusedExchanges: ArrayBuffer[ReusedExchangeExec],
       addReusedExchanges: Boolean): Int = {
     var currentOperationID = startOperatorID
@@ -186,36 +190,35 @@ object ExplainUtils extends AdaptiveSparkPlanHelper {
       return currentOperationID
     }
 
-    def setOpId(plan: QueryPlan[_]): Unit = if (!visited.contains(plan)) {
+    def setOpId(plan: QueryPlan[_]): Unit = idMap.computeIfAbsent(plan, plan 
=> {
       plan match {
         case r: ReusedExchangeExec if addReusedExchanges =>
           reusedExchanges.append(r)
         case _ =>
       }
-      visited.add(plan)
       currentOperationID += 1
-      plan.setTagValue(QueryPlan.OP_ID_TAG, currentOperationID)
-    }
+      currentOperationID
+    })
 
     plan.foreachUp {
       case _: WholeStageCodegenExec =>
       case _: InputAdapter =>
       case p: AdaptiveSparkPlanExec =>
-        currentOperationID = generateOperatorIDs(p.executedPlan, 
currentOperationID, visited,
+        currentOperationID = generateOperatorIDs(p.executedPlan, 
currentOperationID, idMap,
           reusedExchanges, addReusedExchanges)
         if (!p.executedPlan.fastEquals(p.initialPlan)) {
-          currentOperationID = generateOperatorIDs(p.initialPlan, 
currentOperationID, visited,
+          currentOperationID = generateOperatorIDs(p.initialPlan, 
currentOperationID, idMap,
             reusedExchanges, addReusedExchanges)
         }
         setOpId(p)
       case p: QueryStageExec =>
-        currentOperationID = generateOperatorIDs(p.plan, currentOperationID, 
visited,
+        currentOperationID = generateOperatorIDs(p.plan, currentOperationID, 
idMap,
           reusedExchanges, addReusedExchanges)
         setOpId(p)
       case other: QueryPlan[_] =>
         setOpId(other)
         currentOperationID = other.innerChildren.foldLeft(currentOperationID) {
-          (curId, plan) => generateOperatorIDs(plan, curId, visited, 
reusedExchanges,
+          (curId, plan) => generateOperatorIDs(plan, curId, idMap, 
reusedExchanges,
             addReusedExchanges)
         }
     }
@@ -241,7 +244,7 @@ object ExplainUtils extends AdaptiveSparkPlanHelper {
     }
 
     def collectOperatorWithID(plan: QueryPlan[_]): Unit = {
-      plan.getTagValue(QueryPlan.OP_ID_TAG).foreach { id =>
+      Option(ExplainUtils.localIdMap.get().get(plan)).foreach { id =>
         if (collectedOperators.add(id)) operators += plan
       }
     }
@@ -334,20 +337,6 @@ object ExplainUtils extends AdaptiveSparkPlanHelper {
    * `operationId` tag value.
    */
   def getOpId(plan: QueryPlan[_]): String = {
-    plan.getTagValue(QueryPlan.OP_ID_TAG).map(v => s"$v").getOrElse("unknown")
-  }
-
-  def removeTags(plan: QueryPlan[_]): Unit = {
-    def remove(p: QueryPlan[_], children: Seq[QueryPlan[_]]): Unit = {
-      p.unsetTagValue(QueryPlan.OP_ID_TAG)
-      p.unsetTagValue(QueryPlan.CODEGEN_ID_TAG)
-      children.foreach(removeTags)
-    }
-
-    plan foreach {
-      case p: AdaptiveSparkPlanExec => remove(p, Seq(p.executedPlan, 
p.initialPlan))
-      case p: QueryStageExec => remove(p, Seq(p.plan))
-      case plan: QueryPlan[_] => remove(plan, plan.innerChildren)
-    }
+    Option(ExplainUtils.localIdMap.get().get(plan)).map(v => 
s"$v").getOrElse("unknown")
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to