This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 2f4a71252cc [SPARK-44897][SQL] Propagating local properties to 
subquery broadcast exec
2f4a71252cc is described below

commit 2f4a71252cc2ccc4bb1a9c388391c6df1b15a1f7
Author: Michael Chen <[email protected]>
AuthorDate: Mon Aug 28 10:56:54 2023 +0800

    [SPARK-44897][SQL] Propagating local properties to subquery broadcast exec
    
    ### What changes were proposed in this pull request?
    https://issues.apache.org/jira/browse/SPARK-32748 previously proposed 
propagating these local properties to the subquery broadcast exec threads but 
was then reverted since it was said that local properties would already be 
propagated to the broadcast threads.
    I believe this is not always true. In the scenario where a separate 
`BroadcastExchangeExec` is the first to compute the broadcast, this is fine. 
However, in the scenario where the `SubqueryBroadcastExec` is the first to 
compute the broadcast, then the local properties that are propagated to the 
broadcast threads would not have been propagated correctly. This is because the 
local properties from the subquery broadcast exec were not propagated to its 
Future thread.
    It is difficult to write a unit test that reproduces this behavior because 
usually `BroadcastExchangeExec` is the first computing the broadcast variable. 
However, by adding a `Thread.sleep(10)` to `SubqueryBroadcastExec.doPrepare` 
after `relationFuture` is initialized, the added test will consistently fail.
    
    ### Why are the changes needed?
    Local properties are not propagated correctly to `SubqueryBroadcastExec`
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Following test can reproduce the bug and test the solution by adding sleep 
to `SubqueryBroadcastExec.doPrepare`
    ```
    protected override def doPrepare(): Unit = {
        relationFuture
        Thread.sleep(10)
    }
    ```
    
    ```test("SPARK-44897 propagate local properties to subquery broadcast 
execuction thread") {
        withSQLConf(StaticSQLConf.BROADCAST_EXCHANGE_MAX_THREAD_THRESHOLD.key 
-> "1") {
          withTable("a", "b") {
            val confKey = "spark.sql.y"
            val confValue1 = UUID.randomUUID().toString()
            val confValue2 = UUID.randomUUID().toString()
            Seq((confValue1, "1")).toDF("key", "value")
              .write
              .format("parquet")
              .partitionBy("key")
              .mode("overwrite")
              .saveAsTable("a")
            val df1 = spark.table("a")
    
            def generateBroadcastDataFrame(confKey: String, confValue: String): 
Dataset[String] = {
              val df = spark.range(1).mapPartitions { _ =>
                Iterator(TaskContext.get.getLocalProperty(confKey))
              }.filter($"value".contains(confValue)).as("c")
              df.hint("broadcast")
            }
    
            // set local property and assert
            val df2 = generateBroadcastDataFrame(confKey, confValue1)
            spark.sparkContext.setLocalProperty(confKey, confValue1)
            val checkDF = df1.join(df2).where($"a.key" === 
$"c.value").select($"a.key", $"c.value")
            val checks = checkDF.collect()
            assert(checks.forall(_.toSeq == Seq(confValue1, confValue1)))
    
            // change local property and re-assert
            Seq((confValue2, "1")).toDF("key", "value")
              .write
              .format("parquet")
              .partitionBy("key")
              .mode("overwrite")
              .saveAsTable("b")
            val df3 = spark.table("b")
            val df4 = generateBroadcastDataFrame(confKey, confValue2)
            spark.sparkContext.setLocalProperty(confKey, confValue2)
            val checks2DF = df3.join(df4).where($"b.key" === 
$"c.value").select($"b.key", $"c.value")
            val checks2 = checks2DF.collect()
            assert(checks2.forall(_.toSeq == Seq(confValue2, confValue2)))
            assert(checks2.nonEmpty)
          }
        }
      }
      ```
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #42587 from 
ChenMichael/SPARK-44897-local-property-propagation-to-subquery-broadcast-exec.
    
    Authored-by: Michael Chen <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
    (cherry picked from commit 4a4856207d414ba88a8edabeb70e20765460ef1a)
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../spark/sql/execution/SubqueryBroadcastExec.scala       | 15 ++++++++++-----
 1 file changed, 10 insertions(+), 5 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryBroadcastExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryBroadcastExec.scala
index 22d042ccefb..05657fe62e8 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryBroadcastExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryBroadcastExec.scala
@@ -17,7 +17,9 @@
 
 package org.apache.spark.sql.execution
 
-import scala.concurrent.{ExecutionContext, Future}
+import java.util.concurrent.{Future => JFuture}
+
+import scala.concurrent.ExecutionContext
 import scala.concurrent.duration.Duration
 
 import org.apache.spark.rdd.RDD
@@ -27,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.QueryPlan
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.execution.joins.{HashedRelation, HashJoin, 
LongHashedRelation}
 import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
 import org.apache.spark.util.ThreadUtils
 
 /**
@@ -70,10 +73,11 @@ case class SubqueryBroadcastExec(
   }
 
   @transient
-  private lazy val relationFuture: Future[Array[InternalRow]] = {
+  private lazy val relationFuture: JFuture[Array[InternalRow]] = {
     // relationFuture is used in "doExecute". Therefore we can get the 
execution id correctly here.
     val executionId = 
sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
-    Future {
+    SQLExecution.withThreadLocalCaptured[Array[InternalRow]](
+      session, SubqueryBroadcastExec.executionContext) {
       // This will run in another thread. Set the execution id so that we can 
connect these jobs
       // with the correct execution.
       SQLExecution.withExecutionId(session, executionId) {
@@ -104,7 +108,7 @@ case class SubqueryBroadcastExec(
 
         rows
       }
-    }(SubqueryBroadcastExec.executionContext)
+    }
   }
 
   protected override def doPrepare(): Unit = {
@@ -127,5 +131,6 @@ case class SubqueryBroadcastExec(
 
 object SubqueryBroadcastExec {
   private[execution] val executionContext = 
ExecutionContext.fromExecutorService(
-    ThreadUtils.newDaemonCachedThreadPool("dynamicpruning", 16))
+    ThreadUtils.newDaemonCachedThreadPool("dynamicpruning",
+      
SQLConf.get.getConf(StaticSQLConf.BROADCAST_EXCHANGE_MAX_THREAD_THRESHOLD)))
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to