Github user rxin commented on a diff in the pull request:

    https://github.com/apache/spark/pull/11664#discussion_r55956440
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala ---
    @@ -220,7 +222,61 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] 
with Logging with Serializ
        * Runs this query returning the result as an array.
        */
       def executeCollect(): Array[InternalRow] = {
    -    execute().map(_.copy()).collect()
    +    // Packing the UnsafeRows into byte array for faster serialization.
    +    // The byte arrays are in the following format:
    +    // [size] [bytes of UnsafeRow] [size] [bytes of UnsafeRow] ... [-1]
    +    val byteArrayRdd = execute().mapPartitionsInternal { iter =>
    +      new Iterator[Array[Byte]] {
    +        private var row: UnsafeRow = _
    +        override def hasNext: Boolean = row != null || iter.hasNext
    +        override def next: Array[Byte] = {
    +          var cap = 1 << 20  // 1 MB
    +          if (row != null) {
    +            // the buffered row could be larger than default buffer size
    +            cap = Math.max(cap, 4 + row.getSizeInBytes + 4) // reverse 4 
bytes for ending mark (-1).
    +          }
    +          val buffer = ByteBuffer.allocate(cap)
    +          if (row != null) {
    +            buffer.putInt(row.getSizeInBytes)
    +            row.writeTo(buffer)
    +            row = null
    +          }
    +          while (iter.hasNext) {
    +            row = iter.next().asInstanceOf[UnsafeRow]
    +            // Reserve last 4 bytes for ending mark
    +            if (4 + row.getSizeInBytes + 4 <= buffer.remaining()) {
    +              buffer.putInt(row.getSizeInBytes)
    +              row.writeTo(buffer)
    +              row = null
    +            } else {
    +              buffer.putInt(-1)
    +              return buffer.array()
    +            }
    +          }
    +          buffer.putInt(-1)
    +          // copy the used bytes to make it smaller
    +          val bytes = new Array[Byte](buffer.limit())
    +          System.arraycopy(buffer.array(), 0, bytes, 0, buffer.limit())
    +          bytes
    +        }
    +      }
    +    }
    +    // Collect the byte arrays back to driver, then decode them as 
UnsafeRows.
    +    val nFields = schema.length
    +    byteArrayRdd.collect().flatMap { bytes =>
    --- End diff --
    
    i think this block would be more readable if we just write it imperatively, 
e.g.
    
    ```scala
    
    val results = new ArrayBuffer
    
    byteArrayRdd.collect().foreach { bytes =>
      var sizeOfNextRow = bytes.getInt()
      while (sizeOfNextRow >= 0) {
        val row = new UnsafeRow(nFields)
        row.pointTo(buffer.array(), Platform.BYTE_ARRAY_OFFSET + 
buffer.position(), sizeInBytes)
        buffer.position(buffer.position() + sizeOfNextRow)
        results += row
      }
    }
    results.toArray
    ```


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to