[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r202825767 --- Diff: python/pyspark/serializers.py --- @@ -184,27 +184,59 @@ def loads(self, obj): raise NotImplementedError -class ArrowSerializer(FramedSerializer): +class BatchOrderSerializer(Serializer): """ -Serializes bytes as Arrow data with the Arrow file format. +Deserialize a stream of batches followed by batch order information. """ -def dumps(self, batch): +def __init__(self, serializer): +self.serializer = serializer +self.batch_order = [] + +def dump_stream(self, iterator, stream): +return self.serializer.dump_stream(iterator, stream) + +def load_stream(self, stream): +for batch in self.serializer.load_stream(stream): +yield batch +num = read_int(stream) +for i in xrange(num): +index = read_int(stream) +self.batch_order.append(index) +raise StopIteration() + +def get_batch_order(self): --- End diff -- I added an assert and reset `self.batch_order` after calling to ensure `load_stream` is called first and the serializer instance could be used again without retaining state. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r199623502 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala --- @@ -183,34 +182,111 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { -val in = new ByteArrayReadableSeekableByteChannel(batchBytes) -val reader = new ArrowFileReader(in, allocator) - -// Read a batch from a byte stream, ensure the reader is closed -Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch -} { - reader.close() -} +val in = new ByteArrayInputStream(batchBytes) +MessageSerializer.deserializeMessageBatch(new ReadChannel(Channels.newChannel(in)), allocator) + .asInstanceOf[ArrowRecordBatch] // throws IOException } + /** + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + */ private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { -val rdd = payloadRDD.rdd.mapPartitions { iter => +val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] +val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone +val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } -val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd, schema) } + + /** + * Read a file as an Arrow stream and parallelize as an RDD of serialized ArrowRecordBatches. + */ + private[sql] def readArrowStreamFromFile( + sqlContext: SQLContext, + filename: String): JavaRDD[Array[Byte]] = { +val fileStream = new FileInputStream(filename) +try { + // Create array so that we can safely close the file + val batches = getBatchesFromStream(fileStream.getChannel).toArray + // Parallelize the record batches to create an RDD + JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) +} finally { + fileStream.close() +} + } + + /** + * Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches. + */ + private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = { + +// TODO: this could be moved to Arrow +def readMessageLength(in: ReadChannel): Int = { + val buffer = ByteBuffer.allocate(4) + if (in.readFully(buffer) != 4) { +return 0 + } + MessageSerializer.bytesToInt(buffer.array()) +} + +// TODO: this could be moved to Arrow +def loadMessage(in: ReadChannel, messageLength: Int, buffer: ByteBuffer): Message = { + if (in.readFully(buffer) != messageLength) { +throw new java.io.IOException( + "Unexpected end of stream trying to read message.") + } + buffer.rewind() + Message.getRootAsMessage(buffer) +} + + +// Create an iterator to get each serialized ArrowRecordBatch from a stream +new Iterator[Array[Byte]] { + val inputChannel = new ReadChannel(in) + var batch: Array[Byte] = readNextBatch() + + override def hasNext: Boolean = batch != null + + override def next(): Array[Byte] = { +val prevBatch = batch +batch = readNextBatch() +prevBatch + } + + def readNextBatch(): Array[Byte] = { --- End diff -- I don't think there is any reason since this is just internal and it's an anonymous class so that method can't be called directly anyway. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r199622313 --- Diff: python/pyspark/serializers.py --- @@ -184,27 +184,59 @@ def loads(self, obj): raise NotImplementedError -class ArrowSerializer(FramedSerializer): +class BatchOrderSerializer(Serializer): """ -Serializes bytes as Arrow data with the Arrow file format. +Deserialize a stream of batches followed by batch order information. """ -def dumps(self, batch): +def __init__(self, serializer): +self.serializer = serializer +self.batch_order = [] + +def dump_stream(self, iterator, stream): +return self.serializer.dump_stream(iterator, stream) + +def load_stream(self, stream): +for batch in self.serializer.load_stream(stream): +yield batch +num = read_int(stream) +for i in xrange(num): +index = read_int(stream) +self.batch_order.append(index) +raise StopIteration() + +def get_batch_order(self): --- End diff -- Is this to protect from it being called before `load_stream`? I guess there is also the issue of what if the same serializer instance is used twice, then it will never get cleared.. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r199618628 --- Diff: core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala --- @@ -398,6 +398,25 @@ private[spark] object PythonRDD extends Logging { * data collected from this job, and the secret for authentication. */ def serveIterator(items: Iterator[_], threadName: String): Array[Any] = { +serveToStream(threadName) { out => + writeIteratorToStream(items, new DataOutputStream(out)) +} + } + + /** + * Create a socket server and background thread to execute the block of code + * for the given DataOutputStream. + * + * The socket server can only accept one connection, or close if no connection + * in 15 seconds. + * + * Once a connection comes in, it will execute the block of code and pass in + * the socket output stream. + * + * The thread will terminate after the block of code is executed or any + * exceptions happen. + */ + private[spark] def serveToStream(threadName: String)(block: OutputStream => Unit): Array[Any] = { --- End diff -- Yeah, I think I started off with `writeFunc`.. I agree sounds a bit better --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r199615244 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala --- @@ -38,70 +39,75 @@ import org.apache.spark.util.Utils /** - * Store Arrow data in a form that can be serialized by Spark and served to a Python process. + * Writes serialized ArrowRecordBatches to a DataOutputStream in the Arrow stream format. */ -private[sql] class ArrowPayload private[sql] (payload: Array[Byte]) extends Serializable { +private[sql] class ArrowBatchStreamWriter( +schema: StructType, +out: OutputStream, +timeZoneId: String) { - /** - * Convert the ArrowPayload to an ArrowRecordBatch. - */ - def loadBatch(allocator: BufferAllocator): ArrowRecordBatch = { -ArrowConverters.byteArrayToBatch(payload, allocator) - } + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + val writeChannel = new WriteChannel(Channels.newChannel(out)) + + // Write the Arrow schema first, before batches + MessageSerializer.serialize(writeChannel, arrowSchema) /** - * Get the ArrowPayload as a type that can be served to Python. + * Consume iterator to write each serialized ArrowRecordBatch to the stream. */ - def asPythonSerializable: Array[Byte] = payload -} - -/** - * Iterator interface to iterate over Arrow record batches and return rows - */ -private[sql] trait ArrowRowIterator extends Iterator[InternalRow] { + def writeBatches(arrowBatchIter: Iterator[Array[Byte]]): Unit = { +arrowBatchIter.foreach { batchBytes => + writeChannel.write(batchBytes) +} + } /** - * Return the schema loaded from the Arrow record batch being iterated over + * End the Arrow stream, does not close output stream. */ - def schema: StructType + def end(): Unit = { +// Write End of Stream --- End diff -- Since right now it's just writing a `0`, I think it's useful to comment that this means the EOS code. I have a TODO here to wrap this call in an Arrow function that will be more clear, then we wouldn't need a comment. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r199614435 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala --- @@ -183,34 +182,111 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { -val in = new ByteArrayReadableSeekableByteChannel(batchBytes) -val reader = new ArrowFileReader(in, allocator) - -// Read a batch from a byte stream, ensure the reader is closed -Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch -} { - reader.close() -} +val in = new ByteArrayInputStream(batchBytes) +MessageSerializer.deserializeMessageBatch(new ReadChannel(Channels.newChannel(in)), allocator) + .asInstanceOf[ArrowRecordBatch] // throws IOException } + /** + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + */ private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { -val rdd = payloadRDD.rdd.mapPartitions { iter => +val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] +val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone +val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } -val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd, schema) } + + /** + * Read a file as an Arrow stream and parallelize as an RDD of serialized ArrowRecordBatches. + */ + private[sql] def readArrowStreamFromFile( + sqlContext: SQLContext, + filename: String): JavaRDD[Array[Byte]] = { +val fileStream = new FileInputStream(filename) +try { + // Create array so that we can safely close the file + val batches = getBatchesFromStream(fileStream.getChannel).toArray + // Parallelize the record batches to create an RDD + JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) +} finally { + fileStream.close() +} + } + + /** + * Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches. + */ + private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = { + +// TODO: this could be moved to Arrow +def readMessageLength(in: ReadChannel): Int = { + val buffer = ByteBuffer.allocate(4) + if (in.readFully(buffer) != 4) { +return 0 + } + MessageSerializer.bytesToInt(buffer.array()) +} + +// TODO: this could be moved to Arrow +def loadMessage(in: ReadChannel, messageLength: Int, buffer: ByteBuffer): Message = { + if (in.readFully(buffer) != messageLength) { +throw new java.io.IOException( + "Unexpected end of stream trying to read message.") + } + buffer.rewind() + Message.getRootAsMessage(buffer) +} + + +// Create an iterator to get each serialized ArrowRecordBatch from a stream +new Iterator[Array[Byte]] { + val inputChannel = new ReadChannel(in) + var batch: Array[Byte] = readNextBatch() + + override def hasNext: Boolean = batch != null + + override def next(): Array[Byte] = { +val prevBatch = batch +batch = readNextBatch() +prevBatch + } + + def readNextBatch(): Array[Byte] = { +val messageLength = readMessageLength(inputChannel) +if (messageLength == 0) { + return null +} + +val buffer = ByteBuffer.allocate(messageLength) +val msg = loadMessage(inputChannel, messageLength, buffer) +val bodyLength = msg.bodyLength().asInstanceOf[Int] + +if (msg.headerType() == MessageHeader.RecordBatch) { + val allbuf = ByteBuffer.allocate(4 + messageLength + bodyLength) + al
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r199613741 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala --- @@ -3236,13 +3237,50 @@ class Dataset[T] private[sql]( } /** - * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. + * Collect a Dataset as Arrow batches and serve stream to PySpark. */ private[sql] def collectAsArrowToPython(): Array[Any] = { +val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone + withAction("collectAsArrowToPython", queryExecution) { plan => - val iter: Iterator[Array[Byte]] = -toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable) - PythonRDD.serveIterator(iter, "serve-Arrow") + PythonRDD.serveToStream("serve-Arrow") { outputStream => +val out = new DataOutputStream(outputStream) +val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId) +val arrowBatchRdd = getArrowBatchRdd(plan) +val numPartitions = arrowBatchRdd.partitions.length + +// Batches ordered by index of partition + fractional value of batch # in partition +val batchOrder = new ArrayBuffer[Float]() +var partitionCount = 0 + +// Handler to eagerly write batches to Python out of order +def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = { + if (arrowBatches.nonEmpty) { +batchWriter.writeBatches(arrowBatches.iterator) +(0 until arrowBatches.length).foreach { i => --- End diff -- Great, Intellij is trying to steal my job! :fearful: --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r199613477 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala --- @@ -183,34 +182,111 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { -val in = new ByteArrayReadableSeekableByteChannel(batchBytes) -val reader = new ArrowFileReader(in, allocator) - -// Read a batch from a byte stream, ensure the reader is closed -Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch -} { - reader.close() -} +val in = new ByteArrayInputStream(batchBytes) +MessageSerializer.deserializeMessageBatch(new ReadChannel(Channels.newChannel(in)), allocator) + .asInstanceOf[ArrowRecordBatch] // throws IOException } + /** + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + */ private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { -val rdd = payloadRDD.rdd.mapPartitions { iter => +val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] +val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone +val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } -val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd, schema) } + + /** + * Read a file as an Arrow stream and parallelize as an RDD of serialized ArrowRecordBatches. + */ + private[sql] def readArrowStreamFromFile( + sqlContext: SQLContext, + filename: String): JavaRDD[Array[Byte]] = { +val fileStream = new FileInputStream(filename) +try { + // Create array so that we can safely close the file + val batches = getBatchesFromStream(fileStream.getChannel).toArray + // Parallelize the record batches to create an RDD + JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) +} finally { + fileStream.close() +} + } + + /** + * Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches. + */ + private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = { + +// TODO: this could be moved to Arrow +def readMessageLength(in: ReadChannel): Int = { + val buffer = ByteBuffer.allocate(4) + if (in.readFully(buffer) != 4) { +return 0 + } + MessageSerializer.bytesToInt(buffer.array()) +} + +// TODO: this could be moved to Arrow +def loadMessage(in: ReadChannel, messageLength: Int, buffer: ByteBuffer): Message = { + if (in.readFully(buffer) != messageLength) { +throw new java.io.IOException( + "Unexpected end of stream trying to read message.") + } + buffer.rewind() + Message.getRootAsMessage(buffer) +} + + +// Create an iterator to get each serialized ArrowRecordBatch from a stream +new Iterator[Array[Byte]] { + val inputChannel = new ReadChannel(in) --- End diff -- It is being closed, just by `readArrowStreamFromFile` which owns the stream and consumes the iterator. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r199612847 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala --- @@ -34,17 +33,19 @@ private[sql] object PythonSQLUtils { } /** - * Python Callable function to convert ArrowPayloads into a [[DataFrame]]. + * Python callable function to read a file in Arrow stream format and create a [[DataFrame]] + * using each serialized ArrowRecordBatch as a partition. * - * @param payloadRDD A JavaRDD of ArrowPayloads. - * @param schemaString JSON Formatted Schema for ArrowPayloads. * @param sqlContext The active [[SQLContext]]. - * @return The converted [[DataFrame]]. + * @param filename File to read the Arrow stream from. + * @param schemaString JSON Formatted Spark schema for Arrow batches. + * @return A new [[DataFrame]]. */ - def arrowPayloadToDataFrame( - payloadRDD: JavaRDD[Array[Byte]], - schemaString: String, - sqlContext: SQLContext): DataFrame = { -ArrowConverters.toDataFrame(payloadRDD, schemaString, sqlContext) + def arrowReadStreamFromFile( --- End diff -- `arrowStreamFromFile` is important to get in the name since it is a stream format being read from a file, but how about `arrowStreamFromFileToDataFrame`? Its a bit long but it would be good to indicate that it produces a `DataFrame` for the call from Python. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r199561983 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala --- @@ -3236,13 +3237,50 @@ class Dataset[T] private[sql]( } /** - * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. + * Collect a Dataset as Arrow batches and serve stream to PySpark. */ private[sql] def collectAsArrowToPython(): Array[Any] = { +val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone + withAction("collectAsArrowToPython", queryExecution) { plan => - val iter: Iterator[Array[Byte]] = -toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable) - PythonRDD.serveIterator(iter, "serve-Arrow") + PythonRDD.serveToStream("serve-Arrow") { outputStream => +val out = new DataOutputStream(outputStream) +val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId) +val arrowBatchRdd = getArrowBatchRdd(plan) +val numPartitions = arrowBatchRdd.partitions.length + +// Batches ordered by index of partition + fractional value of batch # in partition +val batchOrder = new ArrayBuffer[Float]() +var partitionCount = 0 + +// Handler to eagerly write batches to Python out of order +def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = { + if (arrowBatches.nonEmpty) { +batchWriter.writeBatches(arrowBatches.iterator) +(0 until arrowBatches.length).foreach { i => + batchOrder.append(index + i / arrowBatches.length) --- End diff -- I thought I would need to specify a comparison function, but it looks like Scala can sort a `(Int, Int)` tuple correctly. I agree this would be best, so I'll change it - thanks! --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user sethah commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r199478745 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala --- @@ -34,17 +33,19 @@ private[sql] object PythonSQLUtils { } /** - * Python Callable function to convert ArrowPayloads into a [[DataFrame]]. + * Python callable function to read a file in Arrow stream format and create a [[DataFrame]] + * using each serialized ArrowRecordBatch as a partition. * - * @param payloadRDD A JavaRDD of ArrowPayloads. - * @param schemaString JSON Formatted Schema for ArrowPayloads. * @param sqlContext The active [[SQLContext]]. - * @return The converted [[DataFrame]]. + * @param filename File to read the Arrow stream from. + * @param schemaString JSON Formatted Spark schema for Arrow batches. + * @return A new [[DataFrame]]. */ - def arrowPayloadToDataFrame( - payloadRDD: JavaRDD[Array[Byte]], - schemaString: String, - sqlContext: SQLContext): DataFrame = { -ArrowConverters.toDataFrame(payloadRDD, schemaString, sqlContext) + def arrowReadStreamFromFile( --- End diff -- Can we call it `arrowFileToDataFrame` or something... `arrowReadStreamFromFile` and `readArrowStreamFromFile` are just too similar... --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user sethah commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r199498622 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala --- @@ -38,70 +39,75 @@ import org.apache.spark.util.Utils /** - * Store Arrow data in a form that can be serialized by Spark and served to a Python process. + * Writes serialized ArrowRecordBatches to a DataOutputStream in the Arrow stream format. */ -private[sql] class ArrowPayload private[sql] (payload: Array[Byte]) extends Serializable { +private[sql] class ArrowBatchStreamWriter( +schema: StructType, +out: OutputStream, +timeZoneId: String) { - /** - * Convert the ArrowPayload to an ArrowRecordBatch. - */ - def loadBatch(allocator: BufferAllocator): ArrowRecordBatch = { -ArrowConverters.byteArrayToBatch(payload, allocator) - } + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + val writeChannel = new WriteChannel(Channels.newChannel(out)) + + // Write the Arrow schema first, before batches + MessageSerializer.serialize(writeChannel, arrowSchema) /** - * Get the ArrowPayload as a type that can be served to Python. + * Consume iterator to write each serialized ArrowRecordBatch to the stream. */ - def asPythonSerializable: Array[Byte] = payload -} - -/** - * Iterator interface to iterate over Arrow record batches and return rows - */ -private[sql] trait ArrowRowIterator extends Iterator[InternalRow] { + def writeBatches(arrowBatchIter: Iterator[Array[Byte]]): Unit = { +arrowBatchIter.foreach { batchBytes => + writeChannel.write(batchBytes) +} + } /** - * Return the schema loaded from the Arrow record batch being iterated over + * End the Arrow stream, does not close output stream. */ - def schema: StructType + def end(): Unit = { +// Write End of Stream --- End diff -- this comment can be removed I think --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user sethah commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r199502733 --- Diff: core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala --- @@ -398,6 +398,25 @@ private[spark] object PythonRDD extends Logging { * data collected from this job, and the secret for authentication. */ def serveIterator(items: Iterator[_], threadName: String): Array[Any] = { +serveToStream(threadName) { out => + writeIteratorToStream(items, new DataOutputStream(out)) +} + } + + /** + * Create a socket server and background thread to execute the block of code + * for the given DataOutputStream. + * + * The socket server can only accept one connection, or close if no connection + * in 15 seconds. + * + * Once a connection comes in, it will execute the block of code and pass in + * the socket output stream. + * + * The thread will terminate after the block of code is executed or any + * exceptions happen. + */ + private[spark] def serveToStream(threadName: String)(block: OutputStream => Unit): Array[Any] = { --- End diff -- can you change `block` to `writeFunc` or something? `block` makes me think of thread blocking --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user sethah commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r199482134 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala --- @@ -183,34 +182,111 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { -val in = new ByteArrayReadableSeekableByteChannel(batchBytes) -val reader = new ArrowFileReader(in, allocator) - -// Read a batch from a byte stream, ensure the reader is closed -Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch -} { - reader.close() -} +val in = new ByteArrayInputStream(batchBytes) +MessageSerializer.deserializeMessageBatch(new ReadChannel(Channels.newChannel(in)), allocator) + .asInstanceOf[ArrowRecordBatch] // throws IOException } + /** + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + */ private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { -val rdd = payloadRDD.rdd.mapPartitions { iter => +val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] +val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone +val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } -val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd, schema) } + + /** + * Read a file as an Arrow stream and parallelize as an RDD of serialized ArrowRecordBatches. + */ + private[sql] def readArrowStreamFromFile( + sqlContext: SQLContext, + filename: String): JavaRDD[Array[Byte]] = { +val fileStream = new FileInputStream(filename) +try { + // Create array so that we can safely close the file + val batches = getBatchesFromStream(fileStream.getChannel).toArray + // Parallelize the record batches to create an RDD + JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) +} finally { + fileStream.close() +} + } + + /** + * Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches. + */ + private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = { + +// TODO: this could be moved to Arrow +def readMessageLength(in: ReadChannel): Int = { + val buffer = ByteBuffer.allocate(4) + if (in.readFully(buffer) != 4) { +return 0 + } + MessageSerializer.bytesToInt(buffer.array()) +} + +// TODO: this could be moved to Arrow +def loadMessage(in: ReadChannel, messageLength: Int, buffer: ByteBuffer): Message = { + if (in.readFully(buffer) != messageLength) { +throw new java.io.IOException( + "Unexpected end of stream trying to read message.") + } + buffer.rewind() + Message.getRootAsMessage(buffer) +} + + +// Create an iterator to get each serialized ArrowRecordBatch from a stream +new Iterator[Array[Byte]] { + val inputChannel = new ReadChannel(in) --- End diff -- do we not need to close this when the iterator has been consumed? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user sethah commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r199476976 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala --- @@ -183,34 +182,111 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { -val in = new ByteArrayReadableSeekableByteChannel(batchBytes) -val reader = new ArrowFileReader(in, allocator) - -// Read a batch from a byte stream, ensure the reader is closed -Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch -} { - reader.close() -} +val in = new ByteArrayInputStream(batchBytes) +MessageSerializer.deserializeMessageBatch(new ReadChannel(Channels.newChannel(in)), allocator) + .asInstanceOf[ArrowRecordBatch] // throws IOException } + /** + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + */ private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { -val rdd = payloadRDD.rdd.mapPartitions { iter => +val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] +val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone +val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } -val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd, schema) } + + /** + * Read a file as an Arrow stream and parallelize as an RDD of serialized ArrowRecordBatches. + */ + private[sql] def readArrowStreamFromFile( + sqlContext: SQLContext, + filename: String): JavaRDD[Array[Byte]] = { +val fileStream = new FileInputStream(filename) +try { + // Create array so that we can safely close the file + val batches = getBatchesFromStream(fileStream.getChannel).toArray + // Parallelize the record batches to create an RDD + JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) +} finally { + fileStream.close() +} + } + + /** + * Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches. + */ + private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = { + +// TODO: this could be moved to Arrow +def readMessageLength(in: ReadChannel): Int = { + val buffer = ByteBuffer.allocate(4) + if (in.readFully(buffer) != 4) { +return 0 + } + MessageSerializer.bytesToInt(buffer.array()) +} + +// TODO: this could be moved to Arrow +def loadMessage(in: ReadChannel, messageLength: Int, buffer: ByteBuffer): Message = { + if (in.readFully(buffer) != messageLength) { +throw new java.io.IOException( + "Unexpected end of stream trying to read message.") + } + buffer.rewind() + Message.getRootAsMessage(buffer) +} + + +// Create an iterator to get each serialized ArrowRecordBatch from a stream +new Iterator[Array[Byte]] { + val inputChannel = new ReadChannel(in) + var batch: Array[Byte] = readNextBatch() + + override def hasNext: Boolean = batch != null + + override def next(): Array[Byte] = { +val prevBatch = batch +batch = readNextBatch() +prevBatch + } + + def readNextBatch(): Array[Byte] = { +val messageLength = readMessageLength(inputChannel) +if (messageLength == 0) { + return null +} + +val buffer = ByteBuffer.allocate(messageLength) +val msg = loadMessage(inputChannel, messageLength, buffer) +val bodyLength = msg.bodyLength().asInstanceOf[Int] --- End diff -- why not `toInt`? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.a
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user sethah commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r199496002 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala --- @@ -3236,13 +3237,50 @@ class Dataset[T] private[sql]( } /** - * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. + * Collect a Dataset as Arrow batches and serve stream to PySpark. */ private[sql] def collectAsArrowToPython(): Array[Any] = { +val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone + withAction("collectAsArrowToPython", queryExecution) { plan => - val iter: Iterator[Array[Byte]] = -toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable) - PythonRDD.serveIterator(iter, "serve-Arrow") + PythonRDD.serveToStream("serve-Arrow") { outputStream => +val out = new DataOutputStream(outputStream) +val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId) +val arrowBatchRdd = getArrowBatchRdd(plan) +val numPartitions = arrowBatchRdd.partitions.length + +// Batches ordered by index of partition + fractional value of batch # in partition +val batchOrder = new ArrayBuffer[Float]() +var partitionCount = 0 + +// Handler to eagerly write batches to Python out of order +def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = { + if (arrowBatches.nonEmpty) { +batchWriter.writeBatches(arrowBatches.iterator) +(0 until arrowBatches.length).foreach { i => + batchOrder.append(index + i / arrowBatches.length) --- End diff -- This code: `(0 until array.length).map(i => i / array.length)` is guaranteed to produce only zero values isn't it? The code works, since `sortBy` evidently preserves the ordering of equal elements, but you may as well do `batchOrder.append(index)` since it's the same. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user sethah commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r199484323 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala --- @@ -3236,13 +3237,50 @@ class Dataset[T] private[sql]( } /** - * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. + * Collect a Dataset as Arrow batches and serve stream to PySpark. */ private[sql] def collectAsArrowToPython(): Array[Any] = { +val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone + withAction("collectAsArrowToPython", queryExecution) { plan => - val iter: Iterator[Array[Byte]] = -toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable) - PythonRDD.serveIterator(iter, "serve-Arrow") + PythonRDD.serveToStream("serve-Arrow") { outputStream => +val out = new DataOutputStream(outputStream) +val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId) +val arrowBatchRdd = getArrowBatchRdd(plan) +val numPartitions = arrowBatchRdd.partitions.length + +// Batches ordered by index of partition + fractional value of batch # in partition +val batchOrder = new ArrayBuffer[Float]() +var partitionCount = 0 + +// Handler to eagerly write batches to Python out of order +def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = { + if (arrowBatches.nonEmpty) { +batchWriter.writeBatches(arrowBatches.iterator) +(0 until arrowBatches.length).foreach { i => --- End diff -- intellij would like you to know about `arrowBatches.indices` :grin: --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user sethah commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r199508609 --- Diff: python/pyspark/serializers.py --- @@ -184,27 +184,59 @@ def loads(self, obj): raise NotImplementedError -class ArrowSerializer(FramedSerializer): +class BatchOrderSerializer(Serializer): """ -Serializes bytes as Arrow data with the Arrow file format. +Deserialize a stream of batches followed by batch order information. """ -def dumps(self, batch): +def __init__(self, serializer): +self.serializer = serializer +self.batch_order = [] + +def dump_stream(self, iterator, stream): +return self.serializer.dump_stream(iterator, stream) + +def load_stream(self, stream): +for batch in self.serializer.load_stream(stream): +yield batch +num = read_int(stream) +for i in xrange(num): +index = read_int(stream) +self.batch_order.append(index) +raise StopIteration() + +def get_batch_order(self): --- End diff -- maybe we should initialize `self.batch_order = None`, and add `assert self.batch_order is not None` here. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user sethah commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r199482021 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala --- @@ -183,34 +182,111 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { -val in = new ByteArrayReadableSeekableByteChannel(batchBytes) -val reader = new ArrowFileReader(in, allocator) - -// Read a batch from a byte stream, ensure the reader is closed -Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch -} { - reader.close() -} +val in = new ByteArrayInputStream(batchBytes) +MessageSerializer.deserializeMessageBatch(new ReadChannel(Channels.newChannel(in)), allocator) + .asInstanceOf[ArrowRecordBatch] // throws IOException } + /** + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + */ private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { -val rdd = payloadRDD.rdd.mapPartitions { iter => +val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] +val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone +val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } -val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd, schema) } + + /** + * Read a file as an Arrow stream and parallelize as an RDD of serialized ArrowRecordBatches. + */ + private[sql] def readArrowStreamFromFile( + sqlContext: SQLContext, + filename: String): JavaRDD[Array[Byte]] = { +val fileStream = new FileInputStream(filename) +try { + // Create array so that we can safely close the file + val batches = getBatchesFromStream(fileStream.getChannel).toArray + // Parallelize the record batches to create an RDD + JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) +} finally { + fileStream.close() +} + } + + /** + * Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches. + */ + private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = { + +// TODO: this could be moved to Arrow +def readMessageLength(in: ReadChannel): Int = { + val buffer = ByteBuffer.allocate(4) + if (in.readFully(buffer) != 4) { +return 0 + } + MessageSerializer.bytesToInt(buffer.array()) +} + +// TODO: this could be moved to Arrow +def loadMessage(in: ReadChannel, messageLength: Int, buffer: ByteBuffer): Message = { + if (in.readFully(buffer) != messageLength) { +throw new java.io.IOException( + "Unexpected end of stream trying to read message.") + } + buffer.rewind() + Message.getRootAsMessage(buffer) +} + + --- End diff -- delete extra line --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user sethah commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r199497456 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala --- @@ -183,34 +182,111 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { -val in = new ByteArrayReadableSeekableByteChannel(batchBytes) -val reader = new ArrowFileReader(in, allocator) - -// Read a batch from a byte stream, ensure the reader is closed -Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch -} { - reader.close() -} +val in = new ByteArrayInputStream(batchBytes) +MessageSerializer.deserializeMessageBatch(new ReadChannel(Channels.newChannel(in)), allocator) + .asInstanceOf[ArrowRecordBatch] // throws IOException } + /** + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + */ private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { -val rdd = payloadRDD.rdd.mapPartitions { iter => +val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] +val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone +val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } -val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd, schema) } + + /** + * Read a file as an Arrow stream and parallelize as an RDD of serialized ArrowRecordBatches. + */ + private[sql] def readArrowStreamFromFile( + sqlContext: SQLContext, + filename: String): JavaRDD[Array[Byte]] = { +val fileStream = new FileInputStream(filename) +try { + // Create array so that we can safely close the file + val batches = getBatchesFromStream(fileStream.getChannel).toArray + // Parallelize the record batches to create an RDD + JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) +} finally { + fileStream.close() +} + } + + /** + * Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches. + */ + private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = { + +// TODO: this could be moved to Arrow +def readMessageLength(in: ReadChannel): Int = { + val buffer = ByteBuffer.allocate(4) + if (in.readFully(buffer) != 4) { +return 0 + } + MessageSerializer.bytesToInt(buffer.array()) +} + +// TODO: this could be moved to Arrow +def loadMessage(in: ReadChannel, messageLength: Int, buffer: ByteBuffer): Message = { + if (in.readFully(buffer) != messageLength) { +throw new java.io.IOException( + "Unexpected end of stream trying to read message.") + } + buffer.rewind() + Message.getRootAsMessage(buffer) +} + + +// Create an iterator to get each serialized ArrowRecordBatch from a stream +new Iterator[Array[Byte]] { + val inputChannel = new ReadChannel(in) + var batch: Array[Byte] = readNextBatch() + + override def hasNext: Boolean = batch != null + + override def next(): Array[Byte] = { +val prevBatch = batch +batch = readNextBatch() +prevBatch + } + + def readNextBatch(): Array[Byte] = { +val messageLength = readMessageLength(inputChannel) +if (messageLength == 0) { + return null +} + +val buffer = ByteBuffer.allocate(messageLength) +val msg = loadMessage(inputChannel, messageLength, buffer) +val bodyLength = msg.bodyLength().asInstanceOf[Int] + +if (msg.headerType() == MessageHeader.RecordBatch) { + val allbuf = ByteBuffer.allocate(4 + messageLength + bodyLength) + allbuf.
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user sethah commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r199499070 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala --- @@ -38,70 +39,75 @@ import org.apache.spark.util.Utils /** - * Store Arrow data in a form that can be serialized by Spark and served to a Python process. + * Writes serialized ArrowRecordBatches to a DataOutputStream in the Arrow stream format. */ -private[sql] class ArrowPayload private[sql] (payload: Array[Byte]) extends Serializable { +private[sql] class ArrowBatchStreamWriter( +schema: StructType, +out: OutputStream, +timeZoneId: String) { - /** - * Convert the ArrowPayload to an ArrowRecordBatch. - */ - def loadBatch(allocator: BufferAllocator): ArrowRecordBatch = { -ArrowConverters.byteArrayToBatch(payload, allocator) - } + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + val writeChannel = new WriteChannel(Channels.newChannel(out)) + + // Write the Arrow schema first, before batches + MessageSerializer.serialize(writeChannel, arrowSchema) /** - * Get the ArrowPayload as a type that can be served to Python. + * Consume iterator to write each serialized ArrowRecordBatch to the stream. */ - def asPythonSerializable: Array[Byte] = payload -} - -/** - * Iterator interface to iterate over Arrow record batches and return rows - */ -private[sql] trait ArrowRowIterator extends Iterator[InternalRow] { + def writeBatches(arrowBatchIter: Iterator[Array[Byte]]): Unit = { +arrowBatchIter.foreach { batchBytes => --- End diff -- nit: `arrowBatchIter.foreach(writeChannel.write)` --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user sethah commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r199371158 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala --- @@ -183,34 +182,111 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { -val in = new ByteArrayReadableSeekableByteChannel(batchBytes) -val reader = new ArrowFileReader(in, allocator) - -// Read a batch from a byte stream, ensure the reader is closed -Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch -} { - reader.close() -} +val in = new ByteArrayInputStream(batchBytes) +MessageSerializer.deserializeMessageBatch(new ReadChannel(Channels.newChannel(in)), allocator) + .asInstanceOf[ArrowRecordBatch] // throws IOException } + /** + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + */ private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { -val rdd = payloadRDD.rdd.mapPartitions { iter => +val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] +val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone +val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } -val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd, schema) } + + /** + * Read a file as an Arrow stream and parallelize as an RDD of serialized ArrowRecordBatches. + */ + private[sql] def readArrowStreamFromFile( + sqlContext: SQLContext, + filename: String): JavaRDD[Array[Byte]] = { +val fileStream = new FileInputStream(filename) +try { + // Create array so that we can safely close the file + val batches = getBatchesFromStream(fileStream.getChannel).toArray + // Parallelize the record batches to create an RDD + JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) +} finally { + fileStream.close() +} + } + + /** + * Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches. + */ + private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = { + +// TODO: this could be moved to Arrow +def readMessageLength(in: ReadChannel): Int = { + val buffer = ByteBuffer.allocate(4) + if (in.readFully(buffer) != 4) { +return 0 + } + MessageSerializer.bytesToInt(buffer.array()) +} + +// TODO: this could be moved to Arrow +def loadMessage(in: ReadChannel, messageLength: Int, buffer: ByteBuffer): Message = { + if (in.readFully(buffer) != messageLength) { +throw new java.io.IOException( + "Unexpected end of stream trying to read message.") + } + buffer.rewind() + Message.getRootAsMessage(buffer) +} + + +// Create an iterator to get each serialized ArrowRecordBatch from a stream +new Iterator[Array[Byte]] { + val inputChannel = new ReadChannel(in) + var batch: Array[Byte] = readNextBatch() + + override def hasNext: Boolean = batch != null + + override def next(): Array[Byte] = { +val prevBatch = batch +batch = readNextBatch() +prevBatch + } + + def readNextBatch(): Array[Byte] = { --- End diff -- Mostly I'm just curious, is there any point in making this a private method? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user ueshin commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r199384074 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala --- @@ -3236,13 +3237,50 @@ class Dataset[T] private[sql]( } /** - * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. + * Collect a Dataset as Arrow batches and serve stream to PySpark. */ private[sql] def collectAsArrowToPython(): Array[Any] = { +val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone + withAction("collectAsArrowToPython", queryExecution) { plan => - val iter: Iterator[Array[Byte]] = -toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable) - PythonRDD.serveIterator(iter, "serve-Arrow") + PythonRDD.serveToStream("serve-Arrow") { outputStream => +val out = new DataOutputStream(outputStream) +val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId) +val arrowBatchRdd = getArrowBatchRdd(plan) +val numPartitions = arrowBatchRdd.partitions.length + +// Batches ordered by index of partition + fractional value of batch # in partition +val batchOrder = new ArrayBuffer[Float]() +var partitionCount = 0 + +// Handler to eagerly write batches to Python out of order +def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = { + if (arrowBatches.nonEmpty) { +batchWriter.writeBatches(arrowBatches.iterator) +(0 until arrowBatches.length).foreach { i => + batchOrder.append(index + i / arrowBatches.length) --- End diff -- Maybe we should cast to Float when calculating `i / arrowBatches.length`, otherwise the same values are appended? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user ueshin commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r199384249 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala --- @@ -3236,13 +3237,50 @@ class Dataset[T] private[sql]( } /** - * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. + * Collect a Dataset as Arrow batches and serve stream to PySpark. */ private[sql] def collectAsArrowToPython(): Array[Any] = { +val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone + withAction("collectAsArrowToPython", queryExecution) { plan => - val iter: Iterator[Array[Byte]] = -toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable) - PythonRDD.serveIterator(iter, "serve-Arrow") + PythonRDD.serveToStream("serve-Arrow") { outputStream => +val out = new DataOutputStream(outputStream) +val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId) +val arrowBatchRdd = getArrowBatchRdd(plan) +val numPartitions = arrowBatchRdd.partitions.length + +// Batches ordered by index of partition + fractional value of batch # in partition +val batchOrder = new ArrayBuffer[Float]() +var partitionCount = 0 + +// Handler to eagerly write batches to Python out of order +def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = { + if (arrowBatches.nonEmpty) { +batchWriter.writeBatches(arrowBatches.iterator) +(0 until arrowBatches.length).foreach { i => + batchOrder.append(index + i / arrowBatches.length) --- End diff -- Btw, how about using tuple `(Int, Int)` instead of `Float`, and `batchOrder.append((index, i))` ? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r199287248 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala --- @@ -3236,13 +3237,50 @@ class Dataset[T] private[sql]( } /** - * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. + * Collect a Dataset as Arrow batches and serve stream to PySpark. */ private[sql] def collectAsArrowToPython(): Array[Any] = { +val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone + withAction("collectAsArrowToPython", queryExecution) { plan => - val iter: Iterator[Array[Byte]] = -toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable) - PythonRDD.serveIterator(iter, "serve-Arrow") + PythonRDD.serveToStream("serve-Arrow") { outputStream => +val out = new DataOutputStream(outputStream) +val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId) +val arrowBatchRdd = getArrowBatchRdd(plan) +val numPartitions = arrowBatchRdd.partitions.length + +// Batches ordered by index of partition + batch number for that partition +val batchOrder = new ArrayBuffer[Int]() +var partitionCount = 0 + +// Handler to eagerly write batches to Python out of order +def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = { + if (arrowBatches.nonEmpty) { +batchWriter.writeBatches(arrowBatches.iterator) +(0 until arrowBatches.length).foreach { i => + batchOrder.append(index + i) +} + } + partitionCount += 1 + + // After last batch, end the stream and write batch order + if (partitionCount == numPartitions) { +batchWriter.end() +out.writeInt(batchOrder.length) +// Batch order indices are from 0 to N-1 batches, sorted by order they arrived +batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, i) => --- End diff -- Yeah, looks like something wasn't quite right with the batch indexing... I fixed it and added your test. Thanks @sethah ! --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user sethah commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r199275753 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala --- @@ -3236,13 +3237,50 @@ class Dataset[T] private[sql]( } /** - * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. + * Collect a Dataset as Arrow batches and serve stream to PySpark. */ private[sql] def collectAsArrowToPython(): Array[Any] = { +val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone + withAction("collectAsArrowToPython", queryExecution) { plan => - val iter: Iterator[Array[Byte]] = -toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable) - PythonRDD.serveIterator(iter, "serve-Arrow") + PythonRDD.serveToStream("serve-Arrow") { outputStream => +val out = new DataOutputStream(outputStream) +val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId) +val arrowBatchRdd = getArrowBatchRdd(plan) +val numPartitions = arrowBatchRdd.partitions.length + +// Batches ordered by index of partition + batch number for that partition +val batchOrder = new ArrayBuffer[Int]() +var partitionCount = 0 + +// Handler to eagerly write batches to Python out of order +def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = { + if (arrowBatches.nonEmpty) { +batchWriter.writeBatches(arrowBatches.iterator) +(0 until arrowBatches.length).foreach { i => + batchOrder.append(index + i) +} + } + partitionCount += 1 + + // After last batch, end the stream and write batch order + if (partitionCount == numPartitions) { +batchWriter.end() +out.writeInt(batchOrder.length) +// Batch order indices are from 0 to N-1 batches, sorted by order they arrived +batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, i) => --- End diff -- Does this logic do what you intend? It interleaves batches. ```python df = spark.range(64).toDF("a") df.rdd.getNumPartitions() # 8 spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", 4) pdf = df.toPandas() pdf['a'].values # array([ 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 16, 17, 18, 19, 12, # 13, 14, 15, 20, 21, 22, 23, 24, 25, 26, 27, 32, 33, 34, 35, 28, 29, # 30, 31, 36, 37, 38, 39, 40, 41, 42, 43, 48, 49, 50, 51, 44, 45, 46, # 47, 56, 57, 58, 59, 52, 53, 54, 55, 60, 61, 62, 63]) ``` --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r199241520 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala --- @@ -183,34 +182,111 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { -val in = new ByteArrayReadableSeekableByteChannel(batchBytes) -val reader = new ArrowFileReader(in, allocator) - -// Read a batch from a byte stream, ensure the reader is closed -Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch -} { - reader.close() -} +val in = new ByteArrayInputStream(batchBytes) +MessageSerializer.deserializeMessageBatch(new ReadChannel(Channels.newChannel(in)), allocator) + .asInstanceOf[ArrowRecordBatch] // throws IOException } + /** + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + */ private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { -val rdd = payloadRDD.rdd.mapPartitions { iter => +val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] +val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone +val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } -val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd, schema) } + + /** + * Read a file as an Arrow stream and parallelize as an RDD of serialized ArrowRecordBatches. + */ + private[sql] def readArrowStreamFromFile( + sqlContext: SQLContext, + filename: String): JavaRDD[Array[Byte]] = { +val fileStream = new FileInputStream(filename) +try { + // Create array so that we can safely close the file + val batches = getBatchesFromStream(fileStream.getChannel).toArray + // Parallelize the record batches to create an RDD + JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) +} finally { + fileStream.close() +} + } + + /** + * Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches. + */ + private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = { + +// TODO: this could be moved to Arrow +def readMessageLength(in: ReadChannel): Int = { + val buffer = ByteBuffer.allocate(4) + if (in.readFully(buffer) != 4) { +return 0 + } + MessageSerializer.bytesToInt(buffer.array()) +} + +// TODO: this could be moved to Arrow +def loadMessage(in: ReadChannel, messageLength: Int, buffer: ByteBuffer): Message = { + if (in.readFully(buffer) != messageLength) { +throw new java.io.IOException( + "Unexpected end of stream trying to read message.") + } + buffer.rewind() + Message.getRootAsMessage(buffer) +} + + +// Create an iterator to get each serialized ArrowRecordBatch from a stream +new Iterator[Array[Byte]] { + val inputChannel = new ReadChannel(in) + var batch: Array[Byte] = readNextBatch() + + override def hasNext: Boolean = batch != null + + override def next(): Array[Byte] = { +val prevBatch = batch +batch = readNextBatch() +prevBatch + } + + def readNextBatch(): Array[Byte] = { +val messageLength = readMessageLength(inputChannel) +if (messageLength == 0) { + return null +} + +val buffer = ByteBuffer.allocate(messageLength) +val msg = loadMessage(inputChannel, messageLength, buffer) --- End diff -- I'll propose something different under https://github.com/apache/arrow/pull/2139, we can continue to discuss there --- - To unsubscri
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user icexelloss commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r199193918 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala --- @@ -183,34 +182,111 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { -val in = new ByteArrayReadableSeekableByteChannel(batchBytes) -val reader = new ArrowFileReader(in, allocator) - -// Read a batch from a byte stream, ensure the reader is closed -Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch -} { - reader.close() -} +val in = new ByteArrayInputStream(batchBytes) +MessageSerializer.deserializeMessageBatch(new ReadChannel(Channels.newChannel(in)), allocator) + .asInstanceOf[ArrowRecordBatch] // throws IOException } + /** + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + */ private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { -val rdd = payloadRDD.rdd.mapPartitions { iter => +val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] +val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone +val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } -val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd, schema) } + + /** + * Read a file as an Arrow stream and parallelize as an RDD of serialized ArrowRecordBatches. + */ + private[sql] def readArrowStreamFromFile( + sqlContext: SQLContext, + filename: String): JavaRDD[Array[Byte]] = { +val fileStream = new FileInputStream(filename) +try { + // Create array so that we can safely close the file + val batches = getBatchesFromStream(fileStream.getChannel).toArray + // Parallelize the record batches to create an RDD + JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) +} finally { + fileStream.close() +} + } + + /** + * Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches. + */ + private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = { + +// TODO: this could be moved to Arrow +def readMessageLength(in: ReadChannel): Int = { + val buffer = ByteBuffer.allocate(4) + if (in.readFully(buffer) != 4) { +return 0 + } + MessageSerializer.bytesToInt(buffer.array()) +} + +// TODO: this could be moved to Arrow +def loadMessage(in: ReadChannel, messageLength: Int, buffer: ByteBuffer): Message = { + if (in.readFully(buffer) != messageLength) { +throw new java.io.IOException( + "Unexpected end of stream trying to read message.") + } + buffer.rewind() + Message.getRootAsMessage(buffer) +} + + +// Create an iterator to get each serialized ArrowRecordBatch from a stream +new Iterator[Array[Byte]] { + val inputChannel = new ReadChannel(in) + var batch: Array[Byte] = readNextBatch() + + override def hasNext: Boolean = batch != null + + override def next(): Array[Byte] = { +val prevBatch = batch +batch = readNextBatch() +prevBatch + } + + def readNextBatch(): Array[Byte] = { +val messageLength = readMessageLength(inputChannel) +if (messageLength == 0) { + return null +} + +val buffer = ByteBuffer.allocate(messageLength) +val msg = loadMessage(inputChannel, messageLength, buffer) --- End diff -- I see. IMHO the class still touches too many Arrow stream low-level details, but maybe necessary because of the performance improvement. @BryanCutler does all the speed up of `createDataF
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r199001584 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala --- @@ -183,34 +182,111 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { -val in = new ByteArrayReadableSeekableByteChannel(batchBytes) -val reader = new ArrowFileReader(in, allocator) - -// Read a batch from a byte stream, ensure the reader is closed -Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch -} { - reader.close() -} +val in = new ByteArrayInputStream(batchBytes) +MessageSerializer.deserializeMessageBatch(new ReadChannel(Channels.newChannel(in)), allocator) + .asInstanceOf[ArrowRecordBatch] // throws IOException } + /** + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + */ private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { -val rdd = payloadRDD.rdd.mapPartitions { iter => +val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] +val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone +val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } -val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd, schema) } + + /** + * Read a file as an Arrow stream and parallelize as an RDD of serialized ArrowRecordBatches. + */ + private[sql] def readArrowStreamFromFile( + sqlContext: SQLContext, + filename: String): JavaRDD[Array[Byte]] = { +val fileStream = new FileInputStream(filename) +try { + // Create array so that we can safely close the file + val batches = getBatchesFromStream(fileStream.getChannel).toArray + // Parallelize the record batches to create an RDD + JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) +} finally { + fileStream.close() +} + } + + /** + * Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches. + */ + private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = { + +// TODO: this could be moved to Arrow +def readMessageLength(in: ReadChannel): Int = { + val buffer = ByteBuffer.allocate(4) + if (in.readFully(buffer) != 4) { +return 0 + } + MessageSerializer.bytesToInt(buffer.array()) +} + +// TODO: this could be moved to Arrow +def loadMessage(in: ReadChannel, messageLength: Int, buffer: ByteBuffer): Message = { + if (in.readFully(buffer) != messageLength) { +throw new java.io.IOException( + "Unexpected end of stream trying to read message.") + } + buffer.rewind() + Message.getRootAsMessage(buffer) +} + + +// Create an iterator to get each serialized ArrowRecordBatch from a stream +new Iterator[Array[Byte]] { + val inputChannel = new ReadChannel(in) + var batch: Array[Byte] = readNextBatch() + + override def hasNext: Boolean = batch != null + + override def next(): Array[Byte] = { +val prevBatch = batch +batch = readNextBatch() +prevBatch + } + + def readNextBatch(): Array[Byte] = { +val messageLength = readMessageLength(inputChannel) +if (messageLength == 0) { + return null +} + +val buffer = ByteBuffer.allocate(messageLength) +val msg = loadMessage(inputChannel, messageLength, buffer) --- End diff -- We can't really use it as is for a couple of reasons. It just returns a flatbuffer `Message`, which means that to get it back into another buffer, it requires doing a lot of work with a `FlatBufferBuild
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user icexelloss commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r198994936 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala --- @@ -183,34 +182,111 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { -val in = new ByteArrayReadableSeekableByteChannel(batchBytes) -val reader = new ArrowFileReader(in, allocator) - -// Read a batch from a byte stream, ensure the reader is closed -Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch -} { - reader.close() -} +val in = new ByteArrayInputStream(batchBytes) +MessageSerializer.deserializeMessageBatch(new ReadChannel(Channels.newChannel(in)), allocator) + .asInstanceOf[ArrowRecordBatch] // throws IOException } + /** + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + */ private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { -val rdd = payloadRDD.rdd.mapPartitions { iter => +val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] +val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone +val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } -val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd, schema) } + + /** + * Read a file as an Arrow stream and parallelize as an RDD of serialized ArrowRecordBatches. + */ + private[sql] def readArrowStreamFromFile( + sqlContext: SQLContext, + filename: String): JavaRDD[Array[Byte]] = { +val fileStream = new FileInputStream(filename) +try { + // Create array so that we can safely close the file + val batches = getBatchesFromStream(fileStream.getChannel).toArray + // Parallelize the record batches to create an RDD + JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) +} finally { + fileStream.close() +} + } + + /** + * Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches. + */ + private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = { + +// TODO: this could be moved to Arrow +def readMessageLength(in: ReadChannel): Int = { + val buffer = ByteBuffer.allocate(4) + if (in.readFully(buffer) != 4) { +return 0 + } + MessageSerializer.bytesToInt(buffer.array()) +} + +// TODO: this could be moved to Arrow +def loadMessage(in: ReadChannel, messageLength: Int, buffer: ByteBuffer): Message = { + if (in.readFully(buffer) != messageLength) { +throw new java.io.IOException( + "Unexpected end of stream trying to read message.") + } + buffer.rewind() + Message.getRootAsMessage(buffer) +} + + +// Create an iterator to get each serialized ArrowRecordBatch from a stream +new Iterator[Array[Byte]] { + val inputChannel = new ReadChannel(in) + var batch: Array[Byte] = readNextBatch() + + override def hasNext: Boolean = batch != null + + override def next(): Array[Byte] = { +val prevBatch = batch +batch = readNextBatch() +prevBatch + } + + def readNextBatch(): Array[Byte] = { +val messageLength = readMessageLength(inputChannel) +if (messageLength == 0) { + return null +} + +val buffer = ByteBuffer.allocate(messageLength) +val msg = loadMessage(inputChannel, messageLength, buffer) --- End diff -- Can we use `MessageChannelReader.readNextMessage()`? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r198925610 --- Diff: dev/make-distribution.sh --- @@ -168,10 +168,10 @@ export MAVEN_OPTS="${MAVEN_OPTS:--Xmx2g -XX:ReservedCodeCacheSize=512m}" BUILD_COMMAND=("$MVN" -T 1C clean package -DskipTests $@) # Actually build the jar -echo -e "\nBuilding with..." -echo -e "\$ ${BUILD_COMMAND[@]}\n" +#echo -e "\nBuilding with..." +#echo -e "\$ ${BUILD_COMMAND[@]}\n" -"${BUILD_COMMAND[@]}" +#"${BUILD_COMMAND[@]}" --- End diff -- Oops, yeah thanks for pointing it out. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user felixcheung commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r198723785 --- Diff: dev/make-distribution.sh --- @@ -168,10 +168,10 @@ export MAVEN_OPTS="${MAVEN_OPTS:--Xmx2g -XX:ReservedCodeCacheSize=512m}" BUILD_COMMAND=("$MVN" -T 1C clean package -DskipTests $@) # Actually build the jar -echo -e "\nBuilding with..." -echo -e "\$ ${BUILD_COMMAND[@]}\n" +#echo -e "\nBuilding with..." +#echo -e "\$ ${BUILD_COMMAND[@]}\n" -"${BUILD_COMMAND[@]}" +#"${BUILD_COMMAND[@]}" --- End diff -- did you comment this out for local test? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r198660818 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala --- @@ -183,34 +182,111 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { -val in = new ByteArrayReadableSeekableByteChannel(batchBytes) -val reader = new ArrowFileReader(in, allocator) - -// Read a batch from a byte stream, ensure the reader is closed -Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch -} { - reader.close() -} +val in = new ByteArrayInputStream(batchBytes) +MessageSerializer.deserializeMessageBatch(new ReadChannel(Channels.newChannel(in)), allocator) + .asInstanceOf[ArrowRecordBatch] // throws IOException } + /** + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + */ private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { -val rdd = payloadRDD.rdd.mapPartitions { iter => +val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] +val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone +val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } -val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd, schema) } + + /** + * Read a file as an Arrow stream and parallelize as an RDD of serialized ArrowRecordBatches. + */ + private[sql] def readArrowStreamFromFile( + sqlContext: SQLContext, + filename: String): JavaRDD[Array[Byte]] = { +val fileStream = new FileInputStream(filename) +try { + // Create array so that we can safely close the file + val batches = getBatchesFromStream(fileStream.getChannel).toArray + // Parallelize the record batches to create an RDD + JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) +} finally { + fileStream.close() +} + } + + /** + * Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches. + */ + private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = { + +// TODO: this could be moved to Arrow +def readMessageLength(in: ReadChannel): Int = { + val buffer = ByteBuffer.allocate(4) + if (in.readFully(buffer) != 4) { +return 0 + } + MessageSerializer.bytesToInt(buffer.array()) +} + +// TODO: this could be moved to Arrow +def loadMessage(in: ReadChannel, messageLength: Int, buffer: ByteBuffer): Message = { + if (in.readFully(buffer) != messageLength) { +throw new java.io.IOException( + "Unexpected end of stream trying to read message.") + } + buffer.rewind() + Message.getRootAsMessage(buffer) +} + + +// Create an iterator to get each serialized ArrowRecordBatch from a stream +new Iterator[Array[Byte]] { + val inputChannel = new ReadChannel(in) + var batch: Array[Byte] = readNextBatch() + + override def hasNext: Boolean = batch != null + + override def next(): Array[Byte] = { +val prevBatch = batch +batch = readNextBatch() +prevBatch + } + + def readNextBatch(): Array[Byte] = { +val messageLength = readMessageLength(inputChannel) +if (messageLength == 0) { + return null +} + +val buffer = ByteBuffer.allocate(messageLength) +val msg = loadMessage(inputChannel, messageLength, buffer) --- End diff -- The only issue with this is that it is aware that a message is preceded by the message length and that a length of zero indicates no more messages. Ideally, this logic would be abstracted to Arrow...
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user ueshin commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r198000610 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala --- @@ -3236,13 +3236,49 @@ class Dataset[T] private[sql]( } /** - * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. + * Collect a Dataset as Arrow batches and serve stream to PySpark. */ private[sql] def collectAsArrowToPython(): Array[Any] = { +val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone + withAction("collectAsArrowToPython", queryExecution) { plan => - val iter: Iterator[Array[Byte]] = -toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable) - PythonRDD.serveIterator(iter, "serve-Arrow") + PythonRDD.serveToStream("serve-Arrow") { out => +val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId) +val arrowBatchRdd = getArrowBatchRdd(plan) +val numPartitions = arrowBatchRdd.partitions.length + +// Store collection results for worst case of 1 to N-1 partitions +val results = new Array[Array[Array[Byte]]](numPartitions - 1) +var lastIndex = -1 // index of last partition written + +// Handler to eagerly write partitions to Python in order +def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = { + // If result is from next partition in order + if (index - 1 == lastIndex) { +batchWriter.writeBatches(arrowBatches.iterator) +lastIndex += 1 +// Write stored partitions that come next in order +while (lastIndex < results.length && results(lastIndex) != null) { + batchWriter.writeBatches(results(lastIndex).iterator) + results(lastIndex) = null + lastIndex += 1 +} +// After last batch, end the stream +if (lastIndex == results.length) { + batchWriter.end() +} + } else { +// Store partitions received out of order +results(index - 1) = arrowBatches + } +} + +sparkSession.sparkContext.runJob( + arrowBatchRdd, + (ctx: TaskContext, it: Iterator[Array[Byte]]) => it.toArray, --- End diff -- Oh, I see. In that case, we need to do `it.toArray`. Thanks. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r197932884 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala --- @@ -3236,13 +3236,49 @@ class Dataset[T] private[sql]( } /** - * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. + * Collect a Dataset as Arrow batches and serve stream to PySpark. */ private[sql] def collectAsArrowToPython(): Array[Any] = { +val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone + withAction("collectAsArrowToPython", queryExecution) { plan => - val iter: Iterator[Array[Byte]] = -toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable) - PythonRDD.serveIterator(iter, "serve-Arrow") + PythonRDD.serveToStream("serve-Arrow") { out => +val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId) +val arrowBatchRdd = getArrowBatchRdd(plan) +val numPartitions = arrowBatchRdd.partitions.length + +// Store collection results for worst case of 1 to N-1 partitions +val results = new Array[Array[Array[Byte]]](numPartitions - 1) +var lastIndex = -1 // index of last partition written + +// Handler to eagerly write partitions to Python in order +def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = { + // If result is from next partition in order + if (index - 1 == lastIndex) { +batchWriter.writeBatches(arrowBatches.iterator) +lastIndex += 1 +// Write stored partitions that come next in order +while (lastIndex < results.length && results(lastIndex) != null) { + batchWriter.writeBatches(results(lastIndex).iterator) + results(lastIndex) = null + lastIndex += 1 +} +// After last batch, end the stream +if (lastIndex == results.length) { + batchWriter.end() +} + } else { +// Store partitions received out of order +results(index - 1) = arrowBatches + } +} + +sparkSession.sparkContext.runJob( + arrowBatchRdd, + (ctx: TaskContext, it: Iterator[Array[Byte]]) => it.toArray, --- End diff -- Looking at this again, `it.toArray` is run on the executor, which ends up doing the same thing as `collect()` and then `handlePartitions` is run on the results of that in the driver. The task results need to be serialized, so I'm not sure if we can avoid `it.toArray` here, any thoughts @ueshin ? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r197575182 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala --- @@ -183,34 +182,131 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { -val in = new ByteArrayReadableSeekableByteChannel(batchBytes) -val reader = new ArrowFileReader(in, allocator) - -// Read a batch from a byte stream, ensure the reader is closed -Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch -} { - reader.close() -} +val in = new ByteArrayInputStream(batchBytes) +MessageSerializer.deserializeMessageBatch(new ReadChannel(Channels.newChannel(in)), allocator) + .asInstanceOf[ArrowRecordBatch] // throws IOException } + /** + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + */ private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { -val rdd = payloadRDD.rdd.mapPartitions { iter => +val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] +val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone +val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } -val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd, schema) } + + /** + * Read a file as an Arrow stream and return an RDD of serialized ArrowRecordBatches. + */ + private[sql] def readArrowStreamFromFile(sqlContext: SQLContext, filename: String): + JavaRDD[Array[Byte]] = { +val fileStream = new FileInputStream(filename) +try { + // Create array so that we can safely close the file + val batches = getBatchesFromStream(fileStream.getChannel).toArray + // Parallelize the record batches to create an RDD + JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) +} finally { + fileStream.close() +} + } + + /** + * Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches. + */ + private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = { + +// TODO: simplify in super class +class RecordBatchMessageReader(inputChannel: SeekableByteChannel) { --- End diff -- I was just referring to the two static functions in the previous post. These will contain most of the low level operations inside for stability. I'm not sure we need a new interface to handle this case, it's probably not a common use case. I'll just implement what I thought and maybe it will be more clear. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user icexelloss commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r197552655 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala --- @@ -183,34 +182,131 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { -val in = new ByteArrayReadableSeekableByteChannel(batchBytes) -val reader = new ArrowFileReader(in, allocator) - -// Read a batch from a byte stream, ensure the reader is closed -Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch -} { - reader.close() -} +val in = new ByteArrayInputStream(batchBytes) +MessageSerializer.deserializeMessageBatch(new ReadChannel(Channels.newChannel(in)), allocator) + .asInstanceOf[ArrowRecordBatch] // throws IOException } + /** + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + */ private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { -val rdd = payloadRDD.rdd.mapPartitions { iter => +val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] +val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone +val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } -val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd, schema) } + + /** + * Read a file as an Arrow stream and return an RDD of serialized ArrowRecordBatches. + */ + private[sql] def readArrowStreamFromFile(sqlContext: SQLContext, filename: String): + JavaRDD[Array[Byte]] = { +val fileStream = new FileInputStream(filename) +try { + // Create array so that we can safely close the file + val batches = getBatchesFromStream(fileStream.getChannel).toArray + // Parallelize the record batches to create an RDD + JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) +} finally { + fileStream.close() +} + } + + /** + * Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches. + */ + private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = { + +// TODO: simplify in super class +class RecordBatchMessageReader(inputChannel: SeekableByteChannel) { --- End diff -- Yeah.. one way to do it is to write a new MessageReader interface to read Arrow message from a Channel: ``` public class OnHeapMessageChannelReader { /** * Read the next message in the sequence. * * @return The read message or null if reached the end of the message sequence * @throws IOException */ Message readNextMessage() throws IOException; /** * When a message is followed by a body of data, read that data into an ArrowBuf. This should * only be called when a Message has a body length > 0. * * @param message Read message that is followed by a body of data * @param allocator BufferAllocator to allocate memory for body data * @return An ArrowBuf containing the body of the message that was read * @throws IOException */ ByteBuffer readMessageBody(Message message) throws IOException; ... } ``` We might need to duplicate some logic in https://github.com/apache/arrow/blob/master/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/MessageChannelReader.java#L33 For record batches, it's not too bad because the logic is pretty simple, but the down side is we will be using low level APIs of Arrow, which might not be guaranteed to be stable . @BryanCutler what kind of static function do you think we need to add on the Arrow side? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r197538429 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala --- @@ -183,34 +182,131 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { -val in = new ByteArrayReadableSeekableByteChannel(batchBytes) -val reader = new ArrowFileReader(in, allocator) - -// Read a batch from a byte stream, ensure the reader is closed -Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch -} { - reader.close() -} +val in = new ByteArrayInputStream(batchBytes) +MessageSerializer.deserializeMessageBatch(new ReadChannel(Channels.newChannel(in)), allocator) + .asInstanceOf[ArrowRecordBatch] // throws IOException } + /** + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + */ private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { -val rdd = payloadRDD.rdd.mapPartitions { iter => +val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] +val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone +val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } -val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd, schema) } + + /** + * Read a file as an Arrow stream and return an RDD of serialized ArrowRecordBatches. + */ + private[sql] def readArrowStreamFromFile(sqlContext: SQLContext, filename: String): + JavaRDD[Array[Byte]] = { +val fileStream = new FileInputStream(filename) +try { + // Create array so that we can safely close the file + val batches = getBatchesFromStream(fileStream.getChannel).toArray + // Parallelize the record batches to create an RDD + JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) +} finally { + fileStream.close() +} + } + + /** + * Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches. + */ + private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = { + +// TODO: simplify in super class +class RecordBatchMessageReader(inputChannel: SeekableByteChannel) { --- End diff -- Ok... I don't really need to subclass `MessageChannelReader` here. What if instead we just make a couple of static functions on the Arrow side to help with the details of processing messages, like: ```Java public class MessageChannelReader { ... public static Integer readMessageLength(ReadChannel in) {..} public static Message loadMessage(ReadChannel in, int messageLength, ByteBuffer buffer) {..} } ``` Is that better? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user icexelloss commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r197149246 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala --- @@ -183,34 +182,131 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { -val in = new ByteArrayReadableSeekableByteChannel(batchBytes) -val reader = new ArrowFileReader(in, allocator) - -// Read a batch from a byte stream, ensure the reader is closed -Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch -} { - reader.close() -} +val in = new ByteArrayInputStream(batchBytes) +MessageSerializer.deserializeMessageBatch(new ReadChannel(Channels.newChannel(in)), allocator) + .asInstanceOf[ArrowRecordBatch] // throws IOException } + /** + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + */ private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { -val rdd = payloadRDD.rdd.mapPartitions { iter => +val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] +val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone +val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } -val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd, schema) } + + /** + * Read a file as an Arrow stream and return an RDD of serialized ArrowRecordBatches. + */ + private[sql] def readArrowStreamFromFile(sqlContext: SQLContext, filename: String): + JavaRDD[Array[Byte]] = { +val fileStream = new FileInputStream(filename) +try { + // Create array so that we can safely close the file + val batches = getBatchesFromStream(fileStream.getChannel).toArray + // Parallelize the record batches to create an RDD + JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) +} finally { + fileStream.close() +} + } + + /** + * Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches. + */ + private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = { + +// TODO: simplify in super class +class RecordBatchMessageReader(inputChannel: SeekableByteChannel) { --- End diff -- Btw I don't mind discuss it in this PR too. Also curious what other people think. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user icexelloss commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r196277485 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala --- @@ -183,34 +182,131 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { -val in = new ByteArrayReadableSeekableByteChannel(batchBytes) -val reader = new ArrowFileReader(in, allocator) - -// Read a batch from a byte stream, ensure the reader is closed -Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch -} { - reader.close() -} +val in = new ByteArrayInputStream(batchBytes) +MessageSerializer.deserializeMessageBatch(new ReadChannel(Channels.newChannel(in)), allocator) + .asInstanceOf[ArrowRecordBatch] // throws IOException } + /** + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + */ private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { -val rdd = payloadRDD.rdd.mapPartitions { iter => +val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] +val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone +val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } -val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd, schema) } + + /** + * Read a file as an Arrow stream and return an RDD of serialized ArrowRecordBatches. + */ + private[sql] def readArrowStreamFromFile(sqlContext: SQLContext, filename: String): + JavaRDD[Array[Byte]] = { +val fileStream = new FileInputStream(filename) +try { + // Create array so that we can safely close the file + val batches = getBatchesFromStream(fileStream.getChannel).toArray + // Parallelize the record batches to create an RDD + JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) +} finally { + fileStream.close() +} + } + + /** + * Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches. + */ + private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = { + +// TODO: simplify in super class +class RecordBatchMessageReader(inputChannel: SeekableByteChannel) { --- End diff -- Sure. The main reason that I am not sure about this code is that the code here breaks encapsulation. If I understand correctly, Arrow reader only supports reading record batch from Channel to Arrow memory, in order to read record batch from Channel to on-heap memory directly, we need to subclass `MessageChannelReader`, overwrite `readNextMessage` to load the the metadata and body of record batch. Now the main point that I feel not comfortable with this approach: The subclass changes the behavior of `readNextMessage` to load both metadata and body of a record batch, where in the parent class it only loads meta of a record batch. And I think this is the contract of the interface too so this feels a bit hacky. I am not saying I am totally against this for performance reasons, but considering the code path already involves writing data to disk (so avoid one memory copy won't necessary get us much) and is one of the less frequent operations (pandas DataFrame -> spark DataFrame), I am not sure it's worth it, that's why I suggest resolving this separately so not to block this PR. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r196247663 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala --- @@ -183,34 +182,131 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { -val in = new ByteArrayReadableSeekableByteChannel(batchBytes) -val reader = new ArrowFileReader(in, allocator) - -// Read a batch from a byte stream, ensure the reader is closed -Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch -} { - reader.close() -} +val in = new ByteArrayInputStream(batchBytes) +MessageSerializer.deserializeMessageBatch(new ReadChannel(Channels.newChannel(in)), allocator) + .asInstanceOf[ArrowRecordBatch] // throws IOException } + /** + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + */ private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { -val rdd = payloadRDD.rdd.mapPartitions { iter => +val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] +val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone +val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } -val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd, schema) } + + /** + * Read a file as an Arrow stream and return an RDD of serialized ArrowRecordBatches. + */ + private[sql] def readArrowStreamFromFile(sqlContext: SQLContext, filename: String): + JavaRDD[Array[Byte]] = { +val fileStream = new FileInputStream(filename) +try { + // Create array so that we can safely close the file + val batches = getBatchesFromStream(fileStream.getChannel).toArray + // Parallelize the record batches to create an RDD + JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) +} finally { + fileStream.close() +} + } + + /** + * Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches. + */ + private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = { + +// TODO: simplify in super class +class RecordBatchMessageReader(inputChannel: SeekableByteChannel) { --- End diff -- I think we should definitely avoid extra copies of the data in the JVM if possible, since we are trying to be efficient here. This process doesn't really seem complex to me, the specialized code here is about 10 lines and just reads bytes from an input stream to a byte buffer. Can you clarify a bit more on why you think this should be in a separate PR? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r196172708 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala --- @@ -3236,13 +3236,49 @@ class Dataset[T] private[sql]( } /** - * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. + * Collect a Dataset as Arrow batches and serve stream to PySpark. */ private[sql] def collectAsArrowToPython(): Array[Any] = { +val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone + withAction("collectAsArrowToPython", queryExecution) { plan => - val iter: Iterator[Array[Byte]] = -toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable) - PythonRDD.serveIterator(iter, "serve-Arrow") + PythonRDD.serveToStream("serve-Arrow") { out => +val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId) +val arrowBatchRdd = getArrowBatchRdd(plan) +val numPartitions = arrowBatchRdd.partitions.length + +// Store collection results for worst case of 1 to N-1 partitions +val results = new Array[Array[Array[Byte]]](numPartitions - 1) +var lastIndex = -1 // index of last partition written + +// Handler to eagerly write partitions to Python in order +def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = { + // If result is from next partition in order + if (index - 1 == lastIndex) { +batchWriter.writeBatches(arrowBatches.iterator) +lastIndex += 1 +// Write stored partitions that come next in order +while (lastIndex < results.length && results(lastIndex) != null) { + batchWriter.writeBatches(results(lastIndex).iterator) + results(lastIndex) = null + lastIndex += 1 +} +// After last batch, end the stream +if (lastIndex == results.length) { + batchWriter.end() +} + } else { +// Store partitions received out of order +results(index - 1) = arrowBatches + } +} + +sparkSession.sparkContext.runJob( + arrowBatchRdd, + (ctx: TaskContext, it: Iterator[Array[Byte]]) => it.toArray, --- End diff -- I tried playing around with that a while ago and can't remember if there was some problem, but I'll give it another shot. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r196172265 --- Diff: python/pyspark/serializers.py --- @@ -184,27 +184,31 @@ def loads(self, obj): raise NotImplementedError -class ArrowSerializer(FramedSerializer): +class ArrowStreamSerializer(Serializer): --- End diff -- That was my thought too. It's pretty close, although we do some different handling `ArrowStreamPandasSerializer` that needs to fit in somewhere. Maybe we can look into this as a followup? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user icexelloss commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r196171319 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala --- @@ -183,34 +182,131 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { -val in = new ByteArrayReadableSeekableByteChannel(batchBytes) -val reader = new ArrowFileReader(in, allocator) - -// Read a batch from a byte stream, ensure the reader is closed -Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch -} { - reader.close() -} +val in = new ByteArrayInputStream(batchBytes) +MessageSerializer.deserializeMessageBatch(new ReadChannel(Channels.newChannel(in)), allocator) + .asInstanceOf[ArrowRecordBatch] // throws IOException } + /** + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + */ private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { -val rdd = payloadRDD.rdd.mapPartitions { iter => +val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] +val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone +val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } -val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd, schema) } + + /** + * Read a file as an Arrow stream and return an RDD of serialized ArrowRecordBatches. + */ + private[sql] def readArrowStreamFromFile(sqlContext: SQLContext, filename: String): + JavaRDD[Array[Byte]] = { +val fileStream = new FileInputStream(filename) +try { + // Create array so that we can safely close the file + val batches = getBatchesFromStream(fileStream.getChannel).toArray + // Parallelize the record batches to create an RDD + JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) +} finally { + fileStream.close() +} + } + + /** + * Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches. + */ + private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = { + +// TODO: simplify in super class +class RecordBatchMessageReader(inputChannel: SeekableByteChannel) { --- End diff -- I see. Since this codepath already involves multiple copy of data (python -> disk -> jvm), I am not sure reducing one memory copy is worth the complexity. I feel at least we should have this change in the separate PR maybe? What do you think? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r196170750 --- Diff: python/pyspark/sql/dataframe.py --- @@ -2153,7 +2153,7 @@ def _collectAsArrow(self): """ with SCCallSiteSync(self._sc) as css: sock_info = self._jdf.collectAsArrowToPython() -return list(_load_from_socket(sock_info, ArrowSerializer())) +return list(_load_from_socket(sock_info, ArrowStreamSerializer())) --- End diff -- Oh yeah, thanks! --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r196169993 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala --- @@ -183,34 +182,131 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { -val in = new ByteArrayReadableSeekableByteChannel(batchBytes) -val reader = new ArrowFileReader(in, allocator) - -// Read a batch from a byte stream, ensure the reader is closed -Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch -} { - reader.close() -} +val in = new ByteArrayInputStream(batchBytes) +MessageSerializer.deserializeMessageBatch(new ReadChannel(Channels.newChannel(in)), allocator) + .asInstanceOf[ArrowRecordBatch] // throws IOException } + /** + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + */ private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { -val rdd = payloadRDD.rdd.mapPartitions { iter => +val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] +val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone +val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } -val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd, schema) } + + /** + * Read a file as an Arrow stream and return an RDD of serialized ArrowRecordBatches. + */ + private[sql] def readArrowStreamFromFile(sqlContext: SQLContext, filename: String): + JavaRDD[Array[Byte]] = { +val fileStream = new FileInputStream(filename) +try { + // Create array so that we can safely close the file + val batches = getBatchesFromStream(fileStream.getChannel).toArray + // Parallelize the record batches to create an RDD + JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) +} finally { + fileStream.close() +} + } + + /** + * Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches. + */ + private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = { + +// TODO: simplify in super class +class RecordBatchMessageReader(inputChannel: SeekableByteChannel) { --- End diff -- >The whole class here is trying to read Arrow record batches from an stream into Java on-heap memory without going through Arrow off-heap memory, is that correct? Yes, that's correct. This is done to parallelize the Arrow record batches. > Also, this function is only used for pandas DataFrame -> Spark DataFrame? Yes, `RecordBatchMessageReader` is a specialized class and only meant for this purpose. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user icexelloss commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r196108332 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala --- @@ -183,34 +182,131 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { -val in = new ByteArrayReadableSeekableByteChannel(batchBytes) -val reader = new ArrowFileReader(in, allocator) - -// Read a batch from a byte stream, ensure the reader is closed -Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch -} { - reader.close() -} +val in = new ByteArrayInputStream(batchBytes) +MessageSerializer.deserializeMessageBatch(new ReadChannel(Channels.newChannel(in)), allocator) + .asInstanceOf[ArrowRecordBatch] // throws IOException } + /** + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + */ private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { -val rdd = payloadRDD.rdd.mapPartitions { iter => +val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] +val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone +val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } -val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd, schema) } + + /** + * Read a file as an Arrow stream and return an RDD of serialized ArrowRecordBatches. + */ + private[sql] def readArrowStreamFromFile(sqlContext: SQLContext, filename: String): + JavaRDD[Array[Byte]] = { +val fileStream = new FileInputStream(filename) +try { + // Create array so that we can safely close the file + val batches = getBatchesFromStream(fileStream.getChannel).toArray + // Parallelize the record batches to create an RDD + JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) +} finally { + fileStream.close() +} + } + + /** + * Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches. + */ + private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = { + +// TODO: simplify in super class +class RecordBatchMessageReader(inputChannel: SeekableByteChannel) { --- End diff -- @BryanCutler Now I looked at this more I think I understand what are you trying to do: The whole class here is trying to read Arrow record batches from an stream into Java on-heap memory without going through Arrow off-heap memory, is that correct? Also, this function is only used for pandas DataFrame -> Spark DataFrame? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user ueshin commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r195852498 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala --- @@ -3236,13 +3236,49 @@ class Dataset[T] private[sql]( } /** - * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. + * Collect a Dataset as Arrow batches and serve stream to PySpark. */ private[sql] def collectAsArrowToPython(): Array[Any] = { +val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone + withAction("collectAsArrowToPython", queryExecution) { plan => - val iter: Iterator[Array[Byte]] = -toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable) - PythonRDD.serveIterator(iter, "serve-Arrow") + PythonRDD.serveToStream("serve-Arrow") { out => +val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId) +val arrowBatchRdd = getArrowBatchRdd(plan) +val numPartitions = arrowBatchRdd.partitions.length + +// Store collection results for worst case of 1 to N-1 partitions +val results = new Array[Array[Array[Byte]]](numPartitions - 1) +var lastIndex = -1 // index of last partition written + +// Handler to eagerly write partitions to Python in order +def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = { + // If result is from next partition in order + if (index - 1 == lastIndex) { +batchWriter.writeBatches(arrowBatches.iterator) +lastIndex += 1 +// Write stored partitions that come next in order +while (lastIndex < results.length && results(lastIndex) != null) { + batchWriter.writeBatches(results(lastIndex).iterator) + results(lastIndex) = null + lastIndex += 1 +} +// After last batch, end the stream +if (lastIndex == results.length) { + batchWriter.end() +} + } else { +// Store partitions received out of order +results(index - 1) = arrowBatches + } +} + +sparkSession.sparkContext.runJob( + arrowBatchRdd, + (ctx: TaskContext, it: Iterator[Array[Byte]]) => it.toArray, --- End diff -- Can we call `handlePartitionBatches` here before `it.toArray`? I'd do `it.toArray` as lazy as possible. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user ueshin commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r195822151 --- Diff: python/pyspark/serializers.py --- @@ -184,27 +184,31 @@ def loads(self, obj): raise NotImplementedError -class ArrowSerializer(FramedSerializer): +class ArrowStreamSerializer(Serializer): --- End diff -- I'm wondering if we can reuse this for `ArrowStreamPandasSerializer`? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user ueshin commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r195824174 --- Diff: python/pyspark/sql/dataframe.py --- @@ -2153,7 +2153,7 @@ def _collectAsArrow(self): """ with SCCallSiteSync(self._sc) as css: sock_info = self._jdf.collectAsArrowToPython() -return list(_load_from_socket(sock_info, ArrowSerializer())) +return list(_load_from_socket(sock_info, ArrowStreamSerializer())) --- End diff -- We also need to update the description of `_collectAsArrow()`? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user icexelloss commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r195562446 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala --- @@ -183,34 +182,131 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { -val in = new ByteArrayReadableSeekableByteChannel(batchBytes) -val reader = new ArrowFileReader(in, allocator) - -// Read a batch from a byte stream, ensure the reader is closed -Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch -} { - reader.close() -} +val in = new ByteArrayInputStream(batchBytes) +MessageSerializer.deserializeMessageBatch(new ReadChannel(Channels.newChannel(in)), allocator) + .asInstanceOf[ArrowRecordBatch] // throws IOException } + /** + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + */ private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { -val rdd = payloadRDD.rdd.mapPartitions { iter => +val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] +val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone +val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } -val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd, schema) } + + /** + * Read a file as an Arrow stream and return an RDD of serialized ArrowRecordBatches. + */ + private[sql] def readArrowStreamFromFile(sqlContext: SQLContext, filename: String): + JavaRDD[Array[Byte]] = { +val fileStream = new FileInputStream(filename) +try { + // Create array so that we can safely close the file + val batches = getBatchesFromStream(fileStream.getChannel).toArray + // Parallelize the record batches to create an RDD + JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) +} finally { + fileStream.close() +} + } + + /** + * Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches. + */ + private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = { + +// TODO: simplify in super class +class RecordBatchMessageReader(inputChannel: SeekableByteChannel) { --- End diff -- I wonder if this works? In a for-loop: (1) Read next batch into VectorSchemaRoot (copy into arrow memory) (2) Use VectorUnloader to unload the VectorSchemaRoot to an ArrowRecordBatch (no copy) (3) Use MessageSerializer.serialize to write ArrowRecordBatch to a ByteChannel (copy from arrow memory to java memory) Seems that we cannot directly read from socket into java memory anyway (have to go through Arrow memory allocator).. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user icexelloss commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r195559505 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala --- @@ -183,34 +182,131 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { -val in = new ByteArrayReadableSeekableByteChannel(batchBytes) -val reader = new ArrowFileReader(in, allocator) - -// Read a batch from a byte stream, ensure the reader is closed -Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch -} { - reader.close() -} +val in = new ByteArrayInputStream(batchBytes) +MessageSerializer.deserializeMessageBatch(new ReadChannel(Channels.newChannel(in)), allocator) + .asInstanceOf[ArrowRecordBatch] // throws IOException } + /** + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + */ private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { -val rdd = payloadRDD.rdd.mapPartitions { iter => +val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] +val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone +val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } -val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd, schema) } + + /** + * Read a file as an Arrow stream and return an RDD of serialized ArrowRecordBatches. + */ + private[sql] def readArrowStreamFromFile(sqlContext: SQLContext, filename: String): + JavaRDD[Array[Byte]] = { +val fileStream = new FileInputStream(filename) +try { + // Create array so that we can safely close the file + val batches = getBatchesFromStream(fileStream.getChannel).toArray + // Parallelize the record batches to create an RDD + JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) +} finally { + fileStream.close() +} + } + + /** + * Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches. + */ + private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = { + +// TODO: simplify in super class +class RecordBatchMessageReader(inputChannel: SeekableByteChannel) { --- End diff -- Yeah this does seem pretty complicated. I suppose you didn't use ``` public static ArrowBlock serialize(WriteChannel out, ArrowRecordBatch batch) ``` in message serialize to avoid double copy? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user icexelloss commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r195554625 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala --- @@ -3236,13 +3236,49 @@ class Dataset[T] private[sql]( } /** - * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. + * Collect a Dataset as Arrow batches and serve stream to PySpark. */ private[sql] def collectAsArrowToPython(): Array[Any] = { +val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone + withAction("collectAsArrowToPython", queryExecution) { plan => - val iter: Iterator[Array[Byte]] = -toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable) - PythonRDD.serveIterator(iter, "serve-Arrow") + PythonRDD.serveToStream("serve-Arrow") { out => +val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId) +val arrowBatchRdd = getArrowBatchRdd(plan) +val numPartitions = arrowBatchRdd.partitions.length + +// Store collection results for worst case of 1 to N-1 partitions +val results = new Array[Array[Array[Byte]]](numPartitions - 1) +var lastIndex = -1 // index of last partition written + +// Handler to eagerly write partitions to Python in order +def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = { + // If result is from next partition in order + if (index - 1 == lastIndex) { +batchWriter.writeBatches(arrowBatches.iterator) +lastIndex += 1 +// Write stored partitions that come next in order +while (lastIndex < results.length && results(lastIndex) != null) { + batchWriter.writeBatches(results(lastIndex).iterator) + results(lastIndex) = null + lastIndex += 1 +} +// After last batch, end the stream +if (lastIndex == results.length) { + batchWriter.end() +} + } else { +// Store partitions received out of order +results(index - 1) = arrowBatches + } +} + +sparkSession.sparkContext.runJob( + arrowBatchRdd, + (ctx: TaskContext, it: Iterator[Array[Byte]]) => it.toArray, + 0 until numPartitions, + handlePartitionBatches) --- End diff -- I guess in worst case scenario, the driver still needs to hold all batches in memory. For example, all the batches arrive at the same time. I wonder if there is a way to: (1) Compute all tasks in parallell, once tasks are done, store the result in Block manager on executors. (2) Return all block id to the driver (3) Driver fetches each block and stream individually. This way at least the computation is done in parallel, fetching the result sequentially is a trade off of speed vs memory, something we or the user can choose, but I imagine fetching some 10G - 20G data from executors sequentially shouldn't be too bad. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user icexelloss commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r195552695 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala --- @@ -3236,13 +3236,49 @@ class Dataset[T] private[sql]( } /** - * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. + * Collect a Dataset as Arrow batches and serve stream to PySpark. */ private[sql] def collectAsArrowToPython(): Array[Any] = { +val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone + withAction("collectAsArrowToPython", queryExecution) { plan => - val iter: Iterator[Array[Byte]] = -toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable) - PythonRDD.serveIterator(iter, "serve-Arrow") + PythonRDD.serveToStream("serve-Arrow") { out => +val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId) +val arrowBatchRdd = getArrowBatchRdd(plan) +val numPartitions = arrowBatchRdd.partitions.length + +// Store collection results for worst case of 1 to N-1 partitions +val results = new Array[Array[Array[Byte]]](numPartitions - 1) +var lastIndex = -1 // index of last partition written + +// Handler to eagerly write partitions to Python in order +def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = { + // If result is from next partition in order + if (index - 1 == lastIndex) { +batchWriter.writeBatches(arrowBatches.iterator) +lastIndex += 1 +// Write stored partitions that come next in order +while (lastIndex < results.length && results(lastIndex) != null) { + batchWriter.writeBatches(results(lastIndex).iterator) + results(lastIndex) = null + lastIndex += 1 +} +// After last batch, end the stream +if (lastIndex == results.length) { + batchWriter.end() +} + } else { +// Store partitions received out of order +results(index - 1) = arrowBatches + } +} + +sparkSession.sparkContext.runJob( + arrowBatchRdd, + (ctx: TaskContext, it: Iterator[Array[Byte]]) => it.toArray, + 0 until numPartitions, + handlePartitionBatches) --- End diff -- > I did have another idea though, we could stream all partitions to Python out of order, then follow with another small batch of data that contains maps of partitionIndex to orderReceived. Then the partitions could be put into order on the Python side before making the Pandas DataFrame. This sounds good! --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r195512218 --- Diff: sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala --- @@ -1318,18 +1318,52 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { } } - test("roundtrip payloads") { + test("roundtrip arrow batches") { val inputRows = (0 until 9).map { i => InternalRow(i) } :+ InternalRow(null) val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) val ctx = TaskContext.empty() -val payloadIter = ArrowConverters.toPayloadIterator(inputRows.toIterator, schema, 0, null, ctx) -val outputRowIter = ArrowConverters.fromPayloadIterator(payloadIter, ctx) +val batchIter = ArrowConverters.toBatchIterator(inputRows.toIterator, schema, 5, null, ctx) +val outputRowIter = ArrowConverters.fromBatchIterator(batchIter, schema, null, ctx) -assert(schema == outputRowIter.schema) +var count = 0 +outputRowIter.zipWithIndex.foreach { case (row, i) => + if (i != 9) { +assert(row.getInt(0) == i) + } else { +assert(row.isNullAt(0)) + } + count += 1 +} + +assert(count == inputRows.length) + } + + test("ArrowBatchStreamWriter roundtrip") { +val inputRows = (0 until 9).map { i => + InternalRow(i) +} :+ InternalRow(null) + +val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) + +val ctx = TaskContext.empty() +val batchIter = ArrowConverters.toBatchIterator(inputRows.toIterator, schema, 5, null, ctx) + +// Write batches to Arrow stream format as a byte array +val out = new ByteArrayOutputStream() --- End diff -- done --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r195504451 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala --- @@ -34,17 +34,36 @@ private[sql] object PythonSQLUtils { } /** - * Python Callable function to convert ArrowPayloads into a [[DataFrame]]. + * Python callable function to convert an RDD of serialized ArrowRecordBatches into + * a [[DataFrame]]. * - * @param payloadRDD A JavaRDD of ArrowPayloads. - * @param schemaString JSON Formatted Schema for ArrowPayloads. + * @param arrowBatchRDD A JavaRDD of serialized ArrowRecordBatches. + * @param schemaString JSON Formatted Spark schema for Arrow batches. * @param sqlContext The active [[SQLContext]]. * @return The converted [[DataFrame]]. */ - def arrowPayloadToDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + def arrowStreamToDataFrame( --- End diff -- oh right, this is only called by the function below so I suppose we don't even need it.. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r195502588 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala --- @@ -34,17 +34,36 @@ private[sql] object PythonSQLUtils { } /** - * Python Callable function to convert ArrowPayloads into a [[DataFrame]]. + * Python callable function to convert an RDD of serialized ArrowRecordBatches into + * a [[DataFrame]]. * - * @param payloadRDD A JavaRDD of ArrowPayloads. - * @param schemaString JSON Formatted Schema for ArrowPayloads. + * @param arrowBatchRDD A JavaRDD of serialized ArrowRecordBatches. + * @param schemaString JSON Formatted Spark schema for Arrow batches. * @param sqlContext The active [[SQLContext]]. * @return The converted [[DataFrame]]. */ - def arrowPayloadToDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + def arrowStreamToDataFrame( + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { -ArrowConverters.toDataFrame(payloadRDD, schemaString, sqlContext) +ArrowConverters.toDataFrame(arrowBatchRDD, schemaString, sqlContext) + } + + /** + * Python callable function to read a file in Arrow stream format and create a [[DataFrame]] + * using each serialized ArrowRecordBatch as a partition. + * + * @param sqlContext The active [[SQLContext]]. + * @param filename File to read the Arrow stream from. + * @param schemaString JSON Formatted Spark schema for Arrow batches. + * @return A new [[DataFrame]]. + */ + def arrowReadStreamFromFile( + sqlContext: SQLContext, + filename: String, + schemaString: String): DataFrame = { +JavaSparkContext.fromSparkContext(sqlContext.sparkContext) --- End diff -- oops, nothing! I must have forgot to delete, thanks! --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r195501939 --- Diff: python/pyspark/serializers.py --- @@ -184,24 +184,28 @@ def loads(self, obj): raise NotImplementedError -class ArrowSerializer(FramedSerializer): +class ArrowSerializer(Serializer): --- End diff -- Maybe `ArrowStreamSerializer`? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r195501843 --- Diff: sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala --- @@ -1318,18 +1318,52 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { } } - test("roundtrip payloads") { + test("roundtrip arrow batches") { val inputRows = (0 until 9).map { i => InternalRow(i) } :+ InternalRow(null) val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) val ctx = TaskContext.empty() -val payloadIter = ArrowConverters.toPayloadIterator(inputRows.toIterator, schema, 0, null, ctx) -val outputRowIter = ArrowConverters.fromPayloadIterator(payloadIter, ctx) +val batchIter = ArrowConverters.toBatchIterator(inputRows.toIterator, schema, 5, null, ctx) +val outputRowIter = ArrowConverters.fromBatchIterator(batchIter, schema, null, ctx) -assert(schema == outputRowIter.schema) +var count = 0 +outputRowIter.zipWithIndex.foreach { case (row, i) => + if (i != 9) { +assert(row.getInt(0) == i) + } else { +assert(row.isNullAt(0)) + } + count += 1 +} + +assert(count == inputRows.length) + } + + test("ArrowBatchStreamWriter roundtrip") { +val inputRows = (0 until 9).map { i => + InternalRow(i) +} :+ InternalRow(null) + +val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) + +val ctx = TaskContext.empty() +val batchIter = ArrowConverters.toBatchIterator(inputRows.toIterator, schema, 5, null, ctx) + +// Write batches to Arrow stream format as a byte array +val out = new ByteArrayOutputStream() --- End diff -- This doesn't actually need to be closed, but I should be closing the DataOutputStream, so I'll put that in tryWithResource --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r195499089 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala --- @@ -34,17 +34,36 @@ private[sql] object PythonSQLUtils { } /** - * Python Callable function to convert ArrowPayloads into a [[DataFrame]]. + * Python callable function to convert an RDD of serialized ArrowRecordBatches into + * a [[DataFrame]]. * - * @param payloadRDD A JavaRDD of ArrowPayloads. - * @param schemaString JSON Formatted Schema for ArrowPayloads. + * @param arrowBatchRDD A JavaRDD of serialized ArrowRecordBatches. + * @param schemaString JSON Formatted Spark schema for Arrow batches. * @param sqlContext The active [[SQLContext]]. * @return The converted [[DataFrame]]. */ - def arrowPayloadToDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + def arrowStreamToDataFrame( --- End diff -- it's public so it can be called in Python with Py4j --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r195498764 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala --- @@ -3236,13 +3236,49 @@ class Dataset[T] private[sql]( } /** - * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. + * Collect a Dataset as Arrow batches and serve stream to PySpark. */ private[sql] def collectAsArrowToPython(): Array[Any] = { +val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone + withAction("collectAsArrowToPython", queryExecution) { plan => - val iter: Iterator[Array[Byte]] = -toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable) - PythonRDD.serveIterator(iter, "serve-Arrow") + PythonRDD.serveToStream("serve-Arrow") { out => +val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId) +val arrowBatchRdd = getArrowBatchRdd(plan) +val numPartitions = arrowBatchRdd.partitions.length + +// Store collection results for worst case of 1 to N-1 partitions +val results = new Array[Array[Array[Byte]]](numPartitions - 1) +var lastIndex = -1 // index of last partition written + +// Handler to eagerly write partitions to Python in order +def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = { + // If result is from next partition in order + if (index - 1 == lastIndex) { +batchWriter.writeBatches(arrowBatches.iterator) +lastIndex += 1 +// Write stored partitions that come next in order +while (lastIndex < results.length && results(lastIndex) != null) { + batchWriter.writeBatches(results(lastIndex).iterator) + results(lastIndex) = null + lastIndex += 1 +} +// After last batch, end the stream +if (lastIndex == results.length) { + batchWriter.end() +} + } else { +// Store partitions received out of order +results(index - 1) = arrowBatches + } +} + +sparkSession.sparkContext.runJob( + arrowBatchRdd, + (ctx: TaskContext, it: Iterator[Array[Byte]]) => it.toArray, + 0 until numPartitions, + handlePartitionBatches) --- End diff -- > +1 chunking if we could. I recall Bryan said for grouped UDF we need the entire set. This still keeps Arrow record batches chunked within each partition, which can help the executor memory, but doesn't do anything for the driver side because we still need to collect the entire partition in the driver JVM. > Also not sure if python side we have any assumption on how much of the partition is in each chunk (there shouldn't be?) No, Python doesn't care how many chunks the data is in, it's handled by pyarrow --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r195497043 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala --- @@ -3236,13 +3236,49 @@ class Dataset[T] private[sql]( } /** - * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. + * Collect a Dataset as Arrow batches and serve stream to PySpark. */ private[sql] def collectAsArrowToPython(): Array[Any] = { +val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone + withAction("collectAsArrowToPython", queryExecution) { plan => - val iter: Iterator[Array[Byte]] = -toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable) - PythonRDD.serveIterator(iter, "serve-Arrow") + PythonRDD.serveToStream("serve-Arrow") { out => +val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId) +val arrowBatchRdd = getArrowBatchRdd(plan) +val numPartitions = arrowBatchRdd.partitions.length + +// Store collection results for worst case of 1 to N-1 partitions +val results = new Array[Array[Array[Byte]]](numPartitions - 1) +var lastIndex = -1 // index of last partition written + +// Handler to eagerly write partitions to Python in order +def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = { + // If result is from next partition in order + if (index - 1 == lastIndex) { +batchWriter.writeBatches(arrowBatches.iterator) +lastIndex += 1 +// Write stored partitions that come next in order +while (lastIndex < results.length && results(lastIndex) != null) { + batchWriter.writeBatches(results(lastIndex).iterator) + results(lastIndex) = null + lastIndex += 1 +} +// After last batch, end the stream +if (lastIndex == results.length) { + batchWriter.end() +} + } else { +// Store partitions received out of order +results(index - 1) = arrowBatches + } +} + +sparkSession.sparkContext.runJob( + arrowBatchRdd, + (ctx: TaskContext, it: Iterator[Array[Byte]]) => it.toArray, + 0 until numPartitions, + handlePartitionBatches) --- End diff -- > is it better to incrementally run job on partitions in order I believe this is how `toLocalIterator` works right? I tried using that because it does only keep 1 partition in memory at a time, but the performance took quite a hit from the multiple jobs. I think we should still prioritize performance over memory for `toPandas()` since it's assumed the data to be collect should be relatively small. I did have another idea though, we could stream all partitions to Python out of order, then follow with another small batch of data that contains maps of partitionIndex to orderReceived. Then the partitions could be put into order on the Python side before making the Pandas DataFrame. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user HyukjinKwon commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r194966161 --- Diff: python/pyspark/serializers.py --- @@ -184,24 +184,28 @@ def loads(self, obj): raise NotImplementedError -class ArrowSerializer(FramedSerializer): +class ArrowSerializer(Serializer): --- End diff -- Should we rename this? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user HyukjinKwon commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r194965715 --- Diff: sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala --- @@ -1318,18 +1318,52 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { } } - test("roundtrip payloads") { + test("roundtrip arrow batches") { val inputRows = (0 until 9).map { i => InternalRow(i) } :+ InternalRow(null) val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) val ctx = TaskContext.empty() -val payloadIter = ArrowConverters.toPayloadIterator(inputRows.toIterator, schema, 0, null, ctx) -val outputRowIter = ArrowConverters.fromPayloadIterator(payloadIter, ctx) +val batchIter = ArrowConverters.toBatchIterator(inputRows.toIterator, schema, 5, null, ctx) +val outputRowIter = ArrowConverters.fromBatchIterator(batchIter, schema, null, ctx) -assert(schema == outputRowIter.schema) +var count = 0 +outputRowIter.zipWithIndex.foreach { case (row, i) => + if (i != 9) { +assert(row.getInt(0) == i) + } else { +assert(row.isNullAt(0)) + } + count += 1 +} + +assert(count == inputRows.length) + } + + test("ArrowBatchStreamWriter roundtrip") { +val inputRows = (0 until 9).map { i => + InternalRow(i) +} :+ InternalRow(null) + +val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) + +val ctx = TaskContext.empty() +val batchIter = ArrowConverters.toBatchIterator(inputRows.toIterator, schema, 5, null, ctx) + +// Write batches to Arrow stream format as a byte array +val out = new ByteArrayOutputStream() --- End diff -- Can we use `Utils.tryWithResource { ... }`? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user viirya commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r194963031 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala --- @@ -183,34 +182,131 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { -val in = new ByteArrayReadableSeekableByteChannel(batchBytes) -val reader = new ArrowFileReader(in, allocator) - -// Read a batch from a byte stream, ensure the reader is closed -Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch -} { - reader.close() -} +val in = new ByteArrayInputStream(batchBytes) +MessageSerializer.deserializeMessageBatch(new ReadChannel(Channels.newChannel(in)), allocator) + .asInstanceOf[ArrowRecordBatch] // throws IOException } + /** + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + */ private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { -val rdd = payloadRDD.rdd.mapPartitions { iter => +val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] +val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone +val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } -val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd, schema) } + + /** + * Read a file as an Arrow stream and return an RDD of serialized ArrowRecordBatches. + */ + private[sql] def readArrowStreamFromFile(sqlContext: SQLContext, filename: String): + JavaRDD[Array[Byte]] = { --- End diff -- indentation --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user viirya commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r194962360 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala --- @@ -34,17 +34,36 @@ private[sql] object PythonSQLUtils { } /** - * Python Callable function to convert ArrowPayloads into a [[DataFrame]]. + * Python callable function to convert an RDD of serialized ArrowRecordBatches into + * a [[DataFrame]]. * - * @param payloadRDD A JavaRDD of ArrowPayloads. - * @param schemaString JSON Formatted Schema for ArrowPayloads. + * @param arrowBatchRDD A JavaRDD of serialized ArrowRecordBatches. + * @param schemaString JSON Formatted Spark schema for Arrow batches. * @param sqlContext The active [[SQLContext]]. * @return The converted [[DataFrame]]. */ - def arrowPayloadToDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + def arrowStreamToDataFrame( --- End diff -- This seems being a private method now? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user felixcheung commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r194954051 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala --- @@ -3236,13 +3236,49 @@ class Dataset[T] private[sql]( } /** - * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. + * Collect a Dataset as Arrow batches and serve stream to PySpark. */ private[sql] def collectAsArrowToPython(): Array[Any] = { +val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone + withAction("collectAsArrowToPython", queryExecution) { plan => - val iter: Iterator[Array[Byte]] = -toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable) - PythonRDD.serveIterator(iter, "serve-Arrow") + PythonRDD.serveToStream("serve-Arrow") { out => +val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId) +val arrowBatchRdd = getArrowBatchRdd(plan) +val numPartitions = arrowBatchRdd.partitions.length + +// Store collection results for worst case of 1 to N-1 partitions +val results = new Array[Array[Array[Byte]]](numPartitions - 1) +var lastIndex = -1 // index of last partition written + +// Handler to eagerly write partitions to Python in order +def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = { + // If result is from next partition in order + if (index - 1 == lastIndex) { +batchWriter.writeBatches(arrowBatches.iterator) +lastIndex += 1 +// Write stored partitions that come next in order +while (lastIndex < results.length && results(lastIndex) != null) { + batchWriter.writeBatches(results(lastIndex).iterator) + results(lastIndex) = null + lastIndex += 1 +} +// After last batch, end the stream +if (lastIndex == results.length) { + batchWriter.end() +} + } else { +// Store partitions received out of order +results(index - 1) = arrowBatches + } +} + +sparkSession.sparkContext.runJob( + arrowBatchRdd, + (ctx: TaskContext, it: Iterator[Array[Byte]]) => it.toArray, + 0 until numPartitions, + handlePartitionBatches) --- End diff -- +1 chunking if we could. I recall Bryan said for grouped UDF we need the entire set. Also not sure if python side we have any assumption on how much of the partition is in each chunk (there shouldn't be?) --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user viirya commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r194948874 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala --- @@ -3236,13 +3236,49 @@ class Dataset[T] private[sql]( } /** - * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. + * Collect a Dataset as Arrow batches and serve stream to PySpark. */ private[sql] def collectAsArrowToPython(): Array[Any] = { +val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone + withAction("collectAsArrowToPython", queryExecution) { plan => - val iter: Iterator[Array[Byte]] = -toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable) - PythonRDD.serveIterator(iter, "serve-Arrow") + PythonRDD.serveToStream("serve-Arrow") { out => +val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId) +val arrowBatchRdd = getArrowBatchRdd(plan) +val numPartitions = arrowBatchRdd.partitions.length + +// Store collection results for worst case of 1 to N-1 partitions +val results = new Array[Array[Array[Byte]]](numPartitions - 1) +var lastIndex = -1 // index of last partition written + +// Handler to eagerly write partitions to Python in order +def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = { + // If result is from next partition in order + if (index - 1 == lastIndex) { +batchWriter.writeBatches(arrowBatches.iterator) +lastIndex += 1 +// Write stored partitions that come next in order +while (lastIndex < results.length && results(lastIndex) != null) { + batchWriter.writeBatches(results(lastIndex).iterator) + results(lastIndex) = null + lastIndex += 1 +} +// After last batch, end the stream +if (lastIndex == results.length) { + batchWriter.end() +} + } else { +// Store partitions received out of order +results(index - 1) = arrowBatches + } +} + +sparkSession.sparkContext.runJob( + arrowBatchRdd, + (ctx: TaskContext, it: Iterator[Array[Byte]]) => it.toArray, + 0 until numPartitions, + handlePartitionBatches) --- End diff -- Instead of collecting partitions back at once and holding out of order partitions in driver waiting for partitions in order, is it better to incrementally run job on partitions in order and send streams to python side? So we don't need to hold out of order partitions in driver. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user viirya commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r194949076 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala --- @@ -34,17 +34,36 @@ private[sql] object PythonSQLUtils { } /** - * Python Callable function to convert ArrowPayloads into a [[DataFrame]]. + * Python callable function to convert an RDD of serialized ArrowRecordBatches into + * a [[DataFrame]]. * - * @param payloadRDD A JavaRDD of ArrowPayloads. - * @param schemaString JSON Formatted Schema for ArrowPayloads. + * @param arrowBatchRDD A JavaRDD of serialized ArrowRecordBatches. + * @param schemaString JSON Formatted Spark schema for Arrow batches. * @param sqlContext The active [[SQLContext]]. * @return The converted [[DataFrame]]. */ - def arrowPayloadToDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + def arrowStreamToDataFrame( + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { -ArrowConverters.toDataFrame(payloadRDD, schemaString, sqlContext) +ArrowConverters.toDataFrame(arrowBatchRDD, schemaString, sqlContext) + } + + /** + * Python callable function to read a file in Arrow stream format and create a [[DataFrame]] + * using each serialized ArrowRecordBatch as a partition. + * + * @param sqlContext The active [[SQLContext]]. + * @param filename File to read the Arrow stream from. + * @param schemaString JSON Formatted Spark schema for Arrow batches. + * @return A new [[DataFrame]]. + */ + def arrowReadStreamFromFile( + sqlContext: SQLContext, + filename: String, + schemaString: String): DataFrame = { +JavaSparkContext.fromSparkContext(sqlContext.sparkContext) --- End diff -- What is this line for? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r194904013 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala --- @@ -183,34 +182,131 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { -val in = new ByteArrayReadableSeekableByteChannel(batchBytes) -val reader = new ArrowFileReader(in, allocator) - -// Read a batch from a byte stream, ensure the reader is closed -Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch -} { - reader.close() -} +val in = new ByteArrayInputStream(batchBytes) +MessageSerializer.deserializeMessageBatch(new ReadChannel(Channels.newChannel(in)), allocator) + .asInstanceOf[ArrowRecordBatch] // throws IOException } + /** + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + */ private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { -val rdd = payloadRDD.rdd.mapPartitions { iter => +val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] +val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone +val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } -val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd, schema) } + + /** + * Read a file as an Arrow stream and return an RDD of serialized ArrowRecordBatches. + */ + private[sql] def readArrowStreamFromFile(sqlContext: SQLContext, filename: String): + JavaRDD[Array[Byte]] = { +val fileStream = new FileInputStream(filename) +try { + // Create array so that we can safely close the file + val batches = getBatchesFromStream(fileStream.getChannel).toArray + // Parallelize the record batches to create an RDD + JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) +} finally { + fileStream.close() +} + } + + /** + * Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches. + */ + private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = { + +// TODO: simplify in super class +class RecordBatchMessageReader(inputChannel: SeekableByteChannel) { --- End diff -- made https://issues.apache.org/jira/browse/ARROW-2704 to track --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r194898976 --- Diff: sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala --- @@ -51,11 +51,11 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { test("collect to arrow record batch") { val indexData = (1 to 6).toDF("i") -val arrowPayloads = indexData.toArrowPayload.collect() -assert(arrowPayloads.nonEmpty) -assert(arrowPayloads.length == indexData.rdd.getNumPartitions) +val arrowBatches = indexData.getArrowBatchRdd.collect() +assert(arrowBatches.nonEmpty) +assert(arrowBatches.length == indexData.rdd.getNumPartitions) --- End diff -- Most of these changes are just renames to be consistent --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/21546#discussion_r194898793 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala --- @@ -183,34 +182,131 @@ private[sql] object ArrowConverters { } /** - * Convert a byte array to an ArrowRecordBatch. + * Load a serialized ArrowRecordBatch. */ - private[arrow] def byteArrayToBatch( + private[arrow] def loadBatch( batchBytes: Array[Byte], allocator: BufferAllocator): ArrowRecordBatch = { -val in = new ByteArrayReadableSeekableByteChannel(batchBytes) -val reader = new ArrowFileReader(in, allocator) - -// Read a batch from a byte stream, ensure the reader is closed -Utils.tryWithSafeFinally { - val root = reader.getVectorSchemaRoot // throws IOException - val unloader = new VectorUnloader(root) - reader.loadNextBatch() // throws IOException - unloader.getRecordBatch -} { - reader.close() -} +val in = new ByteArrayInputStream(batchBytes) +MessageSerializer.deserializeMessageBatch(new ReadChannel(Channels.newChannel(in)), allocator) + .asInstanceOf[ArrowRecordBatch] // throws IOException } + /** + * Create a DataFrame from a JavaRDD of serialized ArrowRecordBatches. + */ private[sql] def toDataFrame( - payloadRDD: JavaRDD[Array[Byte]], + arrowBatchRDD: JavaRDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { -val rdd = payloadRDD.rdd.mapPartitions { iter => +val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] +val timeZoneId = sqlContext.sessionState.conf.sessionLocalTimeZone +val rdd = arrowBatchRDD.rdd.mapPartitions { iter => val context = TaskContext.get() - ArrowConverters.fromPayloadIterator(iter.map(new ArrowPayload(_)), context) + ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) } -val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] sqlContext.internalCreateDataFrame(rdd, schema) } + + /** + * Read a file as an Arrow stream and return an RDD of serialized ArrowRecordBatches. + */ + private[sql] def readArrowStreamFromFile(sqlContext: SQLContext, filename: String): + JavaRDD[Array[Byte]] = { +val fileStream = new FileInputStream(filename) +try { + // Create array so that we can safely close the file + val batches = getBatchesFromStream(fileStream.getChannel).toArray + // Parallelize the record batches to create an RDD + JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) +} finally { + fileStream.close() +} + } + + /** + * Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches. + */ + private[sql] def getBatchesFromStream(in: SeekableByteChannel): Iterator[Array[Byte]] = { + +// TODO: simplify in super class +class RecordBatchMessageReader(inputChannel: SeekableByteChannel) { --- End diff -- I had to modify the existing Arrow code to allow for this, but I will work on getting these changes into Arrow for 0.10.0 and then this class can be simplified a lot. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21546: [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream ...
GitHub user BryanCutler opened a pull request: https://github.com/apache/spark/pull/21546 [WIP][SPARK-23030][SQL][PYTHON] Use Arrow stream format for creating from and collecting Pandas DataFrames ## What changes were proposed in this pull request? This changes the calls of `toPandas()` and `createDataFrame()` to use the Arrow stream format, when Arrow is enabled. Previously, Arrow data was written to byte arrays where each chunk is an output of the Arrow file format. This was mainly due to constraints at the time, and caused some overhead by writing the schema/footer on each chunk of data and then having to read multiple Arrow file inputs and concat them together. Using the Arrow stream format has improved these by increasing performance, lower memory overhead for the average case, and simplified the code. Here are the details of this change: **toPandas()** _Before:_ Spark internal rows are converted to Arrow file format, each group of records is a complete Arrow file which contains the schema and other metadata. Next a collect is done and an Array of Arrow files is the result. After that each Arrow file is sent to Python driver which then loads each file and concats them to a single Arrow DataFrame. _After:_ Spark internal rows are converted to ArrowRecordBatches directly, which is the simplest Arrow component for IPC data transfers. The driver JVM then immediately starts serving data to Python as an Arrow stream, sending the schema first. It then starts Spark jobs with a custom handler such that when a partition is received (and in the correct order) the ArrowRecordBatches can be sent to python as soon as possible. This improves performance, simplifies memory usage on executors, and improves the average memory usage on the JVM driver. Since the order of partitions must be preserved, the worst case is that the first partition will be the last to arrive and all data must be kept in memory until then. This case is no worse that before when doing a full collect. **createDataFrame()** _Before:_ A Pandas DataFrame is split into parts and each part is made into an Arrow file. Then each file is prefixed by the buffer size and written to a temp file. The temp file is read and each Arrow file is parallelized as a byte array. _After:_ A Pandas DataFrame is split into parts, then an Arrow stream is written to a temp file where each part is an ArrowRecordBatch. The temp file is read as a stream and the Arrow messages are examined. If the message is an ArrowRecordBatch, the data is saved as a byte array. After reading the file, each ArrowRecordBatch is parallelized as a byte array. This has slightly more processing than before because we must look each Arrow message to extract the record batches, but performance remains the same. It is cleaner in the sense that IPC from Python to JVM is done over a single Arrow stream. ## How was this patch tested? Added new unit tests for the additions to ArrowConverters in Scala, existing tests for Python. You can merge this pull request into a Git repository by running: $ git pull https://github.com/BryanCutler/spark arrow-toPandas-stream-SPARK-23030 Alternatively you can review and apply these changes as the patch at: https://github.com/apache/spark/pull/21546.patch To close this pull request, make a commit to your master/trunk branch with (at least) the following in the commit message: This closes #21546 commit 9af482170ee95831bbda139e6e931ba2631df386 Author: Bryan Cutler Date: 2018-01-10T22:02:15Z change ArrowConverters to stream format commit d617f0da8eff1509da465bb707340e391314bec4 Author: Bryan Cutler Date: 2018-01-10T22:14:07Z Change ArrowSerializer to use stream format commit f10d5d9cd3cece7f56749e1de7fe01699e4759a0 Author: Bryan Cutler Date: 2018-01-12T00:40:36Z toPandas is working with RecordBatch payloads, using custom handler to stream ordered partitions commit 03653c687473b82bbfb6653504479498a2a3c63b Author: Bryan Cutler Date: 2018-02-10T00:23:17Z cleanup and removed ArrowPayload, createDataFrame now working commit 1b932463bca0815e79f3a8d61d1c816e62949698 Author: Bryan Cutler Date: 2018-03-09T00:14:06Z toPandas and createDataFrame working but tests fail with date cols commit ce22d8ad18e052d150528752b727c6cfe11485f7 Author: Bryan Cutler Date: 2018-03-27T00:32:03Z removed usage of seekableByteChannel commit dede0bd96921c439747a9176f24c9ecbb9c8ce0a Author: Bryan Cutler Date: 2018-03-28T00:28:54Z for toPandas, set old collection result to null and add comments commit 9e29b092cb7d45fa486db0215c3bd4a99c5f8d98 Author: Bryan Cutler Date: 2018-03-28T18:28:18Z cleanup, not yet passing ArrowConvertersSuite commit ceb8d38a6c83c3b6dae040c9e8d860811ecad0cc Author: Bryan Cutler Date: 2018-03-29T21:14:03Z fix