Github user icexelloss commented on a diff in the pull request:
https://github.com/apache/spark/pull/21546#discussion_r211964996
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
---
@@ -183,34 +178,106 @@ private[sql] object ArrowConverters {
}
/**
- * Convert a byte array to an ArrowRecordBatch.
+ * Load a serialized ArrowRecordBatch.
*/
- private[arrow] def byteArrayToBatch(
+ private[arrow] def loadBatch(
batchBytes: Array[Byte],
allocator: BufferAllocator): ArrowRecordBatch = {
- val in = new ByteArrayReadableSeekableByteChannel(batchBytes)
- val reader = new ArrowFileReader(in, allocator)
-
- // Read a batch from a byte stream, ensure the reader is closed
- Utils.tryWithSafeFinally {
- val root = reader.getVectorSchemaRoot // throws IOException
- val unloader = new VectorUnloader(root)
- reader.loadNextBatch() // throws IOException
- unloader.getRecordBatch
- } {
- reader.close()
- }
+ val in = new ByteArrayInputStream(batchBytes)
+ MessageSerializer.deserializeRecordBatch(
+ new ReadChannel(Channels.newChannel(in)), allocator) // throws
IOException
}
+ /**
+ * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches.
+ */
private[sql] def toDataFrame(
- payloadRDD: JavaRDD[Array[Byte]],
+ arrowBatchRDD: JavaRDD[Array[Byte]],
schemaString: String,
sqlContext: SQLContext): DataFrame = {
- val rdd = payloadRDD.rdd.mapPartitions { iter =>
+ val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
+ val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone
+ val rdd = arrowBatchRDD.rdd.mapPartitions { iter =>
val context = TaskContext.get()
- ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)),
context)
+ ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context)
}
- val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
sqlContext.internalCreateDataFrame(rdd, schema)
}
+
+ /**
+ * Read a file as an Arrow stream and parallelize as an RDD of
serialized ArrowRecordBatches.
+ */
+ private[sql] def readArrowStreamFromFile(
+ sqlContext: SQLContext,
+ filename: String): JavaRDD[Array[Byte]] = {
+ val fileStream = new FileInputStream(filename)
+ try {
+ // Create array so that we can safely close the file
+ val batches = getBatchesFromStream(fileStream.getChannel).toArray
+ // Parallelize the record batches to create an RDD
+ JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches,
batches.length))
+ } finally {
+ fileStream.close()
+ }
+ }
+
+ /**
+ * Read an Arrow stream input and return an iterator of serialized
ArrowRecordBatches.
+ */
+ private[sql] def getBatchesFromStream(in: SeekableByteChannel):
Iterator[Array[Byte]] = {
+
+ // Create an iterator to get each serialized ArrowRecordBatch from a
stream
+ new Iterator[Array[Byte]] {
+ var batch: Array[Byte] = readNextBatch()
+
+ override def hasNext: Boolean = batch != null
+
+ override def next(): Array[Byte] = {
+ val prevBatch = batch
+ batch = readNextBatch()
+ prevBatch
+ }
+
+ def readNextBatch(): Array[Byte] = {
+ val msgMetadata = MessageSerializer.readMessage(new
ReadChannel(in))
+ if (msgMetadata == null) {
+ return null
+ }
+
+ // Get the length of the body, which has not be read at this point
+ val bodyLength = msgMetadata.getMessageBodyLength.toInt
+
+ // Only care about RecordBatch data, skip Schema and unsupported
Dictionary messages
+ if (msgMetadata.getMessage.headerType() ==
MessageHeader.RecordBatch) {
+
+ // Create output backed by buffer to hold msg length (int32),
msg metadata, msg body
+ val bbout = new ByteBufferOutputStream(4 +
msgMetadata.getMessageLength + bodyLength)
--- End diff --
Add a comment that this is the deserialized form of an Arrow Record Batch?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]