HyukjinKwon commented on code in PR #38759: URL: https://github.com/apache/spark/pull/38759#discussion_r1030973655
########## connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala: ########## @@ -71,20 +73,80 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp var numSent = 0 if (numPartitions > 0) { + type Batch = (Array[Byte], Long) + val batches = rows.mapPartitionsInternal( SparkConnectStreamHandler .rowToArrowConverter(schema, maxRecordsPerBatch, maxBatchSize, timeZoneId)) - batches.collect().foreach { case (bytes, count) => - val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId) - val batch = proto.ExecutePlanResponse.ArrowBatch - .newBuilder() - .setRowCount(count) - .setData(ByteString.copyFrom(bytes)) - .build() - response.setArrowBatch(batch) - responseObserver.onNext(response.build()) - numSent += 1 + val signal = new Object + val partitions = new Array[Array[Batch]](numPartitions) + var error: Option[Throwable] = None + + // This callback is executed by the DAGScheduler thread. + // After fetching a partition, it inserts the partition into the Map, and then + // wakes up the main thread. + val resultHandler = (partitionId: Int, partition: Array[Batch]) => { + signal.synchronized { + partitions(partitionId) = partition + signal.notify() + } + () + } + + val future = spark.sparkContext.submitJob( + rdd = batches, + processPartition = (iter: Iterator[Batch]) => iter.toArray, + partitions = Seq.range(0, numPartitions), + resultHandler = resultHandler, + resultFunc = () => ()) + + // Collect errors and propagate them to the main thread. + future.onComplete { result => + result.failed.foreach { throwable => + signal.synchronized { + error = Some(throwable) + signal.notify() + } + } + }(ThreadUtils.sameThread) + + // The main thread will wait until 0-th partition is available, + // then send it to client and wait for the next partition. + // Different from the implementation of [[Dataset#collectAsArrowToPython]], it sends + // the arrow batches in main thread to avoid DAGScheduler thread been blocked for + // tasks not related to scheduling. This is particularly important if there are + // multiple users or clients running code at the same time. + var currentPartitionId = 0 + while (currentPartitionId < numPartitions) { + signal.synchronized { + while (partitions(currentPartitionId) == null && error.isEmpty) { + signal.wait() + } + + error.foreach { + case NonFatal(e) => + responseObserver.onError(e) + logError("Error while processing query.", e) + return + case other => throw other + } + } + + partitions(currentPartitionId).foreach { case (bytes, count) => Review Comment: I actually think this is fine but let me address this for doubly sure. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org