LuciferYang commented on code in PR #56700:
URL: https://github.com/apache/spark/pull/56700#discussion_r3464790941


##########
sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraphTransformer.scala:
##########
@@ -134,159 +134,165 @@ class DataflowGraphTransformer(graph: DataflowGraph) 
extends AutoCloseable {
     val failedFlowsQueue = new ConcurrentLinkedQueue[ResolutionFailedFlow]()
     val failedDependentFlows = new ConcurrentHashMap[TableIdentifier, 
Seq[ResolutionFailedFlow]]()
 
-    var futures = ArrayBuffer[Future[Unit]]()
+    val completionService = new ExecutorCompletionService[Unit](executor)
+    var outstanding = 0
     val toBeResolvedFlows = new ConcurrentLinkedDeque[Flow]()
     toBeResolvedFlows.addAll(flows.asJava)
 
-    while (futures.nonEmpty || toBeResolvedFlows.peekFirst() != null) {
-      val (done, notDone) = futures.partition(_.isDone)
-      // Explicitly call future.get() to propagate exceptions one by one if any
+    // Waits on a finished resolution task and propagates its exception, if 
any.
+    def reap(finished: Future[Unit]): Unit = {
       try {
-        done.foreach(_.get())
+        finished.get()
       } catch {
         case exn: ExecutionException =>
           // Computation threw the exception that is the cause of exn
           throw exn.getCause
       }
-      futures = notDone
-      val flowOpt = {
-        // We only schedule [[batchSize]] number of flows in parallel.
-        if (futures.size < batchSize) {
-          Option(toBeResolvedFlows.pollFirst())
-        } else {
-          None
-        }
+      outstanding -= 1
+    }
+
+    while (outstanding > 0 || toBeResolvedFlows.peekFirst() != null) {
+      // Reap every resolution task that has already finished, without 
blocking.
+      var finished = completionService.poll()
+      while (finished != null) {
+        reap(finished)
+        finished = completionService.poll()
       }
-      flowOpt.foreach { flow =>
-        futures.append(
-          executor.submit(
-            () =>
+      // We only schedule [[batchSize]] number of flows in parallel.
+      if (outstanding < batchSize && toBeResolvedFlows.peekFirst() != null) {
+        val flow = toBeResolvedFlows.pollFirst()
+        outstanding += 1
+        completionService.submit(
+          () =>
+            try {
               try {
-                try {
-                  // Note: Flow don't need their inputs passed, so for now we 
send empty Seq.
-                  val result = transformer(flow, Seq.empty)
-                  require(
-                    result.forall(_.isInstanceOf[ResolvedFlow]),
-                    "transformer must return a Seq[Flow]"
-                  )
+                // Note: Flow don't need their inputs passed, so for now we 
send empty Seq.
+                val result = transformer(flow, Seq.empty)
+                require(
+                  result.forall(_.isInstanceOf[ResolvedFlow]),
+                  "transformer must return a Seq[Flow]"
+                )
 
-                  val transformedFlows = 
result.map(_.asInstanceOf[ResolvedFlow])
-                  resolvedFlowsMap.put(flow.identifier, transformedFlows)
-                  resolvedFlows.addAll(transformedFlows.asJava)
-                } catch {
-                  case e: TransformNodeRetryableException =>
-                    val datasetIdentifier = e.datasetIdentifier
-                    failedDependentFlows.compute(
-                      datasetIdentifier,
-                      (_, flows) => {
-                        // Don't add the input flow back but the failed flow 
object
-                        // back which has relevant failure information.
-                        val failedFlow = e.failedNode
-                        if (flows == null) {
-                          Seq(failedFlow)
-                        } else {
-                          flows :+ failedFlow
-                        }
+                val transformedFlows = result.map(_.asInstanceOf[ResolvedFlow])
+                resolvedFlowsMap.put(flow.identifier, transformedFlows)
+                resolvedFlows.addAll(transformedFlows.asJava)
+              } catch {
+                case e: TransformNodeRetryableException =>
+                  val datasetIdentifier = e.datasetIdentifier
+                  failedDependentFlows.compute(
+                    datasetIdentifier,
+                    (_, flows) => {
+                      // Don't add the input flow back but the failed flow 
object
+                      // back which has relevant failure information.
+                      val failedFlow = e.failedNode
+                      if (flows == null) {
+                        Seq(failedFlow)
+                      } else {
+                        flows :+ failedFlow
                       }
-                    )
-                    // Between the time the flow started and finished 
resolving, perhaps the
-                    // dependent dataset was resolved
-                    resolvedFlowDestinationsMap.computeIfPresent(
-                      datasetIdentifier,
-                      (_, resolved) => {
-                        if (resolved) {
-                          // Check if the dataset that the flow is dependent 
on has been resolved
-                          // and if so, remove all dependent flows from the 
failedDependentFlows and
-                          // add them to the toBeResolvedFlows queue for retry.
-                          failedDependentFlows.computeIfPresent(
-                            datasetIdentifier,
-                            (_, toRetryFlows) => {
-                              
toRetryFlows.foreach(toBeResolvedFlows.addFirst(_))
-                              null
-                            }
-                          )
-                        }
-                        resolved
+                    }
+                  )
+                  // Between the time the flow started and finished resolving, 
perhaps the
+                  // dependent dataset was resolved
+                  resolvedFlowDestinationsMap.computeIfPresent(
+                    datasetIdentifier,
+                    (_, resolved) => {
+                      if (resolved) {
+                        // Check if the dataset that the flow is dependent on 
has been resolved
+                        // and if so, remove all dependent flows from the 
failedDependentFlows and
+                        // add them to the toBeResolvedFlows queue for retry.
+                        failedDependentFlows.computeIfPresent(
+                          datasetIdentifier,
+                          (_, toRetryFlows) => {
+                            toRetryFlows.foreach(toBeResolvedFlows.addFirst(_))
+                            null
+                          }
+                        )
                       }
+                      resolved
+                    }
+                  )
+                case other: Throwable => throw other
+              }
+              // If all flows to this particular destination are resolved, 
move to the destination
+              // node transformer
+              if (flowsTo(flow.destinationIdentifier).forall({ 
flowToDestination =>
+                  resolvedFlowsMap.containsKey(flowToDestination.identifier)
+                })) {
+                // If multiple flows completed in parallel, ensure we resolve 
the destination only
+                // once by electing a leader via computeIfAbsent
+                var isCurrentThreadLeader = false
+                
resolvedFlowDestinationsMap.computeIfAbsent(flow.destinationIdentifier, _ => {
+                  isCurrentThreadLeader = true
+                  // Set initial value as false as flow destination is not 
resolved yet.
+                  false
+                })
+                if (isCurrentThreadLeader) {
+                  if (tableMap.contains(flow.destinationIdentifier)) {
+                    val transformed =
+                      transformer(
+                        tableMap(flow.destinationIdentifier),
+                        flowsTo(flow.destinationIdentifier)
+                      )
+                    resolvedTables.addAll(
+                      transformed.collect { case t: Table => t }.asJava
                     )
-                  case other: Throwable => throw other
-                }
-                // If all flows to this particular destination are resolved, 
move to the destination
-                // node transformer
-                if (flowsTo(flow.destinationIdentifier).forall({ 
flowToDestination =>
-                    resolvedFlowsMap.containsKey(flowToDestination.identifier)
-                  })) {
-                  // If multiple flows completed in parallel, ensure we 
resolve the destination only
-                  // once by electing a leader via computeIfAbsent
-                  var isCurrentThreadLeader = false
-                  
resolvedFlowDestinationsMap.computeIfAbsent(flow.destinationIdentifier, _ => {
-                    isCurrentThreadLeader = true
-                    // Set initial value as false as flow destination is not 
resolved yet.
-                    false
-                  })
-                  if (isCurrentThreadLeader) {
-                    if (tableMap.contains(flow.destinationIdentifier)) {
+                    resolvedFlows.addAll(
+                      transformed.collect { case f: ResolvedFlow => f }.asJava
+                    )
+                  } else if (viewMap.contains(flow.destinationIdentifier)) {
+                    resolvedViews.addAll {
                       val transformed =
                         transformer(
-                          tableMap(flow.destinationIdentifier),
+                          viewMap(flow.destinationIdentifier),
                           flowsTo(flow.destinationIdentifier)
                         )
-                      resolvedTables.addAll(
-                        transformed.collect { case t: Table => t }.asJava
-                      )
-                      resolvedFlows.addAll(
-                        transformed.collect { case f: ResolvedFlow => f 
}.asJava
-                      )
-                    } else if (viewMap.contains(flow.destinationIdentifier)) {
-                      resolvedViews.addAll {
-                        val transformed =
-                          transformer(
-                            viewMap(flow.destinationIdentifier),
-                            flowsTo(flow.destinationIdentifier)
-                          )
-                        transformed.map(_.asInstanceOf[View]).asJava
-                      }
-                    } else if (sinkMap.contains(flow.destinationIdentifier)) {
-                      resolvedSinks.addAll {
-                        val transformed =
-                          transformer(
-                            sinkMap(flow.destinationIdentifier), 
flowsTo(flow.destinationIdentifier)
-                          )
-                        require(
-                          transformed.forall(_.isInstanceOf[Sink]),
-                          "transformer must return a Seq[Sink]"
+                      transformed.map(_.asInstanceOf[View]).asJava
+                    }
+                  } else if (sinkMap.contains(flow.destinationIdentifier)) {
+                    resolvedSinks.addAll {
+                      val transformed =
+                        transformer(
+                          sinkMap(flow.destinationIdentifier), 
flowsTo(flow.destinationIdentifier)
                         )
-                        transformed.map(_.asInstanceOf[Sink]).asJava
-                      }
-                    } else {
-                      throw new IllegalArgumentException(
-                        s"Unsupported destination 
${flow.destinationIdentifier.unquotedString}" +
-                        s" in flow: ${flow.displayName} at transformDownNodes"
+                      require(
+                        transformed.forall(_.isInstanceOf[Sink]),
+                        "transformer must return a Seq[Sink]"
                       )
+                      transformed.map(_.asInstanceOf[Sink]).asJava
                     }
-                    // Set flow destination as resolved now.
-                    resolvedFlowDestinationsMap.computeIfPresent(
-                      flow.destinationIdentifier,
-                      (_, _) => {
-                        // If there are any other node failures dependent on 
this destination, retry
-                        // them
-                        failedDependentFlows.computeIfPresent(
-                          flow.destinationIdentifier,
-                          (_, toRetryFlows) => {
-                            toRetryFlows.foreach(toBeResolvedFlows.addFirst(_))
-                            null
-                          }
-                        )
-                        true
-                      }
+                  } else {
+                    throw new IllegalArgumentException(
+                      s"Unsupported destination 
${flow.destinationIdentifier.unquotedString}" +
+                      s" in flow: ${flow.displayName} at transformDownNodes"
                     )
                   }
+                  // Set flow destination as resolved now.
+                  resolvedFlowDestinationsMap.computeIfPresent(
+                    flow.destinationIdentifier,
+                    (_, _) => {
+                      // If there are any other node failures dependent on 
this destination, retry
+                      // them
+                      failedDependentFlows.computeIfPresent(
+                        flow.destinationIdentifier,
+                        (_, toRetryFlows) => {
+                          toRetryFlows.foreach(toBeResolvedFlows.addFirst(_))
+                          null
+                        }
+                      )
+                      true
+                    }
+                  )
                 }
-              } catch {
-                case ex: TransformNodeFailedException => 
failedFlowsQueue.add(ex.failedNode)
               }
-          )
+            } catch {
+              case ex: TransformNodeFailedException => 
failedFlowsQueue.add(ex.failedNode)
+            }
         )
+      } else if (outstanding > 0) {
+        // Nothing finished and nothing could be scheduled, but tasks are 
still running:
+        // block until the next one finishes instead of busy-spinning on 
Future.isDone.
+        reap(completionService.take())

Review Comment:
   Good question, but the guard is needed - a plain `else` would deadlock the 
last iteration. The reasoning misses that the `poll()` drain loop runs *after* 
the loop-condition check and mutates `outstanding`. The invariant `outstanding 
> 0 || queue.nonEmpty` only holds at the top of the loop; by the time we reach 
this branch, the drain may have reaped the last in-flight tasks and taken 
`outstanding` to 0 while the queue is already empty. In that case the `if` is 
false (empty queue) and there is nothing left to wait for - `take()` would 
block forever. The `outstanding > 0` guard lets the loop fall through so the 
next condition check (`0 > 0 || empty`) exits cleanly. I added a comment 
spelling this out.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to