This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.1 by this push:
     new 9025cc5  [SPARK-33273][SQL] Fix a race condition in subquery execution
9025cc5 is described below

commit 9025cc5c20c41f82f26d7ce20e9e2baf815676a0
Author: Wenchen Fan <[email protected]>
AuthorDate: Tue Dec 15 18:29:28 2020 +0900

    [SPARK-33273][SQL] Fix a race condition in subquery execution
    
    ### What changes were proposed in this pull request?
    
    If we call `SubqueryExec.executeTake`, it will call `SubqueryExec.execute` 
which will trigger the codegen of the query plan and create an RDD. However, 
`SubqueryExec` already has a thread (`SubqueryExec.relationFuture`) to execute 
the query plan, which means we have 2 threads triggering codegen of the same 
query plan at the same time.
    
    Spark codegen is not thread-safe, as we have places like 
`HashAggregateExec.bufferVars` that is a shared variable. The bug in 
`SubqueryExec` may lead to correctness bugs.
    
    Since https://issues.apache.org/jira/browse/SPARK-33119, `ScalarSubquery` 
will call `SubqueryExec.executeTake`, so flaky tests start to appear.
    
    This PR fixes the bug by reimplementing 
https://github.com/apache/spark/pull/30016 . We should pass the number of rows 
we want to collect to `SubqueryExec` at planning time, so that we can use 
`executeTake` inside `SubqueryExec.relationFuture`, and the caller side should 
always call `SubqueryExec.executeCollect`. This PR also adds checks so that we 
can make sure only `SubqueryExec.executeCollect` is called.
    
    ### Why are the changes needed?
    
    fix correctness bug.
    
    ### Does this PR introduce _any_ user-facing change?
    
    no
    
    ### How was this patch tested?
    
    run `build/sbt "sql/testOnly *SQLQueryTestSuite  -- -z 
scalar-subquery-select"` more than 10 times. Previously it fails, now it passes.
    
    Closes #30765 from cloud-fan/bug.
    
    Authored-by: Wenchen Fan <[email protected]>
    Signed-off-by: HyukjinKwon <[email protected]>
---
 .../adaptive/InsertAdaptiveSparkPlan.scala         |  3 +-
 .../sql/execution/basicPhysicalOperators.scala     | 35 +++++++++++++++++-----
 .../org/apache/spark/sql/execution/subquery.scala  |  6 ++--
 3 files changed, 33 insertions(+), 11 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala
index f8478f8..cd0503f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala
@@ -120,7 +120,8 @@ case class InsertAdaptiveSparkPlan(
           if !subqueryMap.contains(exprId.id) =>
         val executedPlan = compileSubquery(p)
         verifyAdaptivePlan(executedPlan, p)
-        val subquery = SubqueryExec(s"subquery#${exprId.id}", executedPlan)
+        val subquery = SubqueryExec.createForScalarSubquery(
+          s"subquery#${exprId.id}", executedPlan)
         subqueryMap.put(exprId.id, subquery)
       case expressions.InSubquery(_, ListQuery(query, _, exprId, _))
           if !subqueryMap.contains(exprId.id) =>
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index 006fa0f..d651132 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -764,7 +764,7 @@ abstract class BaseSubqueryExec extends SparkPlan {
 /**
  * Physical plan for a subquery.
  */
-case class SubqueryExec(name: String, child: SparkPlan)
+case class SubqueryExec(name: String, child: SparkPlan, maxNumRows: 
Option[Int] = None)
   extends BaseSubqueryExec with UnaryExecNode {
 
   override lazy val metrics = Map(
@@ -783,7 +783,11 @@ case class SubqueryExec(name: String, child: SparkPlan)
       SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) {
         val beforeCollect = System.nanoTime()
         // Note that we use .executeCollect() because we don't want to convert 
data to Scala types
-        val rows: Array[InternalRow] = child.executeCollect()
+        val rows: Array[InternalRow] = if (maxNumRows.isDefined) {
+          child.executeTake(maxNumRows.get)
+        } else {
+          child.executeCollect()
+        }
         val beforeBuild = System.nanoTime()
         longMetric("collectTime") += NANOSECONDS.toMillis(beforeBuild - 
beforeCollect)
         val dataSize = 
rows.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum
@@ -796,28 +800,45 @@ case class SubqueryExec(name: String, child: SparkPlan)
   }
 
   protected override def doCanonicalize(): SparkPlan = {
-    SubqueryExec("Subquery", child.canonicalized)
+    SubqueryExec("Subquery", child.canonicalized, maxNumRows)
   }
 
   protected override def doPrepare(): Unit = {
     relationFuture
   }
 
+  // `SubqueryExec` should only be used by calling `executeCollect`. It 
launches a new thread to
+  // collect the result of `child`. We should not trigger codegen of `child` 
again in other threads,
+  // as generating code is not thread-safe.
+  override def executeCollect(): Array[InternalRow] = {
+    ThreadUtils.awaitResult(relationFuture, Duration.Inf)
+  }
+
   protected override def doExecute(): RDD[InternalRow] = {
-    child.execute()
+    throw new IllegalStateException("SubqueryExec.doExecute should never be 
called")
   }
 
-  override def executeCollect(): Array[InternalRow] = {
-    ThreadUtils.awaitResult(relationFuture, Duration.Inf)
+  override def executeTake(n: Int): Array[InternalRow] = {
+    throw new IllegalStateException("SubqueryExec.executeTake should never be 
called")
+  }
+
+  override def executeTail(n: Int): Array[InternalRow] = {
+    throw new IllegalStateException("SubqueryExec.executeTail should never be 
called")
   }
 
-  override def stringArgs: Iterator[Any] = super.stringArgs ++ 
Iterator(s"[id=#$id]")
+  override def stringArgs: Iterator[Any] = Iterator(name, child) ++ 
Iterator(s"[id=#$id]")
 }
 
 object SubqueryExec {
   private[execution] val executionContext = 
ExecutionContext.fromExecutorService(
     ThreadUtils.newDaemonCachedThreadPool("subquery",
       SQLConf.get.getConf(StaticSQLConf.SUBQUERY_MAX_THREAD_THRESHOLD)))
+
+  def createForScalarSubquery(name: String, child: SparkPlan): SubqueryExec = {
+    // Scalar subquery needs only one row. We require 2 rows here to validate 
if the scalar query is
+    // invalid(return more than one row). We don't need all the rows as it may 
OOM.
+    SubqueryExec(name, child, maxNumRows = Some(2))
+  }
 }
 
 /**
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
index 5e222d2..0080b73 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
@@ -80,8 +80,7 @@ case class ScalarSubquery(
   @volatile private var updated: Boolean = false
 
   def updateResult(): Unit = {
-    // Only return the first two rows as an array to avoid Driver OOM.
-    val rows = plan.executeTake(2)
+    val rows = plan.executeCollect()
     if (rows.length > 1) {
       sys.error(s"more than one row returned by a subquery used as an 
expression:\n$plan")
     }
@@ -178,7 +177,8 @@ case class PlanSubqueries(sparkSession: SparkSession) 
extends Rule[SparkPlan] {
       case subquery: expressions.ScalarSubquery =>
         val executedPlan = QueryExecution.prepareExecutedPlan(sparkSession, 
subquery.plan)
         ScalarSubquery(
-          SubqueryExec(s"scalar-subquery#${subquery.exprId.id}", executedPlan),
+          SubqueryExec.createForScalarSubquery(
+            s"scalar-subquery#${subquery.exprId.id}", executedPlan),
           subquery.exprId)
       case expressions.InSubquery(values, ListQuery(query, _, exprId, _)) =>
         val expr = if (values.length == 1) {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to