Github user ueshin commented on a diff in the pull request:
https://github.com/apache/spark/pull/19147#discussion_r137507456
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/python/VectorizedPythonRunner.scala
---
@@ -0,0 +1,329 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import java.io.{BufferedInputStream, BufferedOutputStream,
DataInputStream, DataOutputStream}
+import java.net.Socket
+import java.nio.charset.StandardCharsets
+
+import scala.collection.JavaConverters._
+
+import org.apache.arrow.vector.VectorSchemaRoot
+import org.apache.arrow.vector.stream.{ArrowStreamReader,
ArrowStreamWriter}
+
+import org.apache.spark.{SparkEnv, SparkFiles, TaskContext}
+import org.apache.spark.api.python.{ChainedPythonFunctions,
PythonEvalType, PythonException, PythonRDD, SpecialLengths}
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.arrow.{ArrowUtils, ArrowWriter}
+import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector,
ColumnarBatch, ColumnVector}
+import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
+
+/**
+ * Similar to `PythonRunner`, but exchange data with Python worker via
columnar format.
+ */
+class VectorizedPythonRunner(
+ funcs: Seq[ChainedPythonFunctions],
+ batchSize: Int,
+ bufferSize: Int,
+ reuse_worker: Boolean,
+ argOffsets: Array[Array[Int]]) extends Logging {
+
+ require(funcs.length == argOffsets.length, "argOffsets should have the
same length as funcs")
+
+ // All the Python functions should have the same exec, version and
envvars.
+ private val envVars = funcs.head.funcs.head.envVars
+ private val pythonExec = funcs.head.funcs.head.pythonExec
+ private val pythonVer = funcs.head.funcs.head.pythonVer
+
+ // TODO: support accumulator in multiple UDF
+ private val accumulator = funcs.head.funcs.head.accumulator
+
+ // todo: return column batch?
+ def compute(
+ inputRows: Iterator[InternalRow],
+ schema: StructType,
+ partitionIndex: Int,
+ context: TaskContext): Iterator[InternalRow] = {
+ val startTime = System.currentTimeMillis
+ val env = SparkEnv.get
+ val localdir = env.blockManager.diskBlockManager.localDirs.map(f =>
f.getPath()).mkString(",")
+ envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor
thread
+ if (reuse_worker) {
+ envVars.put("SPARK_REUSE_WORKER", "1")
+ }
+ val worker: Socket = env.createPythonWorker(pythonExec,
envVars.asScala.toMap)
+ // Whether is the worker released into idle pool
+ @volatile var released = false
+
+ // Start a thread to feed the process input from our parent's iterator
+ val writerThread = new WriterThread(
+ env, worker, inputRows, schema, partitionIndex, context)
+
+ context.addTaskCompletionListener { context =>
+ writerThread.shutdownOnTaskCompletion()
+ if (!reuse_worker || !released) {
+ try {
+ worker.close()
+ } catch {
+ case e: Exception =>
+ logWarning("Failed to close worker socket", e)
+ }
+ }
+ }
+
+ writerThread.start()
+
+ val stream = new DataInputStream(new
BufferedInputStream(worker.getInputStream, bufferSize))
+
+ val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+ s"stdin reader for $pythonExec", 0, Long.MaxValue)
+ val reader = new ArrowStreamReader(stream, allocator)
+
+ new Iterator[InternalRow] {
+ private val root = reader.getVectorSchemaRoot
+ private val vectors = root.getFieldVectors.asScala.map { vector =>
+ new ArrowColumnVector(vector)
+ }.toArray[ColumnVector]
+
+ var closed = false
+
+ context.addTaskCompletionListener { _ =>
+ // todo: we need something like `read.end()`, which release all
the resources, but leave
+ // the input stream open. `reader.close` will close the socket and
we can't reuse worker.
+ // So here we simply not close the reader, which is problematic.
+ if (!closed) {
+ root.close()
+ allocator.close()
+ }
+ }
+
+ private[this] var batchLoaded = true
+ private[this] var currentIter: Iterator[InternalRow] = Iterator.empty
+
+ override def hasNext: Boolean = batchLoaded && (currentIter.hasNext
|| loadNextBatch()) || {
+ root.close()
+ allocator.close()
+ closed = true
+ false
+ }
+
+ private def loadNextBatch(): Boolean = {
+ batchLoaded = reader.loadNextBatch()
+ if (batchLoaded) {
+ val batch = new ColumnarBatch(schema, vectors, root.getRowCount)
--- End diff --
That's right, it doesn't return unsafe row.
But I believe it's performant enough because it can return values in row
directly from column vectors without copying to unsafe row.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]