Github user cloud-fan commented on a diff in the pull request:
https://github.com/apache/spark/pull/19349#discussion_r141104830
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowStreamPythonUDFRunner.scala
---
@@ -0,0 +1,197 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import java.io._
+import java.net._
+import java.util.concurrent.atomic.AtomicBoolean
+
+import scala.collection.JavaConverters._
+
+import org.apache.arrow.vector.VectorSchemaRoot
+import org.apache.arrow.vector.stream.{ArrowStreamReader,
ArrowStreamWriter}
+
+import org.apache.spark._
+import org.apache.spark.api.python._
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.arrow.{ArrowUtils, ArrowWriter}
+import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector,
ColumnarBatch, ColumnVector}
+import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
+
+/**
+ * Similar to `PythonUDFRunner`, but exchange data with Python worker via
Arrow stream.
+ */
+class ArrowStreamPythonUDFRunner(
+ funcs: Seq[ChainedPythonFunctions],
+ batchSize: Int,
+ bufferSize: Int,
+ reuseWorker: Boolean,
+ evalType: Int,
+ argOffsets: Array[Array[Int]],
+ schema: StructType)
+ extends BasePythonRunner[InternalRow, ColumnarBatch](
+ funcs, bufferSize, reuseWorker, evalType, argOffsets) {
+
+ protected override def newWriterThread(
+ env: SparkEnv,
+ worker: Socket,
+ inputIterator: Iterator[InternalRow],
+ partitionIndex: Int,
+ context: TaskContext): WriterThread = {
+ new WriterThread(env, worker, inputIterator, partitionIndex, context) {
+
+ override def writeCommand(dataOut: DataOutputStream): Unit = {
+ dataOut.writeInt(funcs.length)
+ funcs.zip(argOffsets).foreach { case (chained, offsets) =>
+ dataOut.writeInt(offsets.length)
+ offsets.foreach { offset =>
+ dataOut.writeInt(offset)
+ }
+ dataOut.writeInt(chained.funcs.length)
+ chained.funcs.foreach { f =>
+ dataOut.writeInt(f.command.length)
+ dataOut.write(f.command)
+ }
+ }
+ }
+
+ override def writeIteratorToStream(dataOut: DataOutputStream): Unit
= {
+ val arrowSchema = ArrowUtils.toArrowSchema(schema)
+ val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+ s"stdout writer for $pythonExec", 0, Long.MaxValue)
+
+ val root = VectorSchemaRoot.create(arrowSchema, allocator)
+ val arrowWriter = ArrowWriter.create(root)
+
+ var closed = false
+
+ context.addTaskCompletionListener { _ =>
+ if (!closed) {
+ root.close()
+ allocator.close()
+ }
+ }
+
+ val writer = new ArrowStreamWriter(root, null, dataOut)
+ writer.start()
+
+ Utils.tryWithSafeFinally {
+ while (inputIterator.hasNext) {
+ var rowCount = 0
+ while (inputIterator.hasNext && (batchSize <= 0 || rowCount <
batchSize)) {
+ val row = inputIterator.next()
+ arrowWriter.write(row)
+ rowCount += 1
+ }
+ arrowWriter.finish()
+ writer.writeBatch()
+ arrowWriter.reset()
+ }
+ } {
+ writer.end()
+ root.close()
+ allocator.close()
+ closed = true
+ }
+ }
+ }
+ }
+
+ protected override def newReaderIterator(
+ stream: DataInputStream,
+ writerThread: WriterThread,
+ startTime: Long,
+ env: SparkEnv,
+ worker: Socket,
+ released: AtomicBoolean,
+ context: TaskContext): Iterator[ColumnarBatch] = {
+ new ReaderIterator(stream, writerThread, startTime, env, worker,
released, context) {
+
+ private val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+ s"stdin reader for $pythonExec", 0, Long.MaxValue)
+
+ private var reader: ArrowStreamReader = _
+ private var root: VectorSchemaRoot = _
+ private var schema: StructType = _
+ private var vectors: Array[ColumnVector] = _
+
+ private var closed = false
+
+ context.addTaskCompletionListener { _ =>
+ // todo: we need something like `read.end()`, which release all
the resources, but leave
--- End diff --
cc @BryanCutler for arrow side issues.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]