Github user sethah commented on a diff in the pull request:
https://github.com/apache/spark/pull/21546#discussion_r199496002
--- Diff: sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala ---
@@ -3236,13 +3237,50 @@ class Dataset[T] private[sql](
}
/**
- * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark.
+ * Collect a Dataset as Arrow batches and serve stream to PySpark.
*/
private[sql] def collectAsArrowToPython(): Array[Any] = {
+ val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
+
withAction("collectAsArrowToPython", queryExecution) { plan =>
- val iter: Iterator[Array[Byte]] =
- toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable)
- PythonRDD.serveIterator(iter, "serve-Arrow")
+ PythonRDD.serveToStream("serve-Arrow") { outputStream =>
+ val out = new DataOutputStream(outputStream)
+ val batchWriter = new ArrowBatchStreamWriter(schema, out,
timeZoneId)
+ val arrowBatchRdd = getArrowBatchRdd(plan)
+ val numPartitions = arrowBatchRdd.partitions.length
+
+ // Batches ordered by index of partition + fractional value of
batch # in partition
+ val batchOrder = new ArrayBuffer[Float]()
+ var partitionCount = 0
+
+ // Handler to eagerly write batches to Python out of order
+ def handlePartitionBatches(index: Int, arrowBatches:
Array[Array[Byte]]): Unit = {
+ if (arrowBatches.nonEmpty) {
+ batchWriter.writeBatches(arrowBatches.iterator)
+ (0 until arrowBatches.length).foreach { i =>
+ batchOrder.append(index + i / arrowBatches.length)
--- End diff --
This code: `(0 until array.length).map(i => i / array.length)` is
guaranteed to produce only zero values isn't it? The code works, since `sortBy`
evidently preserves the ordering of equal elements, but you may as well do
`batchOrder.append(index)` since it's the same.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]