Github user viirya commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19349#discussion_r141243843
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
 ---
    @@ -0,0 +1,197 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.sql.execution.python
    +
    +import java.io._
    +import java.net._
    +import java.util.concurrent.atomic.AtomicBoolean
    +
    +import scala.collection.JavaConverters._
    +
    +import org.apache.arrow.vector.VectorSchemaRoot
    +import org.apache.arrow.vector.stream.{ArrowStreamReader, 
ArrowStreamWriter}
    +
    +import org.apache.spark._
    +import org.apache.spark.api.python._
    +import org.apache.spark.sql.catalyst.InternalRow
    +import org.apache.spark.sql.execution.arrow.{ArrowUtils, ArrowWriter}
    +import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, 
ColumnarBatch, ColumnVector}
    +import org.apache.spark.sql.types._
    +import org.apache.spark.util.Utils
    +
    +/**
    + * Similar to `PythonUDFRunner`, but exchange data with Python worker via 
Arrow stream.
    + */
    +class ArrowPythonRunner(
    +    funcs: Seq[ChainedPythonFunctions],
    +    batchSize: Int,
    +    bufferSize: Int,
    +    reuseWorker: Boolean,
    +    evalType: Int,
    +    argOffsets: Array[Array[Int]],
    +    schema: StructType)
    +  extends BasePythonRunner[InternalRow, ColumnarBatch](
    +    funcs, bufferSize, reuseWorker, evalType, argOffsets) {
    +
    +  protected override def newWriterThread(
    +      env: SparkEnv,
    +      worker: Socket,
    +      inputIterator: Iterator[InternalRow],
    +      partitionIndex: Int,
    +      context: TaskContext): WriterThread = {
    +    new WriterThread(env, worker, inputIterator, partitionIndex, context) {
    +
    +      override def writeCommand(dataOut: DataOutputStream): Unit = {
    +        dataOut.writeInt(funcs.length)
    +        funcs.zip(argOffsets).foreach { case (chained, offsets) =>
    +          dataOut.writeInt(offsets.length)
    +          offsets.foreach { offset =>
    +            dataOut.writeInt(offset)
    +          }
    +          dataOut.writeInt(chained.funcs.length)
    +          chained.funcs.foreach { f =>
    +            dataOut.writeInt(f.command.length)
    +            dataOut.write(f.command)
    +          }
    +        }
    +      }
    +
    +      override def writeIteratorToStream(dataOut: DataOutputStream): Unit 
= {
    +        val arrowSchema = ArrowUtils.toArrowSchema(schema)
    +        val allocator = ArrowUtils.rootAllocator.newChildAllocator(
    +          s"stdout writer for $pythonExec", 0, Long.MaxValue)
    +
    +        val root = VectorSchemaRoot.create(arrowSchema, allocator)
    +        val arrowWriter = ArrowWriter.create(root)
    +
    +        var closed = false
    +
    +        context.addTaskCompletionListener { _ =>
    +          if (!closed) {
    +            root.close()
    +            allocator.close()
    +          }
    +        }
    +
    +        val writer = new ArrowStreamWriter(root, null, dataOut)
    +        writer.start()
    +
    +        Utils.tryWithSafeFinally {
    +          while (inputIterator.hasNext) {
    +            var rowCount = 0
    +            while (inputIterator.hasNext && (batchSize <= 0 || rowCount < 
batchSize)) {
    +              val row = inputIterator.next()
    +              arrowWriter.write(row)
    +              rowCount += 1
    +            }
    +            arrowWriter.finish()
    +            writer.writeBatch()
    +            arrowWriter.reset()
    +          }
    +        } {
    +          writer.end()
    +          root.close()
    +          allocator.close()
    +          closed = true
    +        }
    --- End diff --
    
    nvm. `ArrowStreamPandasSerializer` is not a `FramedSerializer`.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to