Github user BryanCutler commented on a diff in the pull request:
https://github.com/apache/spark/pull/21546#discussion_r197932884
--- Diff: sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala ---
@@ -3236,13 +3236,49 @@ class Dataset[T] private[sql](
}
/**
- * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark.
+ * Collect a Dataset as Arrow batches and serve stream to PySpark.
*/
private[sql] def collectAsArrowToPython(): Array[Any] = {
+ val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
+
withAction("collectAsArrowToPython", queryExecution) { plan =>
- val iter: Iterator[Array[Byte]] =
- toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable)
- PythonRDD.serveIterator(iter, "serve-Arrow")
+ PythonRDD.serveToStream("serve-Arrow") { out =>
+ val batchWriter = new ArrowBatchStreamWriter(schema, out,
timeZoneId)
+ val arrowBatchRdd = getArrowBatchRdd(plan)
+ val numPartitions = arrowBatchRdd.partitions.length
+
+ // Store collection results for worst case of 1 to N-1 partitions
+ val results = new Array[Array[Array[Byte]]](numPartitions - 1)
+ var lastIndex = -1 // index of last partition written
+
+ // Handler to eagerly write partitions to Python in order
+ def handlePartitionBatches(index: Int, arrowBatches:
Array[Array[Byte]]): Unit = {
+ // If result is from next partition in order
+ if (index - 1 == lastIndex) {
+ batchWriter.writeBatches(arrowBatches.iterator)
+ lastIndex += 1
+ // Write stored partitions that come next in order
+ while (lastIndex < results.length && results(lastIndex) !=
null) {
+ batchWriter.writeBatches(results(lastIndex).iterator)
+ results(lastIndex) = null
+ lastIndex += 1
+ }
+ // After last batch, end the stream
+ if (lastIndex == results.length) {
+ batchWriter.end()
+ }
+ } else {
+ // Store partitions received out of order
+ results(index - 1) = arrowBatches
+ }
+ }
+
+ sparkSession.sparkContext.runJob(
+ arrowBatchRdd,
+ (ctx: TaskContext, it: Iterator[Array[Byte]]) => it.toArray,
--- End diff --
Looking at this again, `it.toArray` is run on the executor, which ends up
doing the same thing as `collect()` and then `handlePartitions` is run on the
results of that in the driver. The task results need to be serialized, so I'm
not sure if we can avoid `it.toArray` here, any thoughts @ueshin ?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]