ueshin commented on code in PR #40806:
URL: https://github.com/apache/spark/pull/40806#discussion_r1169054143
##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -161,100 +161,104 @@ object SparkConnectStreamHandler {
// Conservatively sets it 70% because the size is not accurate but
estimated.
val maxBatchSize =
(SparkEnv.get.conf.get(CONNECT_GRPC_ARROW_MAX_BATCH_SIZE) * 0.7).toLong
- SQLExecution.withNewExecutionId(dataframe.queryExecution,
Some("collectArrow")) {
- val rows = dataframe.queryExecution.executedPlan.execute()
- val numPartitions = rows.getNumPartitions
- var numSent = 0
-
- if (numPartitions > 0) {
- type Batch = (Array[Byte], Long)
-
- val batches = rows.mapPartitionsInternal(
- SparkConnectStreamHandler
- .rowToArrowConverter(schema, maxRecordsPerBatch, maxBatchSize,
timeZoneId))
-
- val signal = new Object
- val partitions = new Array[Array[Batch]](numPartitions)
- var error: Option[Throwable] = None
-
- // This callback is executed by the DAGScheduler thread.
- // After fetching a partition, it inserts the partition into the Map,
and then
- // wakes up the main thread.
- val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
- signal.synchronized {
- partitions(partitionId) = partition
- signal.notify()
- }
- ()
- }
+ val rowToArrowConverter = SparkConnectStreamHandler
+ .rowToArrowConverter(schema, maxRecordsPerBatch, maxBatchSize,
timeZoneId)
- val future = spark.sparkContext.submitJob(
- rdd = batches,
- processPartition = (iter: Iterator[Batch]) => iter.toArray,
- partitions = Seq.range(0, numPartitions),
- resultHandler = resultHandler,
- resultFunc = () => ())
-
- // Collect errors and propagate them to the main thread.
- future.onComplete { result =>
- result.failed.foreach { throwable =>
- signal.synchronized {
- error = Some(throwable)
- signal.notify()
- }
- }
- }(ThreadUtils.sameThread)
-
- // The main thread will wait until 0-th partition is available,
- // then send it to client and wait for the next partition.
- // Different from the implementation of
[[Dataset#collectAsArrowToPython]], it sends
- // the arrow batches in main thread to avoid DAGScheduler thread been
blocked for
- // tasks not related to scheduling. This is particularly important if
there are
- // multiple users or clients running code at the same time.
- var currentPartitionId = 0
- while (currentPartitionId < numPartitions) {
- val partition = signal.synchronized {
- var part = partitions(currentPartitionId)
- while (part == null && error.isEmpty) {
- signal.wait()
- part = partitions(currentPartitionId)
- }
- partitions(currentPartitionId) = null
+ var numSent = 0
+ def sendBatch(bytes: Array[Byte], count: Long): Unit = {
+ val response =
proto.ExecutePlanResponse.newBuilder().setSessionId(sessionId)
+ val batch = proto.ExecutePlanResponse.ArrowBatch
+ .newBuilder()
+ .setRowCount(count)
+ .setData(ByteString.copyFrom(bytes))
+ .build()
+ response.setArrowBatch(batch)
+ responseObserver.onNext(response.build())
+ numSent += 1
+ }
- error.foreach { case other =>
- throw other
+ dataframe.queryExecution.executedPlan match {
+ case LocalTableScanExec(_, rows) =>
+ rowToArrowConverter(rows.iterator).foreach { case (bytes, count) =>
+ sendBatch(bytes, count)
+ }
+ case _ =>
+ SQLExecution.withNewExecutionId(dataframe.queryExecution,
Some("collectArrow")) {
+ val rows = dataframe.queryExecution.executedPlan.execute()
+ val numPartitions = rows.getNumPartitions
+
+ if (numPartitions > 0) {
+ type Batch = (Array[Byte], Long)
+
+ val batches = rows.mapPartitionsInternal(rowToArrowConverter)
+
+ val signal = new Object
+ val partitions = new Array[Array[Batch]](numPartitions)
+ var error: Option[Throwable] = None
+
+ // This callback is executed by the DAGScheduler thread.
+ // After fetching a partition, it inserts the partition into the
Map, and then
+ // wakes up the main thread.
+ val resultHandler = (partitionId: Int, partition: Array[Batch]) =>
{
+ signal.synchronized {
+ partitions(partitionId) = partition
+ signal.notify()
+ }
+ ()
}
- part
- }
- partition.foreach { case (bytes, count) =>
- val response =
proto.ExecutePlanResponse.newBuilder().setSessionId(sessionId)
- val batch = proto.ExecutePlanResponse.ArrowBatch
- .newBuilder()
- .setRowCount(count)
- .setData(ByteString.copyFrom(bytes))
- .build()
- response.setArrowBatch(batch)
- responseObserver.onNext(response.build())
- numSent += 1
+ val future = spark.sparkContext.submitJob(
+ rdd = batches,
+ processPartition = (iter: Iterator[Batch]) => iter.toArray,
+ partitions = Seq.range(0, numPartitions),
+ resultHandler = resultHandler,
+ resultFunc = () => ())
+
+ // Collect errors and propagate them to the main thread.
+ future.onComplete { result =>
+ result.failed.foreach { throwable =>
+ signal.synchronized {
+ error = Some(throwable)
+ signal.notify()
+ }
+ }
+ }(ThreadUtils.sameThread)
+
+ // The main thread will wait until 0-th partition is available,
+ // then send it to client and wait for the next partition.
+ // Different from the implementation of
[[Dataset#collectAsArrowToPython]], it sends
+ // the arrow batches in main thread to avoid DAGScheduler thread
been blocked for
+ // tasks not related to scheduling. This is particularly important
if there are
+ // multiple users or clients running code at the same time.
+ var currentPartitionId = 0
+ while (currentPartitionId < numPartitions) {
+ val partition = signal.synchronized {
+ var part = partitions(currentPartitionId)
+ while (part == null && error.isEmpty) {
+ signal.wait()
+ part = partitions(currentPartitionId)
+ }
+ partitions(currentPartitionId) = null
+
+ error.foreach { case other =>
Review Comment:
Let me fix it, too, while we're here. Thanks.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]