Github user viirya commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r194963031 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala --- @@ -183,34 +182,131 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { - val in = new ByteArrayReadableSeekableByteChannel(batchBytes) - val reader = new ArrowFileReader(in, allocator) - - // Read a batch from a byte stream, ensure the reader is closed - Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch - } { - reader.close() - } + val in = new ByteArrayInputStream(batchBytes) + MessageSerializer.deserializeMessageBatch(new ReadChannel(Channels.newChannel(in)), allocator) + .asInstanceOf[ArrowRecordBatch] // throws IOException } + /** + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + */ private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { - val rdd = payloadRDD.rdd.mapPartitions { iter => + val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] + val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone + val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } - val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd, schema) } + + /** + * Read a file as an Arrow stream and return an RDD of serialized ArrowRecordBatches. + */ + private[sql] def readArrowStreamFromFile(sqlContext: SQLContext, filename: String): + JavaRDD[Array[Byte]] = { --- End diff -- indentation
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org