HyukjinKwon commented on code in PR #38759:
URL: https://github.com/apache/spark/pull/38759#discussion_r1030973655


##########
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -71,20 +73,80 @@ class SparkConnectStreamHandler(responseObserver: 
StreamObserver[ExecutePlanResp
       var numSent = 0
 
       if (numPartitions > 0) {
+        type Batch = (Array[Byte], Long)
+
         val batches = rows.mapPartitionsInternal(
           SparkConnectStreamHandler
             .rowToArrowConverter(schema, maxRecordsPerBatch, maxBatchSize, 
timeZoneId))
 
-        batches.collect().foreach { case (bytes, count) =>
-          val response = 
proto.ExecutePlanResponse.newBuilder().setClientId(clientId)
-          val batch = proto.ExecutePlanResponse.ArrowBatch
-            .newBuilder()
-            .setRowCount(count)
-            .setData(ByteString.copyFrom(bytes))
-            .build()
-          response.setArrowBatch(batch)
-          responseObserver.onNext(response.build())
-          numSent += 1
+        val signal = new Object
+        val partitions = new Array[Array[Batch]](numPartitions)
+        var error: Option[Throwable] = None
+
+        // This callback is executed by the DAGScheduler thread.
+        // After fetching a partition, it inserts the partition into the Map, 
and then
+        // wakes up the main thread.
+        val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+          signal.synchronized {
+            partitions(partitionId) = partition
+            signal.notify()
+          }
+          ()
+        }
+
+        val future = spark.sparkContext.submitJob(
+          rdd = batches,
+          processPartition = (iter: Iterator[Batch]) => iter.toArray,
+          partitions = Seq.range(0, numPartitions),
+          resultHandler = resultHandler,
+          resultFunc = () => ())
+
+        // Collect errors and propagate them to the main thread.
+        future.onComplete { result =>
+          result.failed.foreach { throwable =>
+            signal.synchronized {
+              error = Some(throwable)
+              signal.notify()
+            }
+          }
+        }(ThreadUtils.sameThread)
+
+        // The main thread will wait until 0-th partition is available,
+        // then send it to client and wait for the next partition.
+        // Different from the implementation of 
[[Dataset#collectAsArrowToPython]], it sends
+        // the arrow batches in main thread to avoid DAGScheduler thread been 
blocked for
+        // tasks not related to scheduling. This is particularly important if 
there are
+        // multiple users or clients running code at the same time.
+        var currentPartitionId = 0
+        while (currentPartitionId < numPartitions) {
+          signal.synchronized {
+            while (partitions(currentPartitionId) == null && error.isEmpty) {
+              signal.wait()
+            }
+
+            error.foreach {
+              case NonFatal(e) =>
+                responseObserver.onError(e)
+                logError("Error while processing query.", e)
+                return
+              case other => throw other
+            }
+          }
+
+          partitions(currentPartitionId).foreach { case (bytes, count) =>

Review Comment:
   I actually think this is fine but let me address this for doubly sure.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to