Github user BryanCutler commented on a diff in the pull request:
https://github.com/apache/spark/pull/21546#discussion_r199241520
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
---
@@ -183,34 +182,111 @@ private[sql] object ArrowConverters {
}
/**
- * Convert a byte array to an ArrowRecordBatch.
+ * Load a serialized ArrowRecordBatch.
*/
- private[arrow] def byteArrayToBatch(
+ private[arrow] def loadBatch(
batchBytes: Array[Byte],
allocator: BufferAllocator): ArrowRecordBatch = {
- val in = new ByteArrayReadableSeekableByteChannel(batchBytes)
- val reader = new ArrowFileReader(in, allocator)
-
- // Read a batch from a byte stream, ensure the reader is closed
- Utils.tryWithSafeFinally {
- val root = reader.getVectorSchemaRoot // throws IOException
- val unloader = new VectorUnloader(root)
- reader.loadNextBatch() // throws IOException
- unloader.getRecordBatch
- } {
- reader.close()
- }
+ val in = new ByteArrayInputStream(batchBytes)
+ MessageSerializer.deserializeMessageBatch(new
ReadChannel(Channels.newChannel(in)), allocator)
+ .asInstanceOf[ArrowRecordBatch] // throws IOException
}
+ /**
+ * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches.
+ */
private[sql] def toDataFrame(
- payloadRDD: JavaRDD[Array[Byte]],
+ arrowBatchRDD: JavaRDD[Array[Byte]],
schemaString: String,
sqlContext: SQLContext): DataFrame = {
- val rdd = payloadRDD.rdd.mapPartitions { iter =>
+ val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
+ val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone
+ val rdd = arrowBatchRDD.rdd.mapPartitions { iter =>
val context = TaskContext.get()
- ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)),
context)
+ ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context)
}
- val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
sqlContext.internalCreateDataFrame(rdd, schema)
}
+
+ /**
+ * Read a file as an Arrow stream and parallelize as an RDD of
serialized ArrowRecordBatches.
+ */
+ private[sql] def readArrowStreamFromFile(
+ sqlContext: SQLContext,
+ filename: String): JavaRDD[Array[Byte]] = {
+ val fileStream = new FileInputStream(filename)
+ try {
+ // Create array so that we can safely close the file
+ val batches = getBatchesFromStream(fileStream.getChannel).toArray
+ // Parallelize the record batches to create an RDD
+ JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches,
batches.length))
+ } finally {
+ fileStream.close()
+ }
+ }
+
+ /**
+ * Read an Arrow stream input and return an iterator of serialized
ArrowRecordBatches.
+ */
+ private[sql] def getBatchesFromStream(in: SeekableByteChannel):
Iterator[Array[Byte]] = {
+
+ // TODO: this could be moved to Arrow
+ def readMessageLength(in: ReadChannel): Int = {
+ val buffer = ByteBuffer.allocate(4)
+ if (in.readFully(buffer) != 4) {
+ return 0
+ }
+ MessageSerializer.bytesToInt(buffer.array())
+ }
+
+ // TODO: this could be moved to Arrow
+ def loadMessage(in: ReadChannel, messageLength: Int, buffer:
ByteBuffer): Message = {
+ if (in.readFully(buffer) != messageLength) {
+ throw new java.io.IOException(
+ "Unexpected end of stream trying to read message.")
+ }
+ buffer.rewind()
+ Message.getRootAsMessage(buffer)
+ }
+
+
+ // Create an iterator to get each serialized ArrowRecordBatch from a
stream
+ new Iterator[Array[Byte]] {
+ val inputChannel = new ReadChannel(in)
+ var batch: Array[Byte] = readNextBatch()
+
+ override def hasNext: Boolean = batch != null
+
+ override def next(): Array[Byte] = {
+ val prevBatch = batch
+ batch = readNextBatch()
+ prevBatch
+ }
+
+ def readNextBatch(): Array[Byte] = {
+ val messageLength = readMessageLength(inputChannel)
+ if (messageLength == 0) {
+ return null
+ }
+
+ val buffer = ByteBuffer.allocate(messageLength)
+ val msg = loadMessage(inputChannel, messageLength, buffer)
--- End diff --
I'll propose something different under
https://github.com/apache/arrow/pull/2139, we can continue to discuss there
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]