Github user icexelloss commented on a diff in the pull request:
https://github.com/apache/spark/pull/21546#discussion_r197149246
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
---
@@ -183,34 +182,131 @@ private[sql] object ArrowConverters {
}
/**
- * Convert a byte array to an ArrowRecordBatch.
+ * Load a serialized ArrowRecordBatch.
*/
- private[arrow] def byteArrayToBatch(
+ private[arrow] def loadBatch(
batchBytes: Array[Byte],
allocator: BufferAllocator): ArrowRecordBatch = {
- val in = new ByteArrayReadableSeekableByteChannel(batchBytes)
- val reader = new ArrowFileReader(in, allocator)
-
- // Read a batch from a byte stream, ensure the reader is closed
- Utils.tryWithSafeFinally {
- val root = reader.getVectorSchemaRoot // throws IOException
- val unloader = new VectorUnloader(root)
- reader.loadNextBatch() // throws IOException
- unloader.getRecordBatch
- } {
- reader.close()
- }
+ val in = new ByteArrayInputStream(batchBytes)
+ MessageSerializer.deserializeMessageBatch(new
ReadChannel(Channels.newChannel(in)), allocator)
+ .asInstanceOf[ArrowRecordBatch] // throws IOException
}
+ /**
+ * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches.
+ */
private[sql] def toDataFrame(
- payloadRDD: JavaRDD[Array[Byte]],
+ arrowBatchRDD: JavaRDD[Array[Byte]],
schemaString: String,
sqlContext: SQLContext): DataFrame = {
- val rdd = payloadRDD.rdd.mapPartitions { iter =>
+ val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
+ val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone
+ val rdd = arrowBatchRDD.rdd.mapPartitions { iter =>
val context = TaskContext.get()
- ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)),
context)
+ ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context)
}
- val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
sqlContext.internalCreateDataFrame(rdd, schema)
}
+
+ /**
+ * Read a file as an Arrow stream and return an RDD of serialized
ArrowRecordBatches.
+ */
+ private[sql] def readArrowStreamFromFile(sqlContext: SQLContext,
filename: String):
+ JavaRDD[Array[Byte]] = {
+ val fileStream = new FileInputStream(filename)
+ try {
+ // Create array so that we can safely close the file
+ val batches = getBatchesFromStream(fileStream.getChannel).toArray
+ // Parallelize the record batches to create an RDD
+ JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches,
batches.length))
+ } finally {
+ fileStream.close()
+ }
+ }
+
+ /**
+ * Read an Arrow stream input and return an iterator of serialized
ArrowRecordBatches.
+ */
+ private[sql] def getBatchesFromStream(in: SeekableByteChannel):
Iterator[Array[Byte]] = {
+
+ // TODO: simplify in super class
+ class RecordBatchMessageReader(inputChannel: SeekableByteChannel) {
--- End diff --
Btw I don't mind discuss it in this PR too. Also curious what other people
think.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]