This is an automated email from the ASF dual-hosted git repository. imbruced pushed a commit to branch arrow-worker in repository https://gitbox.apache.org/repos/asf/sedona.git
commit 6cbc9680f94c3ea58ebc717807f43500f3062dbe Author: pawelkocinski <[email protected]> AuthorDate: Sun Dec 21 23:29:55 2025 +0100 add code so far --- .../spark/api/python/SedonaPythonRunner.scala | 721 --------------------- .../execution/python/SedonaArrowPythonRunner.scala | 6 +- .../execution/python/SedonaPythonArrowInput.scala | 48 +- .../execution/python/SedonaPythonArrowOutput.scala | 144 ---- .../execution/python/SedonaPythonUDFRunner.scala | 4 +- .../sql/execution/python/SedonaWriterThread.scala | 2 +- 6 files changed, 51 insertions(+), 874 deletions(-) diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/api/python/SedonaPythonRunner.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/api/python/SedonaPythonRunner.scala deleted file mode 100644 index 38a9c7182b..0000000000 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/api/python/SedonaPythonRunner.scala +++ /dev/null @@ -1,721 +0,0 @@ -package org.apache.spark.api.python - -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import org.apache.sedona.common.geometrySerde.CoordinateType -import org.apache.spark._ -import org.apache.spark.SedonaSparkEnv -import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.Python._ -import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES} -import org.apache.spark.resource.ResourceProfile.{EXECUTOR_CORES_LOCAL_PROPERTY, PYSPARK_MEMORY_LOCAL_PROPERTY} -import org.apache.spark.security.SocketAuthHelper -import org.apache.spark.sql.execution.python.{ArrowPythonRunner, BatchIterator} -import org.apache.spark.util._ - -import java.io._ -import java.net._ -import java.nio.charset.StandardCharsets -import java.nio.charset.StandardCharsets.UTF_8 -import java.nio.file.{Path, Files => JavaFiles} -import java.util.concurrent.atomic.AtomicBoolean -import scala.collection.JavaConverters._ -import scala.util.control.NonFatal -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow -import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT -import org.apache.spark.sql.types.StructType - - -private object SedonaBasePythonRunner { - - private lazy val faultHandlerLogDir = Utils.createTempDir(namePrefix = "faulthandler") - - private def faultHandlerLogPath(pid: Int): Path = { - new File(faultHandlerLogDir, pid.toString).toPath - } -} - -/** - * A helper class to run Python mapPartition/UDFs in Spark. - * - * funcs is a list of independent Python functions, each one of them is a list of chained Python - * functions (from bottom to top). - */ -private[spark] abstract class SedonaBasePythonRunner[IN, OUT]( - protected val funcs: Seq[ChainedPythonFunctions], - protected val evalType: Int, - protected val argOffsets: Array[Array[Int]], - protected val jobArtifactUUID: Option[String], - schema: StructType) - extends Logging { - - require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs") - - private val conf = SparkEnv.get.conf - protected val bufferSize: Int = conf.get(BUFFER_SIZE) - protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT) - private val reuseWorker = conf.get(PYTHON_WORKER_REUSE) - private val faultHandlerEnabled = conf.get(PYTHON_WORKER_FAULTHANLDER_ENABLED) - protected val simplifiedTraceback: Boolean = false - - // All the Python functions should have the same exec, version and envvars. - protected val envVars: java.util.Map[String, String] = funcs.head.funcs.head.envVars - protected val pythonExec: String = funcs.head.funcs.head.pythonExec - protected val pythonVer: String = funcs.head.funcs.head.pythonVer - - // TODO: support accumulator in multiple UDF - protected val accumulator: PythonAccumulatorV2 = funcs.head.funcs.head.accumulator - - // Python accumulator is always set in production except in tests. See SPARK-27893 - private val maybeAccumulator: Option[PythonAccumulatorV2] = Option(accumulator) - - // Expose a ServerSocket to support method calls via socket from Python side. - private[spark] var serverSocket: Option[ServerSocket] = None - - // Authentication helper used when serving method calls via socket from Python side. - private lazy val authHelper = new SocketAuthHelper(conf) - - // each python worker gets an equal part of the allocation. the worker pool will grow to the - // number of concurrent tasks, which is determined by the number of cores in this executor. - private def getWorkerMemoryMb(mem: Option[Long], cores: Int): Option[Long] = { - mem.map(_ / cores) - } - - def compute( - inputIterator: Iterator[IN], - partitionIndex: Int, - context: TaskContext): Iterator[OUT] = { - val startTime = System.currentTimeMillis - val env = SparkEnv.get - - // Get the executor cores and pyspark memory, they are passed via the local properties when - // the user specified them in a ResourceProfile. - val execCoresProp = Option(context.getLocalProperty(EXECUTOR_CORES_LOCAL_PROPERTY)) - val memoryMb = Option(context.getLocalProperty(PYSPARK_MEMORY_LOCAL_PROPERTY)).map(_.toLong) - val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") - // If OMP_NUM_THREADS is not explicitly set, override it with the number of task cpus. - // See SPARK-42613 for details. - if (conf.getOption("spark.executorEnv.OMP_NUM_THREADS").isEmpty) { - envVars.put("OMP_NUM_THREADS", conf.get("spark.task.cpus", "1")) - } - envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor thread - if (reuseWorker) { - envVars.put("SPARK_REUSE_WORKER", "1") - } - if (simplifiedTraceback) { - envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1") - } - // SPARK-30299 this could be wrong with standalone mode when executor - // cores might not be correct because it defaults to all cores on the box. - val execCores = execCoresProp.map(_.toInt).getOrElse(conf.get(EXECUTOR_CORES)) - val workerMemoryMb = getWorkerMemoryMb(memoryMb, execCores) - if (workerMemoryMb.isDefined) { - envVars.put("PYSPARK_EXECUTOR_MEMORY_MB", workerMemoryMb.get.toString) - } - envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString) - envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString) - if (faultHandlerEnabled) { - envVars.put("PYTHON_FAULTHANDLER_DIR", SedonaBasePythonRunner.faultHandlerLogDir.toString) - } - - envVars.put("SPARK_JOB_ARTIFACT_UUID", jobArtifactUUID.getOrElse("default")) - - val (worker: Socket, pid: Option[Int]) = env.createPythonWorker( - pythonExec, envVars.asScala.toMap) - // Whether is the worker released into idle pool or closed. When any codes try to release or - // close a worker, they should use `releasedOrClosed.compareAndSet` to flip the state to make - // sure there is only one winner that is going to release or close the worker. - val releasedOrClosed = new AtomicBoolean(false) - - // Start a thread to feed the process input from our parent's iterator - val writerThread = newWriterThread(env, worker, inputIterator, partitionIndex, context) - - context.addTaskCompletionListener[Unit] { _ => - writerThread.shutdownOnTaskCompletion() - if (!reuseWorker || releasedOrClosed.compareAndSet(false, true)) { - try { - worker.close() - } catch { - case e: Exception => - logWarning("Failed to close worker socket", e) - } - } - } - - writerThread.start() - new WriterMonitorThread(SparkEnv.get, worker, writerThread, context).start() - if (reuseWorker) { - val key = (worker, context.taskAttemptId) - // SPARK-35009: avoid creating multiple monitor threads for the same python worker - // and task context - if (PythonRunner.runningMonitorThreads.add(key)) { - new MonitorThread(SparkEnv.get, worker, context).start() - } - } else { - new MonitorThread(SparkEnv.get, worker, context).start() - } - - // Return an iterator that read lines from the process's stdout - val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) - - val stdoutIterator = newReaderIterator( - stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) - new InterruptibleIterator(context, stdoutIterator) - } - - protected def newWriterThread( - env: SparkEnv, - worker: Socket, - inputIterator: Iterator[IN], - partitionIndex: Int, - context: TaskContext): WriterThread - - protected def newReaderIterator( - stream: DataInputStream, - writerThread: WriterThread, - startTime: Long, - env: SparkEnv, - worker: Socket, - pid: Option[Int], - releasedOrClosed: AtomicBoolean, - context: TaskContext): Iterator[OUT] - - /** - * The thread responsible for writing the data from the PythonRDD's parent iterator to the - * Python process. - */ - abstract class WriterThread( - env: SparkEnv, - worker: Socket, - inputIterator: Iterator[IN], - partitionIndex: Int, - context: TaskContext) - extends Thread(s"stdout writer for $pythonExec") { - - @volatile private var _exception: Throwable = null - - private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet - private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala)) - - setDaemon(true) - - /** Contains the throwable thrown while writing the parent iterator to the Python process. */ - def exception: Option[Throwable] = Option(_exception) - - /** - * Terminates the writer thread and waits for it to exit, ignoring any exceptions that may occur - * due to cleanup. - */ - def shutdownOnTaskCompletion(): Unit = { - assert(context.isCompleted) - this.interrupt() - // Task completion listeners that run after this method returns may invalidate - // `inputIterator`. For example, when `inputIterator` was generated by the off-heap vectorized - // reader, a task completion listener will free the underlying off-heap buffers. If the writer - // thread is still running when `inputIterator` is invalidated, it can cause a use-after-free - // bug that crashes the executor (SPARK-33277). Therefore this method must wait for the writer - // thread to exit before returning. - this.join() - } - - /** - * Writes a command section to the stream connected to the Python worker. - */ - protected def writeCommand(dataOut: DataOutputStream): Unit - - /** - * Writes input data to the stream connected to the Python worker. - */ - protected def writeIteratorToStream(dataOut: DataOutputStream): Unit - - override def run(): Unit = Utils.logUncaughtExceptions { - try { - println("ssss") - val toReadCRS = inputIterator.buffered.headOption.flatMap( - el => el.asInstanceOf[Iterator[IN]].buffered.headOption - ) - - val row = toReadCRS match { - case Some(value) => value match { - case row: GenericInternalRow => - Some(row) - } - case None => None - } - - val geometryFields = schema.zipWithIndex.filter { - case (field, index) => field.dataType == GeometryUDT - }.map { - case (field, index) => - if (row.isEmpty || row.get.values(index) == null) (index, 0) else { - val geom = row.get.get(index, GeometryUDT).asInstanceOf[Array[Byte]] - val preambleByte = geom(0) & 0xFF - val hasSrid = (preambleByte & 0x01) != 0 - - var srid = 0 - if (hasSrid) { - val srid2 = (geom(1) & 0xFF) << 16 - val srid1 = (geom(2) & 0xFF) << 8 - val srid0 = geom(3) & 0xFF - srid = srid2 | srid1 | srid0 - } - (index, srid) - } - } - - TaskContext.setTaskContext(context) - val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) - val dataOut = new DataOutputStream(stream) - - // Partition index - dataOut.writeInt(partitionIndex) - // Python version of driver - PythonRDD.writeUTF(pythonVer, dataOut) - // Init a ServerSocket to accept method calls from Python side. - val isBarrier = context.isInstanceOf[BarrierTaskContext] - if (isBarrier) { - serverSocket = Some(new ServerSocket(/* port */ 0, - /* backlog */ 1, - InetAddress.getByName("localhost"))) - // A call to accept() for ServerSocket shall block infinitely. - serverSocket.foreach(_.setSoTimeout(0)) - new Thread("accept-connections") { - setDaemon(true) - - override def run(): Unit = { - while (!serverSocket.get.isClosed()) { - var sock: Socket = null - try { - sock = serverSocket.get.accept() - // Wait for function call from python side. - sock.setSoTimeout(10000) - authHelper.authClient(sock) - val input = new DataInputStream(sock.getInputStream()) - val requestMethod = input.readInt() - // The BarrierTaskContext function may wait infinitely, socket shall not timeout - // before the function finishes. - sock.setSoTimeout(10000) - requestMethod match { - case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION => - barrierAndServe(requestMethod, sock) - case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION => - val length = input.readInt() - val message = new Array[Byte](length) - input.readFully(message) - barrierAndServe(requestMethod, sock, new String(message, UTF_8)) - case _ => - val out = new DataOutputStream(new BufferedOutputStream( - sock.getOutputStream)) - writeUTF(BarrierTaskContextMessageProtocol.ERROR_UNRECOGNIZED_FUNCTION, out) - } - } catch { - case e: SocketException if e.getMessage.contains("Socket closed") => - // It is possible that the ServerSocket is not closed, but the native socket - // has already been closed, we shall catch and silently ignore this case. - } finally { - if (sock != null) { - sock.close() - } - } - } - } - }.start() - } - val secret = if (isBarrier) { - authHelper.secret - } else { - "" - } - // Close ServerSocket on task completion. - serverSocket.foreach { server => - context.addTaskCompletionListener[Unit](_ => server.close()) - } - val boundPort: Int = serverSocket.map(_.getLocalPort).getOrElse(0) - if (boundPort == -1) { - val message = "ServerSocket failed to bind to Java side." - logError(message) - throw new SparkException(message) - } else if (isBarrier) { - logDebug(s"Started ServerSocket on port $boundPort.") - } - // Write out the TaskContextInfo - dataOut.writeBoolean(isBarrier) - dataOut.writeInt(boundPort) - val secretBytes = secret.getBytes(UTF_8) - dataOut.writeInt(secretBytes.length) - dataOut.write(secretBytes, 0, secretBytes.length) - dataOut.writeInt(context.stageId()) - dataOut.writeInt(context.partitionId()) - dataOut.writeInt(context.attemptNumber()) - dataOut.writeLong(context.taskAttemptId()) - dataOut.writeInt(context.cpus()) - val resources = context.resources() - dataOut.writeInt(resources.size) - resources.foreach { case (k, v) => - PythonRDD.writeUTF(k, dataOut) - PythonRDD.writeUTF(v.name, dataOut) - dataOut.writeInt(v.addresses.size) - v.addresses.foreach { case addr => - PythonRDD.writeUTF(addr, dataOut) - } - } - val localProps = context.getLocalProperties.asScala - dataOut.writeInt(localProps.size) - localProps.foreach { case (k, v) => - PythonRDD.writeUTF(k, dataOut) - PythonRDD.writeUTF(v, dataOut) - } - - // sparkFilesDir - val root = jobArtifactUUID.map { uuid => - new File(SparkFiles.getRootDirectory(), uuid).getAbsolutePath - }.getOrElse(SparkFiles.getRootDirectory()) - PythonRDD.writeUTF(root, dataOut) - // Python includes (*.zip and *.egg files) - dataOut.writeInt(pythonIncludes.size) - for (include <- pythonIncludes) { - PythonRDD.writeUTF(include, dataOut) - } - // Broadcast variables - val oldBids = PythonRDD.getWorkerBroadcasts(worker) - val newBids = broadcastVars.map(_.id).toSet - // number of different broadcasts - val toRemove = oldBids.diff(newBids) - val addedBids = newBids.diff(oldBids) - val cnt = toRemove.size + addedBids.size - val needsDecryptionServer = env.serializerManager.encryptionEnabled && addedBids.nonEmpty - dataOut.writeBoolean(needsDecryptionServer) - dataOut.writeInt(cnt) - - def sendBidsToRemove(): Unit = { - for (bid <- toRemove) { - // remove the broadcast from worker - dataOut.writeLong(-bid - 1) // bid >= 0 - oldBids.remove(bid) - } - } - - if (needsDecryptionServer) { - // if there is encryption, we setup a server which reads the encrypted files, and sends - // the decrypted data to python - val idsAndFiles = broadcastVars.flatMap { broadcast => - if (!oldBids.contains(broadcast.id)) { - oldBids.add(broadcast.id) - Some((broadcast.id, broadcast.value.path)) - } else { - None - } - } - val server = new EncryptedPythonBroadcastServer(env, idsAndFiles) - dataOut.writeInt(server.port) - logTrace(s"broadcast decryption server setup on ${server.port}") - PythonRDD.writeUTF(server.secret, dataOut) - sendBidsToRemove() - idsAndFiles.foreach { case (id, _) => - // send new broadcast - dataOut.writeLong(id) - } - dataOut.flush() - logTrace("waiting for python to read decrypted broadcast data from server") - server.waitTillBroadcastDataSent() - logTrace("done sending decrypted data to python") - } else { - sendBidsToRemove() - for (broadcast <- broadcastVars) { - if (!oldBids.contains(broadcast.id)) { - // send new broadcast - dataOut.writeLong(broadcast.id) - PythonRDD.writeUTF(broadcast.value.path, dataOut) - oldBids.add(broadcast.id) - } - } - } - dataOut.flush() - - dataOut.writeInt(evalType) - writeCommand(dataOut) - - // write number of geometry fields - dataOut.writeInt(geometryFields.length) - // write geometry field indices and their SRIDs - geometryFields.foreach { case (index, srid) => - dataOut.writeInt(index) - dataOut.writeInt(srid) - } - - writeIteratorToStream(dataOut) - - dataOut.writeInt(SpecialLengths.END_OF_STREAM) - dataOut.flush() - } catch { - case t: Throwable if (NonFatal(t) || t.isInstanceOf[Exception]) => - if (context.isCompleted || context.isInterrupted) { - logDebug("Exception/NonFatal Error thrown after task completion (likely due to " + - "cleanup)", t) - if (!worker.isClosed) { - Utils.tryLog(worker.shutdownOutput()) - } - } else { - // We must avoid throwing exceptions/NonFatals here, because the thread uncaught - // exception handler will kill the whole executor (see - // org.apache.spark.executor.Executor). - _exception = t - if (!worker.isClosed) { - Utils.tryLog(worker.shutdownOutput()) - } - } - } - } - - /** - * Gateway to call BarrierTaskContext methods. - */ - def barrierAndServe(requestMethod: Int, sock: Socket, message: String = ""): Unit = { - require( - serverSocket.isDefined, - "No available ServerSocket to redirect the BarrierTaskContext method call." - ) - val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) - try { - val messages = requestMethod match { - case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION => - context.asInstanceOf[BarrierTaskContext].barrier() - Array(BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS) - case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION => - context.asInstanceOf[BarrierTaskContext].allGather(message) - } - out.writeInt(messages.length) - messages.foreach(writeUTF(_, out)) - } catch { - case e: SparkException => - writeUTF(e.getMessage, out) - } finally { - out.close() - } - } - - def writeUTF(str: String, dataOut: DataOutputStream): Unit = { - val bytes = str.getBytes(UTF_8) - dataOut.writeInt(bytes.length) - dataOut.write(bytes) - } - } - - abstract class ReaderIterator( - stream: DataInputStream, - writerThread: WriterThread, - startTime: Long, - env: SparkEnv, - worker: Socket, - pid: Option[Int], - releasedOrClosed: AtomicBoolean, - context: TaskContext) - extends Iterator[OUT] { - - private var nextObj: OUT = _ - private var eos = false - - override def hasNext: Boolean = nextObj != null || { - if (!eos) { - nextObj = read() - hasNext - } else { - false - } - } - - override def next(): OUT = { - if (hasNext) { - val obj = nextObj - nextObj = null.asInstanceOf[OUT] - obj - } else { - Iterator.empty.next() - } - } - - /** - * Reads next object from the stream. - * When the stream reaches end of data, needs to process the following sections, - * and then returns null. - */ - protected def read(): OUT - - protected def handleTimingData(): Unit = { - // Timing data from worker - val bootTime = stream.readLong() - val initTime = stream.readLong() - val finishTime = stream.readLong() - val boot = bootTime - startTime - val init = initTime - bootTime - val finish = finishTime - initTime - val total = finishTime - startTime - logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, - init, finish)) - val memoryBytesSpilled = stream.readLong() - val diskBytesSpilled = stream.readLong() - context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled) - context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled) - } - - protected def handlePythonException(): PythonException = { - // Signals that an exception has been thrown in python - val exLength = stream.readInt() - val obj = new Array[Byte](exLength) - stream.readFully(obj) - new PythonException(new String(obj, StandardCharsets.UTF_8), - writerThread.exception.orNull) - } - - protected def handleEndOfDataSection(): Unit = { - // We've finished the data section of the output, but we can still - // read some accumulator updates: - val numAccumulatorUpdates = stream.readInt() - (1 to numAccumulatorUpdates).foreach { _ => - val updateLen = stream.readInt() - val update = new Array[Byte](updateLen) - stream.readFully(update) - maybeAccumulator.foreach(_.add(update)) - } - // Check whether the worker is ready to be re-used. - if (stream.readInt() == SpecialLengths.END_OF_STREAM) { - if (reuseWorker && releasedOrClosed.compareAndSet(false, true)) { - env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker) - } - } - eos = true - } - - protected val handleException: PartialFunction[Throwable, OUT] = { - case e: Exception if context.isInterrupted => - logDebug("Exception thrown after task interruption", e) - throw new TaskKilledException(context.getKillReason().getOrElse("unknown reason")) - - case e: Exception if writerThread.exception.isDefined => - logError("Python worker exited unexpectedly (crashed)", e) - logError("This may have been caused by a prior exception:", writerThread.exception.get) - throw writerThread.exception.get - - case eof: EOFException if faultHandlerEnabled && pid.isDefined && - JavaFiles.exists(SedonaBasePythonRunner.faultHandlerLogPath(pid.get)) => - val path = SedonaBasePythonRunner.faultHandlerLogPath(pid.get) - val error = String.join("\n", JavaFiles.readAllLines(path)) + "\n" - JavaFiles.deleteIfExists(path) - throw new SparkException(s"Python worker exited unexpectedly (crashed): $error", eof) - - case eof: EOFException => - throw new SparkException("Python worker exited unexpectedly (crashed)", eof) - } - } - - /** - * It is necessary to have a monitor thread for python workers if the user cancels with - * interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the - * threads can block indefinitely. - */ - class MonitorThread(env: SparkEnv, worker: Socket, context: TaskContext) - extends Thread(s"Worker Monitor for $pythonExec") { - - /** How long to wait before killing the python worker if a task cannot be interrupted. */ - private val taskKillTimeout = env.conf.get(PYTHON_TASK_KILL_TIMEOUT) - - setDaemon(true) - - private def monitorWorker(): Unit = { - // Kill the worker if it is interrupted, checking until task completion. - // TODO: This has a race condition if interruption occurs, as completed may still become true. - while (!context.isInterrupted && !context.isCompleted) { - Thread.sleep(2000) - } - if (!context.isCompleted) { - Thread.sleep(taskKillTimeout) - if (!context.isCompleted) { - try { - // Mimic the task name used in `Executor` to help the user find out the task to blame. - val taskName = s"${context.partitionId}.${context.attemptNumber} " + - s"in stage ${context.stageId} (TID ${context.taskAttemptId})" - logWarning(s"Incomplete task $taskName interrupted: Attempting to kill Python Worker") - env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker) - } catch { - case e: Exception => - logError("Exception when trying to kill worker", e) - } - } - } - } - - override def run(): Unit = { - try { - monitorWorker() - } finally { - if (reuseWorker) { - val key = (worker, context.taskAttemptId) - PythonRunner.runningMonitorThreads.remove(key) - } - } - } - } - - /** - * This thread monitors the WriterThread and kills it in case of deadlock. - * - * A deadlock can arise if the task completes while the writer thread is sending input to the - * Python process (e.g. due to the use of `take()`), and the Python process is still producing - * output. When the inputs are sufficiently large, this can result in a deadlock due to the use of - * blocking I/O (SPARK-38677). To resolve the deadlock, we need to close the socket. - */ - class WriterMonitorThread( - env: SparkEnv, worker: Socket, writerThread: WriterThread, context: TaskContext) - extends Thread(s"Writer Monitor for $pythonExec (writer thread id ${writerThread.getId})") { - - /** - * How long to wait before closing the socket if the writer thread has not exited after the task - * ends. - */ - private val taskKillTimeout = env.conf.get(PYTHON_TASK_KILL_TIMEOUT) - - setDaemon(true) - - override def run(): Unit = { - // Wait until the task is completed (or the writer thread exits, in which case this thread has - // nothing to do). - while (!context.isCompleted && writerThread.isAlive) { - Thread.sleep(2000) - } - if (writerThread.isAlive) { - Thread.sleep(taskKillTimeout) - // If the writer thread continues running, this indicates a deadlock. Kill the worker to - // resolve the deadlock. - if (writerThread.isAlive) { - try { - // Mimic the task name used in `Executor` to help the user find out the task to blame. - val taskName = s"${context.partitionId}.${context.attemptNumber} " + - s"in stage ${context.stageId} (TID ${context.taskAttemptId})" - logWarning( - s"Detected deadlock while completing task $taskName: " + - "Attempting to kill Python Worker") - env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker) - } catch { - case e: Exception => - logError("Exception when trying to kill worker", e) - } - } - } - } - } -} diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala index c3eafc9766..2a28eba6db 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala @@ -37,10 +37,10 @@ class SedonaArrowPythonRunner( protected override val workerConf: Map[String, String], val pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String]) - extends SedonaBasePythonRunner[Iterator[InternalRow], ColumnarBatch]( - funcs, evalType, argOffsets, jobArtifactUUID, schema) + extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch]( + funcs, evalType, argOffsets, jobArtifactUUID) with SedonaBasicPythonArrowInput - with SedonaBasicPythonArrowOutput { + with BasicPythonArrowOutput { override val pythonExec: String = SQLConf.get.pysparkWorkerPythonExecutable.getOrElse( diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala index bf353539bc..7bc0d322c2 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala @@ -20,10 +20,12 @@ package org.apache.spark.sql.execution.python import org.apache.arrow.vector.VectorSchemaRoot import org.apache.arrow.vector.ipc.ArrowStreamWriter import org.apache.spark.api.python -import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonRDD, SedonaBasePythonRunner} +import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonRDD} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.execution.arrow.ArrowWriter import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.sql.util.ArrowUtils.toArrowSchema @@ -37,7 +39,7 @@ import java.net.Socket * A trait that can be mixed-in with [[python.BasePythonRunner]]. It implements the logic from * JVM (an iterator of internal rows + additional data if required) to Python (Arrow). */ -private[python] trait SedonaPythonArrowInput[IN] { self: SedonaBasePythonRunner[IN, _] => +private[python] trait SedonaPythonArrowInput[IN] { self: BasePythonRunner[IN, _] => protected val workerConf: Map[String, String] protected val schema: StructType @@ -84,6 +86,46 @@ private[python] trait SedonaPythonArrowInput[IN] { self: SedonaBasePythonRunner[ protected override def writeCommand(dataOut: DataOutputStream): Unit = { handleMetadataBeforeExec(dataOut) writeUDF(dataOut, funcs, argOffsets) + + val toReadCRS = inputIterator.buffered.headOption.flatMap( + el => el.asInstanceOf[Iterator[IN]].buffered.headOption + ) + + val row = toReadCRS match { + case Some(value) => value match { + case row: GenericInternalRow => + Some(row) + } + case None => None + } + + val geometryFields = schema.zipWithIndex.filter { + case (field, index) => field.dataType == GeometryUDT + }.map { + case (field, index) => + if (row.isEmpty || row.get.values(index) == null) (index, 0) else { + val geom = row.get.get(index, GeometryUDT).asInstanceOf[Array[Byte]] + val preambleByte = geom(0) & 0xFF + val hasSrid = (preambleByte & 0x01) != 0 + + var srid = 0 + if (hasSrid) { + val srid2 = (geom(1) & 0xFF) << 16 + val srid1 = (geom(2) & 0xFF) << 8 + val srid0 = geom(3) & 0xFF + srid = srid2 | srid1 | srid0 + } + (index, srid) + } + } + + // write number of geometry fields + dataOut.writeInt(geometryFields.length) + // write geometry field indices and their SRIDs + geometryFields.foreach { case (index, srid) => + dataOut.writeInt(index) + dataOut.writeInt(srid) + } } protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { @@ -124,7 +166,7 @@ private[python] trait SedonaPythonArrowInput[IN] { self: SedonaBasePythonRunner[ private[python] trait SedonaBasicPythonArrowInput extends SedonaPythonArrowInput[Iterator[InternalRow]] { - self: SedonaBasePythonRunner[Iterator[InternalRow], _] => + self: BasePythonRunner[Iterator[InternalRow], _] => protected def writeIteratorToArrowStream( root: VectorSchemaRoot, diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowOutput.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowOutput.scala deleted file mode 100644 index 316cb32c3e..0000000000 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowOutput.scala +++ /dev/null @@ -1,144 +0,0 @@ -package org.apache.spark.sql.execution.python - -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import org.apache.arrow.vector.VectorSchemaRoot -import org.apache.arrow.vector.ipc.ArrowStreamReader -import org.apache.spark.api.python.{BasePythonRunner, SedonaBasePythonRunner, SpecialLengths} -import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.util.ArrowUtils -import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnVector, ColumnarBatch} -import org.apache.spark.{SparkEnv, TaskContext} - -import java.io.DataInputStream -import java.net.Socket -import java.util.concurrent.atomic.AtomicBoolean -import scala.collection.JavaConverters._ - -/** - * A trait that can be mixed-in with [[BasePythonRunner]]. It implements the logic from - * Python (Arrow) to JVM (output type being deserialized from ColumnarBatch). - */ -private[python] trait SedonaPythonArrowOutput[OUT <: AnyRef] { self: SedonaBasePythonRunner[_, OUT] => - - protected def pythonMetrics: Map[String, SQLMetric] - - protected def handleMetadataAfterExec(stream: DataInputStream): Unit = { } - - protected def deserializeColumnarBatch(batch: ColumnarBatch, schema: StructType): OUT - - protected def newReaderIterator( - stream: DataInputStream, - writerThread: WriterThread, - startTime: Long, - env: SparkEnv, - worker: Socket, - pid: Option[Int], - releasedOrClosed: AtomicBoolean, - context: TaskContext): Iterator[OUT] = { - - new ReaderIterator( - stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) { - - private val allocator = ArrowUtils.rootAllocator.newChildAllocator( - s"stdin reader for $pythonExec", 0, Long.MaxValue) - - private var reader: ArrowStreamReader = _ - private var root: VectorSchemaRoot = _ - private var schema: StructType = _ - private var vectors: Array[ColumnVector] = _ - - context.addTaskCompletionListener[Unit] { _ => - if (reader != null) { - reader.close(false) - } - allocator.close() - } - - private var batchLoaded = true - - protected override def handleEndOfDataSection(): Unit = { - handleMetadataAfterExec(stream) - super.handleEndOfDataSection() - } - - protected override def read(): OUT = { - if (writerThread.exception.isDefined) { - throw writerThread.exception.get - } - try { - if (reader != null && batchLoaded) { - val bytesReadStart = reader.bytesRead() - batchLoaded = reader.loadNextBatch() - if (batchLoaded) { - val batch = new ColumnarBatch(vectors) - val rowCount = root.getRowCount - batch.setNumRows(root.getRowCount) - val bytesReadEnd = reader.bytesRead() - pythonMetrics("pythonNumRowsReceived") += rowCount - pythonMetrics("pythonDataReceived") += bytesReadEnd - bytesReadStart - val result = deserializeColumnarBatch(batch, schema) - result - } else { - reader.close(false) - allocator.close() - // Reach end of stream. Call `read()` again to read control data. - read() - } - } else { - stream.readInt() match { - case SpecialLengths.START_ARROW_STREAM => - reader = new ArrowStreamReader(stream, allocator) - root = reader.getVectorSchemaRoot() - schema = ArrowUtils.fromArrowSchema(root.getSchema()) - vectors = root.getFieldVectors().asScala.map { vector => - new ArrowColumnVector(vector) - }.toArray[ColumnVector] - read() - case SpecialLengths.TIMING_DATA => - handleTimingData() - read() - case SpecialLengths.PYTHON_EXCEPTION_THROWN => - throw handlePythonException() - case SpecialLengths.END_OF_DATA_SECTION => - handleEndOfDataSection() - null.asInstanceOf[OUT] - } - } - } catch { - case e: Exception => - // If an exception happens, make sure to close the reader to release resources. - if (reader != null) { - reader.close(false) - } - allocator.close() - throw e - } - } - } - } -} - -private[python] trait SedonaBasicPythonArrowOutput extends SedonaPythonArrowOutput[ColumnarBatch] { - self: SedonaBasePythonRunner[_, ColumnarBatch] => - - protected def deserializeColumnarBatch( - batch: ColumnarBatch, - schema: StructType): ColumnarBatch = batch -} diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonUDFRunner.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonUDFRunner.scala index beb49c1dde..dcf93b5213 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonUDFRunner.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonUDFRunner.scala @@ -34,8 +34,8 @@ abstract class SedonaBasePythonUDFRunner( argOffsets: Array[Array[Int]], pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String]) - extends SedonaBasePythonRunner[Array[Byte], Array[Byte]]( - funcs, evalType, argOffsets, jobArtifactUUID, null) { + extends BasePythonRunner[Array[Byte], Array[Byte]]( + funcs, evalType, argOffsets, jobArtifactUUID) { override val pythonExec: String = SQLConf.get.pysparkWorkerPythonExecutable.getOrElse( diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaWriterThread.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaWriterThread.scala index b28cf81906..57cf2dc7bb 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaWriterThread.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaWriterThread.scala @@ -4,7 +4,7 @@ package org.apache.spark.sql.execution.python import org.apache.sedona.common.geometrySerde.CoordinateType import org.apache.spark._ import org.apache.spark.SedonaSparkEnv -import org.apache.spark.api.python.{BarrierTaskContextMessageProtocol, BasePythonRunner, ChainedPythonFunctions, EncryptedPythonBroadcastServer, PythonRDD, SedonaBasePythonRunner, SpecialLengths} +import org.apache.spark.api.python.{BarrierTaskContextMessageProtocol, BasePythonRunner, ChainedPythonFunctions, EncryptedPythonBroadcastServer, PythonRDD, SpecialLengths} import org.apache.spark.internal.Logging import org.apache.spark.internal.config.Python._ import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES}
