Github user sethah commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21546#discussion_r199275753
  
    --- Diff: sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala ---
    @@ -3236,13 +3237,50 @@ class Dataset[T] private[sql](
       }
     
       /**
    -   * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark.
    +   * Collect a Dataset as Arrow batches and serve stream to PySpark.
        */
       private[sql] def collectAsArrowToPython(): Array[Any] = {
    +    val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
    +
         withAction("collectAsArrowToPython", queryExecution) { plan =>
    -      val iter: Iterator[Array[Byte]] =
    -        toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable)
    -      PythonRDD.serveIterator(iter, "serve-Arrow")
    +      PythonRDD.serveToStream("serve-Arrow") { outputStream =>
    +        val out = new DataOutputStream(outputStream)
    +        val batchWriter = new ArrowBatchStreamWriter(schema, out, 
timeZoneId)
    +        val arrowBatchRdd = getArrowBatchRdd(plan)
    +        val numPartitions = arrowBatchRdd.partitions.length
    +
    +        // Batches ordered by index of partition + batch number for that 
partition
    +        val batchOrder = new ArrayBuffer[Int]()
    +        var partitionCount = 0
    +
    +        // Handler to eagerly write batches to Python out of order
    +        def handlePartitionBatches(index: Int, arrowBatches: 
Array[Array[Byte]]): Unit = {
    +          if (arrowBatches.nonEmpty) {
    +            batchWriter.writeBatches(arrowBatches.iterator)
    +            (0 until arrowBatches.length).foreach { i =>
    +              batchOrder.append(index + i)
    +            }
    +          }
    +          partitionCount += 1
    +
    +          // After last batch, end the stream and write batch order
    +          if (partitionCount == numPartitions) {
    +            batchWriter.end()
    +            out.writeInt(batchOrder.length)
    +            // Batch order indices are from 0 to N-1 batches, sorted by 
order they arrived
    +            batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, i) =>
    --- End diff --
    
    Does this logic do what you intend? It interleaves batches.
    
    ```python
    df = spark.range(64).toDF("a")
    df.rdd.getNumPartitions()  # 8
    spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", 4)
    pdf = df.toPandas()
    pdf['a'].values
    # array([ 0,  1,  2,  3,  8,  9, 10, 11,  4,  5,  6,  7, 16, 17, 18, 19, 12,
    #       13, 14, 15, 20, 21, 22, 23, 24, 25, 26, 27, 32, 33, 34, 35, 28, 29,
    #       30, 31, 36, 37, 38, 39, 40, 41, 42, 43, 48, 49, 50, 51, 44, 45, 46,
    #       47, 56, 57, 58, 59, 52, 53, 54, 55, 60, 61, 62, 63])
    ```


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to