Github user sethah commented on a diff in the pull request:
https://github.com/apache/spark/pull/21546#discussion_r199275753
--- Diff: sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala ---
@@ -3236,13 +3237,50 @@ class Dataset[T] private[sql](
}
/**
- * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark.
+ * Collect a Dataset as Arrow batches and serve stream to PySpark.
*/
private[sql] def collectAsArrowToPython(): Array[Any] = {
+ val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
+
withAction("collectAsArrowToPython", queryExecution) { plan =>
- val iter: Iterator[Array[Byte]] =
- toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable)
- PythonRDD.serveIterator(iter, "serve-Arrow")
+ PythonRDD.serveToStream("serve-Arrow") { outputStream =>
+ val out = new DataOutputStream(outputStream)
+ val batchWriter = new ArrowBatchStreamWriter(schema, out,
timeZoneId)
+ val arrowBatchRdd = getArrowBatchRdd(plan)
+ val numPartitions = arrowBatchRdd.partitions.length
+
+ // Batches ordered by index of partition + batch number for that
partition
+ val batchOrder = new ArrayBuffer[Int]()
+ var partitionCount = 0
+
+ // Handler to eagerly write batches to Python out of order
+ def handlePartitionBatches(index: Int, arrowBatches:
Array[Array[Byte]]): Unit = {
+ if (arrowBatches.nonEmpty) {
+ batchWriter.writeBatches(arrowBatches.iterator)
+ (0 until arrowBatches.length).foreach { i =>
+ batchOrder.append(index + i)
+ }
+ }
+ partitionCount += 1
+
+ // After last batch, end the stream and write batch order
+ if (partitionCount == numPartitions) {
+ batchWriter.end()
+ out.writeInt(batchOrder.length)
+ // Batch order indices are from 0 to N-1 batches, sorted by
order they arrived
+ batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, i) =>
--- End diff --
Does this logic do what you intend? It interleaves batches.
```python
df = spark.range(64).toDF("a")
df.rdd.getNumPartitions() # 8
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", 4)
pdf = df.toPandas()
pdf['a'].values
# array([ 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 16, 17, 18, 19, 12,
# 13, 14, 15, 20, 21, 22, 23, 24, 25, 26, 27, 32, 33, 34, 35, 28, 29,
# 30, 31, 36, 37, 38, 39, 40, 41, 42, 43, 48, 49, 50, 51, 44, 45, 46,
# 47, 56, 57, 58, 59, 52, 53, 54, 55, 60, 61, 62, 63])
```
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]