Github user viirya commented on a diff in the pull request:
https://github.com/apache/spark/pull/21546#discussion_r194963031
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
---
@@ -183,34 +182,131 @@ private[sql] object ArrowConverters {
}
/**
- * Convert a byte array to an ArrowRecordBatch.
+ * Load a serialized ArrowRecordBatch.
*/
- private[arrow] def byteArrayToBatch(
+ private[arrow] def loadBatch(
batchBytes: Array[Byte],
allocator: BufferAllocator): ArrowRecordBatch = {
- val in = new ByteArrayReadableSeekableByteChannel(batchBytes)
- val reader = new ArrowFileReader(in, allocator)
-
- // Read a batch from a byte stream, ensure the reader is closed
- Utils.tryWithSafeFinally {
- val root = reader.getVectorSchemaRoot // throws IOException
- val unloader = new VectorUnloader(root)
- reader.loadNextBatch() // throws IOException
- unloader.getRecordBatch
- } {
- reader.close()
- }
+ val in = new ByteArrayInputStream(batchBytes)
+ MessageSerializer.deserializeMessageBatch(new
ReadChannel(Channels.newChannel(in)), allocator)
+ .asInstanceOf[ArrowRecordBatch] // throws IOException
}
+ /**
+ * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches.
+ */
private[sql] def toDataFrame(
- payloadRDD: JavaRDD[Array[Byte]],
+ arrowBatchRDD: JavaRDD[Array[Byte]],
schemaString: String,
sqlContext: SQLContext): DataFrame = {
- val rdd = payloadRDD.rdd.mapPartitions { iter =>
+ val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
+ val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone
+ val rdd = arrowBatchRDD.rdd.mapPartitions { iter =>
val context = TaskContext.get()
- ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)),
context)
+ ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context)
}
- val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
sqlContext.internalCreateDataFrame(rdd, schema)
}
+
+ /**
+ * Read a file as an Arrow stream and return an RDD of serialized
ArrowRecordBatches.
+ */
+ private[sql] def readArrowStreamFromFile(sqlContext: SQLContext,
filename: String):
+ JavaRDD[Array[Byte]] = {
--- End diff --
indentation
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]