This is an automated email from the ASF dual-hosted git repository.

imbruced pushed a commit to branch arrow-worker
in repository https://gitbox.apache.org/repos/asf/sedona.git

commit 6cbc9680f94c3ea58ebc717807f43500f3062dbe
Author: pawelkocinski <[email protected]>
AuthorDate: Sun Dec 21 23:29:55 2025 +0100

    add code so far
---
 .../spark/api/python/SedonaPythonRunner.scala      | 721 ---------------------
 .../execution/python/SedonaArrowPythonRunner.scala |   6 +-
 .../execution/python/SedonaPythonArrowInput.scala  |  48 +-
 .../execution/python/SedonaPythonArrowOutput.scala | 144 ----
 .../execution/python/SedonaPythonUDFRunner.scala   |   4 +-
 .../sql/execution/python/SedonaWriterThread.scala  |   2 +-
 6 files changed, 51 insertions(+), 874 deletions(-)

diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/api/python/SedonaPythonRunner.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/spark/api/python/SedonaPythonRunner.scala
deleted file mode 100644
index 38a9c7182b..0000000000
--- 
a/spark/spark-3.5/src/main/scala/org/apache/spark/api/python/SedonaPythonRunner.scala
+++ /dev/null
@@ -1,721 +0,0 @@
-package org.apache.spark.api.python
-
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-import org.apache.sedona.common.geometrySerde.CoordinateType
-import org.apache.spark._
-import org.apache.spark.SedonaSparkEnv
-import org.apache.spark.internal.Logging
-import org.apache.spark.internal.config.Python._
-import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES}
-import 
org.apache.spark.resource.ResourceProfile.{EXECUTOR_CORES_LOCAL_PROPERTY, 
PYSPARK_MEMORY_LOCAL_PROPERTY}
-import org.apache.spark.security.SocketAuthHelper
-import org.apache.spark.sql.execution.python.{ArrowPythonRunner, BatchIterator}
-import org.apache.spark.util._
-
-import java.io._
-import java.net._
-import java.nio.charset.StandardCharsets
-import java.nio.charset.StandardCharsets.UTF_8
-import java.nio.file.{Path, Files => JavaFiles}
-import java.util.concurrent.atomic.AtomicBoolean
-import scala.collection.JavaConverters._
-import scala.util.control.NonFatal
-import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
-import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
-import org.apache.spark.sql.types.StructType
-
-
-private object SedonaBasePythonRunner {
-
-  private lazy val faultHandlerLogDir = Utils.createTempDir(namePrefix = 
"faulthandler")
-
-  private def faultHandlerLogPath(pid: Int): Path = {
-    new File(faultHandlerLogDir, pid.toString).toPath
-  }
-}
-
-/**
- * A helper class to run Python mapPartition/UDFs in Spark.
- *
- * funcs is a list of independent Python functions, each one of them is a list 
of chained Python
- * functions (from bottom to top).
- */
-private[spark] abstract class SedonaBasePythonRunner[IN, OUT](
-                                                               protected val 
funcs: Seq[ChainedPythonFunctions],
-                                                               protected val 
evalType: Int,
-                                                               protected val 
argOffsets: Array[Array[Int]],
-                                                               protected val 
jobArtifactUUID: Option[String],
-                                                               schema: 
StructType)
-  extends Logging {
-
-  require(funcs.length == argOffsets.length, "argOffsets should have the same 
length as funcs")
-
-  private val conf = SparkEnv.get.conf
-  protected val bufferSize: Int = conf.get(BUFFER_SIZE)
-  protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
-  private val reuseWorker = conf.get(PYTHON_WORKER_REUSE)
-  private val faultHandlerEnabled = 
conf.get(PYTHON_WORKER_FAULTHANLDER_ENABLED)
-  protected val simplifiedTraceback: Boolean = false
-
-  // All the Python functions should have the same exec, version and envvars.
-  protected val envVars: java.util.Map[String, String] = 
funcs.head.funcs.head.envVars
-  protected val pythonExec: String = funcs.head.funcs.head.pythonExec
-  protected val pythonVer: String = funcs.head.funcs.head.pythonVer
-
-  // TODO: support accumulator in multiple UDF
-  protected val accumulator: PythonAccumulatorV2 = 
funcs.head.funcs.head.accumulator
-
-  // Python accumulator is always set in production except in tests. See 
SPARK-27893
-  private val maybeAccumulator: Option[PythonAccumulatorV2] = 
Option(accumulator)
-
-  // Expose a ServerSocket to support method calls via socket from Python side.
-  private[spark] var serverSocket: Option[ServerSocket] = None
-
-  // Authentication helper used when serving method calls via socket from 
Python side.
-  private lazy val authHelper = new SocketAuthHelper(conf)
-
-  // each python worker gets an equal part of the allocation. the worker pool 
will grow to the
-  // number of concurrent tasks, which is determined by the number of cores in 
this executor.
-  private def getWorkerMemoryMb(mem: Option[Long], cores: Int): Option[Long] = 
{
-    mem.map(_ / cores)
-  }
-
-  def compute(
-               inputIterator: Iterator[IN],
-               partitionIndex: Int,
-               context: TaskContext): Iterator[OUT] = {
-    val startTime = System.currentTimeMillis
-    val env = SparkEnv.get
-
-    // Get the executor cores and pyspark memory, they are passed via the 
local properties when
-    // the user specified them in a ResourceProfile.
-    val execCoresProp = 
Option(context.getLocalProperty(EXECUTOR_CORES_LOCAL_PROPERTY))
-    val memoryMb = 
Option(context.getLocalProperty(PYSPARK_MEMORY_LOCAL_PROPERTY)).map(_.toLong)
-    val localdir = env.blockManager.diskBlockManager.localDirs.map(f => 
f.getPath()).mkString(",")
-    // If OMP_NUM_THREADS is not explicitly set, override it with the number 
of task cpus.
-    // See SPARK-42613 for details.
-    if (conf.getOption("spark.executorEnv.OMP_NUM_THREADS").isEmpty) {
-      envVars.put("OMP_NUM_THREADS", conf.get("spark.task.cpus", "1"))
-    }
-    envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor 
thread
-    if (reuseWorker) {
-      envVars.put("SPARK_REUSE_WORKER", "1")
-    }
-    if (simplifiedTraceback) {
-      envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1")
-    }
-    // SPARK-30299 this could be wrong with standalone mode when executor
-    // cores might not be correct because it defaults to all cores on the box.
-    val execCores = 
execCoresProp.map(_.toInt).getOrElse(conf.get(EXECUTOR_CORES))
-    val workerMemoryMb = getWorkerMemoryMb(memoryMb, execCores)
-    if (workerMemoryMb.isDefined) {
-      envVars.put("PYSPARK_EXECUTOR_MEMORY_MB", workerMemoryMb.get.toString)
-    }
-    envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString)
-    envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString)
-    if (faultHandlerEnabled) {
-      envVars.put("PYTHON_FAULTHANDLER_DIR", 
SedonaBasePythonRunner.faultHandlerLogDir.toString)
-    }
-
-    envVars.put("SPARK_JOB_ARTIFACT_UUID", 
jobArtifactUUID.getOrElse("default"))
-
-    val (worker: Socket, pid: Option[Int]) = env.createPythonWorker(
-      pythonExec, envVars.asScala.toMap)
-    // Whether is the worker released into idle pool or closed. When any codes 
try to release or
-    // close a worker, they should use `releasedOrClosed.compareAndSet` to 
flip the state to make
-    // sure there is only one winner that is going to release or close the 
worker.
-    val releasedOrClosed = new AtomicBoolean(false)
-
-    // Start a thread to feed the process input from our parent's iterator
-    val writerThread = newWriterThread(env, worker, inputIterator, 
partitionIndex, context)
-
-    context.addTaskCompletionListener[Unit] { _ =>
-      writerThread.shutdownOnTaskCompletion()
-      if (!reuseWorker || releasedOrClosed.compareAndSet(false, true)) {
-        try {
-          worker.close()
-        } catch {
-          case e: Exception =>
-            logWarning("Failed to close worker socket", e)
-        }
-      }
-    }
-
-    writerThread.start()
-    new WriterMonitorThread(SparkEnv.get, worker, writerThread, 
context).start()
-    if (reuseWorker) {
-      val key = (worker, context.taskAttemptId)
-      // SPARK-35009: avoid creating multiple monitor threads for the same 
python worker
-      // and task context
-      if (PythonRunner.runningMonitorThreads.add(key)) {
-        new MonitorThread(SparkEnv.get, worker, context).start()
-      }
-    } else {
-      new MonitorThread(SparkEnv.get, worker, context).start()
-    }
-
-    // Return an iterator that read lines from the process's stdout
-    val stream = new DataInputStream(new 
BufferedInputStream(worker.getInputStream, bufferSize))
-
-    val stdoutIterator = newReaderIterator(
-      stream, writerThread, startTime, env, worker, pid, releasedOrClosed, 
context)
-    new InterruptibleIterator(context, stdoutIterator)
-  }
-
-  protected def newWriterThread(
-                                 env: SparkEnv,
-                                 worker: Socket,
-                                 inputIterator: Iterator[IN],
-                                 partitionIndex: Int,
-                                 context: TaskContext): WriterThread
-
-  protected def newReaderIterator(
-                                   stream: DataInputStream,
-                                   writerThread: WriterThread,
-                                   startTime: Long,
-                                   env: SparkEnv,
-                                   worker: Socket,
-                                   pid: Option[Int],
-                                   releasedOrClosed: AtomicBoolean,
-                                   context: TaskContext): Iterator[OUT]
-
-  /**
-   * The thread responsible for writing the data from the PythonRDD's parent 
iterator to the
-   * Python process.
-   */
-  abstract class WriterThread(
-                               env: SparkEnv,
-                               worker: Socket,
-                               inputIterator: Iterator[IN],
-                               partitionIndex: Int,
-                               context: TaskContext)
-    extends Thread(s"stdout writer for $pythonExec") {
-
-    @volatile private var _exception: Throwable = null
-
-    private val pythonIncludes = 
funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet
-    private val broadcastVars = 
funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala))
-
-    setDaemon(true)
-
-    /** Contains the throwable thrown while writing the parent iterator to the 
Python process. */
-    def exception: Option[Throwable] = Option(_exception)
-
-    /**
-     * Terminates the writer thread and waits for it to exit, ignoring any 
exceptions that may occur
-     * due to cleanup.
-     */
-    def shutdownOnTaskCompletion(): Unit = {
-      assert(context.isCompleted)
-      this.interrupt()
-      // Task completion listeners that run after this method returns may 
invalidate
-      // `inputIterator`. For example, when `inputIterator` was generated by 
the off-heap vectorized
-      // reader, a task completion listener will free the underlying off-heap 
buffers. If the writer
-      // thread is still running when `inputIterator` is invalidated, it can 
cause a use-after-free
-      // bug that crashes the executor (SPARK-33277). Therefore this method 
must wait for the writer
-      // thread to exit before returning.
-      this.join()
-    }
-
-    /**
-     * Writes a command section to the stream connected to the Python worker.
-     */
-    protected def writeCommand(dataOut: DataOutputStream): Unit
-
-    /**
-     * Writes input data to the stream connected to the Python worker.
-     */
-    protected def writeIteratorToStream(dataOut: DataOutputStream): Unit
-
-    override def run(): Unit = Utils.logUncaughtExceptions {
-      try {
-        println("ssss")
-        val toReadCRS = inputIterator.buffered.headOption.flatMap(
-          el => el.asInstanceOf[Iterator[IN]].buffered.headOption
-        )
-
-        val row = toReadCRS match {
-          case Some(value) => value match {
-            case row: GenericInternalRow =>
-              Some(row)
-          }
-          case None => None
-        }
-
-        val geometryFields = schema.zipWithIndex.filter {
-          case (field, index) => field.dataType == GeometryUDT
-        }.map {
-          case (field, index) =>
-            if (row.isEmpty || row.get.values(index) == null) (index, 0) else {
-              val geom = row.get.get(index, 
GeometryUDT).asInstanceOf[Array[Byte]]
-              val preambleByte = geom(0) & 0xFF
-              val hasSrid = (preambleByte & 0x01) != 0
-
-              var srid = 0
-              if (hasSrid) {
-                val srid2 = (geom(1) & 0xFF) << 16
-                val srid1 = (geom(2) & 0xFF) << 8
-                val srid0 = geom(3) & 0xFF
-                srid = srid2 | srid1 | srid0
-              }
-              (index, srid)
-            }
-        }
-
-        TaskContext.setTaskContext(context)
-        val stream = new BufferedOutputStream(worker.getOutputStream, 
bufferSize)
-        val dataOut = new DataOutputStream(stream)
-
-        // Partition index
-        dataOut.writeInt(partitionIndex)
-        // Python version of driver
-        PythonRDD.writeUTF(pythonVer, dataOut)
-        // Init a ServerSocket to accept method calls from Python side.
-        val isBarrier = context.isInstanceOf[BarrierTaskContext]
-        if (isBarrier) {
-          serverSocket = Some(new ServerSocket(/* port */ 0,
-            /* backlog */ 1,
-            InetAddress.getByName("localhost")))
-          // A call to accept() for ServerSocket shall block infinitely.
-          serverSocket.foreach(_.setSoTimeout(0))
-          new Thread("accept-connections") {
-            setDaemon(true)
-
-            override def run(): Unit = {
-              while (!serverSocket.get.isClosed()) {
-                var sock: Socket = null
-                try {
-                  sock = serverSocket.get.accept()
-                  // Wait for function call from python side.
-                  sock.setSoTimeout(10000)
-                  authHelper.authClient(sock)
-                  val input = new DataInputStream(sock.getInputStream())
-                  val requestMethod = input.readInt()
-                  // The BarrierTaskContext function may wait infinitely, 
socket shall not timeout
-                  // before the function finishes.
-                  sock.setSoTimeout(10000)
-                  requestMethod match {
-                    case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
-                      barrierAndServe(requestMethod, sock)
-                    case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION 
=>
-                      val length = input.readInt()
-                      val message = new Array[Byte](length)
-                      input.readFully(message)
-                      barrierAndServe(requestMethod, sock, new String(message, 
UTF_8))
-                    case _ =>
-                      val out = new DataOutputStream(new BufferedOutputStream(
-                        sock.getOutputStream))
-                      
writeUTF(BarrierTaskContextMessageProtocol.ERROR_UNRECOGNIZED_FUNCTION, out)
-                  }
-                } catch {
-                  case e: SocketException if e.getMessage.contains("Socket 
closed") =>
-                  // It is possible that the ServerSocket is not closed, but 
the native socket
-                  // has already been closed, we shall catch and silently 
ignore this case.
-                } finally {
-                  if (sock != null) {
-                    sock.close()
-                  }
-                }
-              }
-            }
-          }.start()
-        }
-        val secret = if (isBarrier) {
-          authHelper.secret
-        } else {
-          ""
-        }
-        // Close ServerSocket on task completion.
-        serverSocket.foreach { server =>
-          context.addTaskCompletionListener[Unit](_ => server.close())
-        }
-        val boundPort: Int = serverSocket.map(_.getLocalPort).getOrElse(0)
-        if (boundPort == -1) {
-          val message = "ServerSocket failed to bind to Java side."
-          logError(message)
-          throw new SparkException(message)
-        } else if (isBarrier) {
-          logDebug(s"Started ServerSocket on port $boundPort.")
-        }
-        // Write out the TaskContextInfo
-        dataOut.writeBoolean(isBarrier)
-        dataOut.writeInt(boundPort)
-        val secretBytes = secret.getBytes(UTF_8)
-        dataOut.writeInt(secretBytes.length)
-        dataOut.write(secretBytes, 0, secretBytes.length)
-        dataOut.writeInt(context.stageId())
-        dataOut.writeInt(context.partitionId())
-        dataOut.writeInt(context.attemptNumber())
-        dataOut.writeLong(context.taskAttemptId())
-        dataOut.writeInt(context.cpus())
-        val resources = context.resources()
-        dataOut.writeInt(resources.size)
-        resources.foreach { case (k, v) =>
-          PythonRDD.writeUTF(k, dataOut)
-          PythonRDD.writeUTF(v.name, dataOut)
-          dataOut.writeInt(v.addresses.size)
-          v.addresses.foreach { case addr =>
-            PythonRDD.writeUTF(addr, dataOut)
-          }
-        }
-        val localProps = context.getLocalProperties.asScala
-        dataOut.writeInt(localProps.size)
-        localProps.foreach { case (k, v) =>
-          PythonRDD.writeUTF(k, dataOut)
-          PythonRDD.writeUTF(v, dataOut)
-        }
-
-        // sparkFilesDir
-        val root = jobArtifactUUID.map { uuid =>
-          new File(SparkFiles.getRootDirectory(), uuid).getAbsolutePath
-        }.getOrElse(SparkFiles.getRootDirectory())
-        PythonRDD.writeUTF(root, dataOut)
-        // Python includes (*.zip and *.egg files)
-        dataOut.writeInt(pythonIncludes.size)
-        for (include <- pythonIncludes) {
-          PythonRDD.writeUTF(include, dataOut)
-        }
-        // Broadcast variables
-        val oldBids = PythonRDD.getWorkerBroadcasts(worker)
-        val newBids = broadcastVars.map(_.id).toSet
-        // number of different broadcasts
-        val toRemove = oldBids.diff(newBids)
-        val addedBids = newBids.diff(oldBids)
-        val cnt = toRemove.size + addedBids.size
-        val needsDecryptionServer = env.serializerManager.encryptionEnabled && 
addedBids.nonEmpty
-        dataOut.writeBoolean(needsDecryptionServer)
-        dataOut.writeInt(cnt)
-
-        def sendBidsToRemove(): Unit = {
-          for (bid <- toRemove) {
-            // remove the broadcast from worker
-            dataOut.writeLong(-bid - 1) // bid >= 0
-            oldBids.remove(bid)
-          }
-        }
-
-        if (needsDecryptionServer) {
-          // if there is encryption, we setup a server which reads the 
encrypted files, and sends
-          // the decrypted data to python
-          val idsAndFiles = broadcastVars.flatMap { broadcast =>
-            if (!oldBids.contains(broadcast.id)) {
-              oldBids.add(broadcast.id)
-              Some((broadcast.id, broadcast.value.path))
-            } else {
-              None
-            }
-          }
-          val server = new EncryptedPythonBroadcastServer(env, idsAndFiles)
-          dataOut.writeInt(server.port)
-          logTrace(s"broadcast decryption server setup on ${server.port}")
-          PythonRDD.writeUTF(server.secret, dataOut)
-          sendBidsToRemove()
-          idsAndFiles.foreach { case (id, _) =>
-            // send new broadcast
-            dataOut.writeLong(id)
-          }
-          dataOut.flush()
-          logTrace("waiting for python to read decrypted broadcast data from 
server")
-          server.waitTillBroadcastDataSent()
-          logTrace("done sending decrypted data to python")
-        } else {
-          sendBidsToRemove()
-          for (broadcast <- broadcastVars) {
-            if (!oldBids.contains(broadcast.id)) {
-              // send new broadcast
-              dataOut.writeLong(broadcast.id)
-              PythonRDD.writeUTF(broadcast.value.path, dataOut)
-              oldBids.add(broadcast.id)
-            }
-          }
-        }
-        dataOut.flush()
-
-        dataOut.writeInt(evalType)
-        writeCommand(dataOut)
-
-        // write number of geometry fields
-        dataOut.writeInt(geometryFields.length)
-        // write geometry field indices and their SRIDs
-        geometryFields.foreach { case (index, srid) =>
-          dataOut.writeInt(index)
-          dataOut.writeInt(srid)
-        }
-
-        writeIteratorToStream(dataOut)
-
-        dataOut.writeInt(SpecialLengths.END_OF_STREAM)
-        dataOut.flush()
-      } catch {
-        case t: Throwable if (NonFatal(t) || t.isInstanceOf[Exception]) =>
-          if (context.isCompleted || context.isInterrupted) {
-            logDebug("Exception/NonFatal Error thrown after task completion 
(likely due to " +
-              "cleanup)", t)
-            if (!worker.isClosed) {
-              Utils.tryLog(worker.shutdownOutput())
-            }
-          } else {
-            // We must avoid throwing exceptions/NonFatals here, because the 
thread uncaught
-            // exception handler will kill the whole executor (see
-            // org.apache.spark.executor.Executor).
-            _exception = t
-            if (!worker.isClosed) {
-              Utils.tryLog(worker.shutdownOutput())
-            }
-          }
-      }
-    }
-
-    /**
-     * Gateway to call BarrierTaskContext methods.
-     */
-    def barrierAndServe(requestMethod: Int, sock: Socket, message: String = 
""): Unit = {
-      require(
-        serverSocket.isDefined,
-        "No available ServerSocket to redirect the BarrierTaskContext method 
call."
-      )
-      val out = new DataOutputStream(new 
BufferedOutputStream(sock.getOutputStream))
-      try {
-        val messages = requestMethod match {
-          case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
-            context.asInstanceOf[BarrierTaskContext].barrier()
-            Array(BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS)
-          case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION =>
-            context.asInstanceOf[BarrierTaskContext].allGather(message)
-        }
-        out.writeInt(messages.length)
-        messages.foreach(writeUTF(_, out))
-      } catch {
-        case e: SparkException =>
-          writeUTF(e.getMessage, out)
-      } finally {
-        out.close()
-      }
-    }
-
-    def writeUTF(str: String, dataOut: DataOutputStream): Unit = {
-      val bytes = str.getBytes(UTF_8)
-      dataOut.writeInt(bytes.length)
-      dataOut.write(bytes)
-    }
-  }
-
-  abstract class ReaderIterator(
-                                 stream: DataInputStream,
-                                 writerThread: WriterThread,
-                                 startTime: Long,
-                                 env: SparkEnv,
-                                 worker: Socket,
-                                 pid: Option[Int],
-                                 releasedOrClosed: AtomicBoolean,
-                                 context: TaskContext)
-    extends Iterator[OUT] {
-
-    private var nextObj: OUT = _
-    private var eos = false
-
-    override def hasNext: Boolean = nextObj != null || {
-      if (!eos) {
-        nextObj = read()
-        hasNext
-      } else {
-        false
-      }
-    }
-
-    override def next(): OUT = {
-      if (hasNext) {
-        val obj = nextObj
-        nextObj = null.asInstanceOf[OUT]
-        obj
-      } else {
-        Iterator.empty.next()
-      }
-    }
-
-    /**
-     * Reads next object from the stream.
-     * When the stream reaches end of data, needs to process the following 
sections,
-     * and then returns null.
-     */
-    protected def read(): OUT
-
-    protected def handleTimingData(): Unit = {
-      // Timing data from worker
-      val bootTime = stream.readLong()
-      val initTime = stream.readLong()
-      val finishTime = stream.readLong()
-      val boot = bootTime - startTime
-      val init = initTime - bootTime
-      val finish = finishTime - initTime
-      val total = finishTime - startTime
-      logInfo("Times: total = %s, boot = %s, init = %s, finish = 
%s".format(total, boot,
-        init, finish))
-      val memoryBytesSpilled = stream.readLong()
-      val diskBytesSpilled = stream.readLong()
-      context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled)
-      context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled)
-    }
-
-    protected def handlePythonException(): PythonException = {
-      // Signals that an exception has been thrown in python
-      val exLength = stream.readInt()
-      val obj = new Array[Byte](exLength)
-      stream.readFully(obj)
-      new PythonException(new String(obj, StandardCharsets.UTF_8),
-        writerThread.exception.orNull)
-    }
-
-    protected def handleEndOfDataSection(): Unit = {
-      // We've finished the data section of the output, but we can still
-      // read some accumulator updates:
-      val numAccumulatorUpdates = stream.readInt()
-      (1 to numAccumulatorUpdates).foreach { _ =>
-        val updateLen = stream.readInt()
-        val update = new Array[Byte](updateLen)
-        stream.readFully(update)
-        maybeAccumulator.foreach(_.add(update))
-      }
-      // Check whether the worker is ready to be re-used.
-      if (stream.readInt() == SpecialLengths.END_OF_STREAM) {
-        if (reuseWorker && releasedOrClosed.compareAndSet(false, true)) {
-          env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker)
-        }
-      }
-      eos = true
-    }
-
-    protected val handleException: PartialFunction[Throwable, OUT] = {
-      case e: Exception if context.isInterrupted =>
-        logDebug("Exception thrown after task interruption", e)
-        throw new 
TaskKilledException(context.getKillReason().getOrElse("unknown reason"))
-
-      case e: Exception if writerThread.exception.isDefined =>
-        logError("Python worker exited unexpectedly (crashed)", e)
-        logError("This may have been caused by a prior exception:", 
writerThread.exception.get)
-        throw writerThread.exception.get
-
-      case eof: EOFException if faultHandlerEnabled && pid.isDefined &&
-        JavaFiles.exists(SedonaBasePythonRunner.faultHandlerLogPath(pid.get)) 
=>
-        val path = SedonaBasePythonRunner.faultHandlerLogPath(pid.get)
-        val error = String.join("\n", JavaFiles.readAllLines(path)) + "\n"
-        JavaFiles.deleteIfExists(path)
-        throw new SparkException(s"Python worker exited unexpectedly 
(crashed): $error", eof)
-
-      case eof: EOFException =>
-        throw new SparkException("Python worker exited unexpectedly 
(crashed)", eof)
-    }
-  }
-
-  /**
-   * It is necessary to have a monitor thread for python workers if the user 
cancels with
-   * interrupts disabled. In that case we will need to explicitly kill the 
worker, otherwise the
-   * threads can block indefinitely.
-   */
-  class MonitorThread(env: SparkEnv, worker: Socket, context: TaskContext)
-    extends Thread(s"Worker Monitor for $pythonExec") {
-
-    /** How long to wait before killing the python worker if a task cannot be 
interrupted. */
-    private val taskKillTimeout = env.conf.get(PYTHON_TASK_KILL_TIMEOUT)
-
-    setDaemon(true)
-
-    private def monitorWorker(): Unit = {
-      // Kill the worker if it is interrupted, checking until task completion.
-      // TODO: This has a race condition if interruption occurs, as completed 
may still become true.
-      while (!context.isInterrupted && !context.isCompleted) {
-        Thread.sleep(2000)
-      }
-      if (!context.isCompleted) {
-        Thread.sleep(taskKillTimeout)
-        if (!context.isCompleted) {
-          try {
-            // Mimic the task name used in `Executor` to help the user find 
out the task to blame.
-            val taskName = s"${context.partitionId}.${context.attemptNumber} " 
+
-              s"in stage ${context.stageId} (TID ${context.taskAttemptId})"
-            logWarning(s"Incomplete task $taskName interrupted: Attempting to 
kill Python Worker")
-            env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker)
-          } catch {
-            case e: Exception =>
-              logError("Exception when trying to kill worker", e)
-          }
-        }
-      }
-    }
-
-    override def run(): Unit = {
-      try {
-        monitorWorker()
-      } finally {
-        if (reuseWorker) {
-          val key = (worker, context.taskAttemptId)
-          PythonRunner.runningMonitorThreads.remove(key)
-        }
-      }
-    }
-  }
-
-  /**
-   * This thread monitors the WriterThread and kills it in case of deadlock.
-   *
-   * A deadlock can arise if the task completes while the writer thread is 
sending input to the
-   * Python process (e.g. due to the use of `take()`), and the Python process 
is still producing
-   * output. When the inputs are sufficiently large, this can result in a 
deadlock due to the use of
-   * blocking I/O (SPARK-38677). To resolve the deadlock, we need to close the 
socket.
-   */
-  class WriterMonitorThread(
-                             env: SparkEnv, worker: Socket, writerThread: 
WriterThread, context: TaskContext)
-    extends Thread(s"Writer Monitor for $pythonExec (writer thread id 
${writerThread.getId})") {
-
-    /**
-     * How long to wait before closing the socket if the writer thread has not 
exited after the task
-     * ends.
-     */
-    private val taskKillTimeout = env.conf.get(PYTHON_TASK_KILL_TIMEOUT)
-
-    setDaemon(true)
-
-    override def run(): Unit = {
-      // Wait until the task is completed (or the writer thread exits, in 
which case this thread has
-      // nothing to do).
-      while (!context.isCompleted && writerThread.isAlive) {
-        Thread.sleep(2000)
-      }
-      if (writerThread.isAlive) {
-        Thread.sleep(taskKillTimeout)
-        // If the writer thread continues running, this indicates a deadlock. 
Kill the worker to
-        // resolve the deadlock.
-        if (writerThread.isAlive) {
-          try {
-            // Mimic the task name used in `Executor` to help the user find 
out the task to blame.
-            val taskName = s"${context.partitionId}.${context.attemptNumber} " 
+
-              s"in stage ${context.stageId} (TID ${context.taskAttemptId})"
-            logWarning(
-              s"Detected deadlock while completing task $taskName: " +
-                "Attempting to kill Python Worker")
-            env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker)
-          } catch {
-            case e: Exception =>
-              logError("Exception when trying to kill worker", e)
-          }
-        }
-      }
-    }
-  }
-}
diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala
index c3eafc9766..2a28eba6db 100644
--- 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala
+++ 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala
@@ -37,10 +37,10 @@ class SedonaArrowPythonRunner(
                          protected override val workerConf: Map[String, 
String],
                          val pythonMetrics: Map[String, SQLMetric],
                          jobArtifactUUID: Option[String])
-  extends SedonaBasePythonRunner[Iterator[InternalRow], ColumnarBatch](
-    funcs, evalType, argOffsets, jobArtifactUUID, schema)
+  extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](
+    funcs, evalType, argOffsets, jobArtifactUUID)
     with SedonaBasicPythonArrowInput
-    with SedonaBasicPythonArrowOutput {
+    with BasicPythonArrowOutput {
 
   override val pythonExec: String =
     SQLConf.get.pysparkWorkerPythonExecutable.getOrElse(
diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala
index bf353539bc..7bc0d322c2 100644
--- 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala
+++ 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala
@@ -20,10 +20,12 @@ package org.apache.spark.sql.execution.python
 import org.apache.arrow.vector.VectorSchemaRoot
 import org.apache.arrow.vector.ipc.ArrowStreamWriter
 import org.apache.spark.api.python
-import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, 
PythonRDD, SedonaBasePythonRunner}
+import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, 
PythonRDD}
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
 import org.apache.spark.sql.execution.arrow.ArrowWriter
 import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.sql.util.ArrowUtils
 import org.apache.spark.sql.util.ArrowUtils.toArrowSchema
@@ -37,7 +39,7 @@ import java.net.Socket
  * A trait that can be mixed-in with [[python.BasePythonRunner]]. It 
implements the logic from
  * JVM (an iterator of internal rows + additional data if required) to Python 
(Arrow).
  */
-private[python] trait SedonaPythonArrowInput[IN] { self: 
SedonaBasePythonRunner[IN, _] =>
+private[python] trait SedonaPythonArrowInput[IN] { self: BasePythonRunner[IN, 
_] =>
   protected val workerConf: Map[String, String]
 
   protected val schema: StructType
@@ -84,6 +86,46 @@ private[python] trait SedonaPythonArrowInput[IN] { self: 
SedonaBasePythonRunner[
       protected override def writeCommand(dataOut: DataOutputStream): Unit = {
         handleMetadataBeforeExec(dataOut)
         writeUDF(dataOut, funcs, argOffsets)
+
+        val toReadCRS = inputIterator.buffered.headOption.flatMap(
+          el => el.asInstanceOf[Iterator[IN]].buffered.headOption
+        )
+
+        val row = toReadCRS match {
+          case Some(value) => value match {
+            case row: GenericInternalRow =>
+              Some(row)
+          }
+          case None => None
+        }
+
+        val geometryFields = schema.zipWithIndex.filter {
+          case (field, index) => field.dataType == GeometryUDT
+        }.map {
+          case (field, index) =>
+            if (row.isEmpty || row.get.values(index) == null) (index, 0) else {
+              val geom = row.get.get(index, 
GeometryUDT).asInstanceOf[Array[Byte]]
+              val preambleByte = geom(0) & 0xFF
+              val hasSrid = (preambleByte & 0x01) != 0
+
+              var srid = 0
+              if (hasSrid) {
+                val srid2 = (geom(1) & 0xFF) << 16
+                val srid1 = (geom(2) & 0xFF) << 8
+                val srid0 = geom(3) & 0xFF
+                srid = srid2 | srid1 | srid0
+              }
+              (index, srid)
+            }
+        }
+
+        // write number of geometry fields
+        dataOut.writeInt(geometryFields.length)
+        // write geometry field indices and their SRIDs
+        geometryFields.foreach { case (index, srid) =>
+          dataOut.writeInt(index)
+          dataOut.writeInt(srid)
+        }
       }
 
       protected override def writeIteratorToStream(dataOut: DataOutputStream): 
Unit = {
@@ -124,7 +166,7 @@ private[python] trait SedonaPythonArrowInput[IN] { self: 
SedonaBasePythonRunner[
 
 
 private[python] trait SedonaBasicPythonArrowInput extends 
SedonaPythonArrowInput[Iterator[InternalRow]] {
-  self: SedonaBasePythonRunner[Iterator[InternalRow], _] =>
+  self: BasePythonRunner[Iterator[InternalRow], _] =>
 
   protected def writeIteratorToArrowStream(
                                             root: VectorSchemaRoot,
diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowOutput.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowOutput.scala
deleted file mode 100644
index 316cb32c3e..0000000000
--- 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowOutput.scala
+++ /dev/null
@@ -1,144 +0,0 @@
-package org.apache.spark.sql.execution.python
-
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-import org.apache.arrow.vector.VectorSchemaRoot
-import org.apache.arrow.vector.ipc.ArrowStreamReader
-import org.apache.spark.api.python.{BasePythonRunner, SedonaBasePythonRunner, 
SpecialLengths}
-import org.apache.spark.sql.execution.metric.SQLMetric
-import org.apache.spark.sql.types.StructType
-import org.apache.spark.sql.util.ArrowUtils
-import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnVector, 
ColumnarBatch}
-import org.apache.spark.{SparkEnv, TaskContext}
-
-import java.io.DataInputStream
-import java.net.Socket
-import java.util.concurrent.atomic.AtomicBoolean
-import scala.collection.JavaConverters._
-
-/**
- * A trait that can be mixed-in with [[BasePythonRunner]]. It implements the 
logic from
- * Python (Arrow) to JVM (output type being deserialized from ColumnarBatch).
- */
-private[python] trait SedonaPythonArrowOutput[OUT <: AnyRef] { self: 
SedonaBasePythonRunner[_, OUT] =>
-
-  protected def pythonMetrics: Map[String, SQLMetric]
-
-  protected def handleMetadataAfterExec(stream: DataInputStream): Unit = { }
-
-  protected def deserializeColumnarBatch(batch: ColumnarBatch, schema: 
StructType): OUT
-
-  protected def newReaderIterator(
-                                   stream: DataInputStream,
-                                   writerThread: WriterThread,
-                                   startTime: Long,
-                                   env: SparkEnv,
-                                   worker: Socket,
-                                   pid: Option[Int],
-                                   releasedOrClosed: AtomicBoolean,
-                                   context: TaskContext): Iterator[OUT] = {
-
-    new ReaderIterator(
-      stream, writerThread, startTime, env, worker, pid, releasedOrClosed, 
context) {
-
-      private val allocator = ArrowUtils.rootAllocator.newChildAllocator(
-        s"stdin reader for $pythonExec", 0, Long.MaxValue)
-
-      private var reader: ArrowStreamReader = _
-      private var root: VectorSchemaRoot = _
-      private var schema: StructType = _
-      private var vectors: Array[ColumnVector] = _
-
-      context.addTaskCompletionListener[Unit] { _ =>
-        if (reader != null) {
-          reader.close(false)
-        }
-        allocator.close()
-      }
-
-      private var batchLoaded = true
-
-      protected override def handleEndOfDataSection(): Unit = {
-        handleMetadataAfterExec(stream)
-        super.handleEndOfDataSection()
-      }
-
-      protected override def read(): OUT = {
-        if (writerThread.exception.isDefined) {
-          throw writerThread.exception.get
-        }
-        try {
-          if (reader != null && batchLoaded) {
-            val bytesReadStart = reader.bytesRead()
-            batchLoaded = reader.loadNextBatch()
-            if (batchLoaded) {
-              val batch = new ColumnarBatch(vectors)
-              val rowCount = root.getRowCount
-              batch.setNumRows(root.getRowCount)
-              val bytesReadEnd = reader.bytesRead()
-              pythonMetrics("pythonNumRowsReceived") += rowCount
-              pythonMetrics("pythonDataReceived") += bytesReadEnd - 
bytesReadStart
-              val result = deserializeColumnarBatch(batch, schema)
-              result
-            } else {
-              reader.close(false)
-              allocator.close()
-              // Reach end of stream. Call `read()` again to read control data.
-              read()
-            }
-          } else {
-            stream.readInt() match {
-              case SpecialLengths.START_ARROW_STREAM =>
-                reader = new ArrowStreamReader(stream, allocator)
-                root = reader.getVectorSchemaRoot()
-                schema = ArrowUtils.fromArrowSchema(root.getSchema())
-                vectors = root.getFieldVectors().asScala.map { vector =>
-                  new ArrowColumnVector(vector)
-                }.toArray[ColumnVector]
-                read()
-              case SpecialLengths.TIMING_DATA =>
-                handleTimingData()
-                read()
-              case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
-                throw handlePythonException()
-              case SpecialLengths.END_OF_DATA_SECTION =>
-                handleEndOfDataSection()
-                null.asInstanceOf[OUT]
-            }
-          }
-        } catch {
-          case e: Exception =>
-            // If an exception happens, make sure to close the reader to 
release resources.
-            if (reader != null) {
-              reader.close(false)
-            }
-            allocator.close()
-            throw e
-        }
-      }
-    }
-  }
-}
-
-private[python] trait SedonaBasicPythonArrowOutput extends 
SedonaPythonArrowOutput[ColumnarBatch] {
-  self: SedonaBasePythonRunner[_, ColumnarBatch] =>
-
-  protected def deserializeColumnarBatch(
-                                          batch: ColumnarBatch,
-                                          schema: StructType): ColumnarBatch = 
batch
-}
diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonUDFRunner.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonUDFRunner.scala
index beb49c1dde..dcf93b5213 100644
--- 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonUDFRunner.scala
+++ 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonUDFRunner.scala
@@ -34,8 +34,8 @@ abstract class SedonaBasePythonUDFRunner(
                                     argOffsets: Array[Array[Int]],
                                     pythonMetrics: Map[String, SQLMetric],
                                     jobArtifactUUID: Option[String])
-  extends SedonaBasePythonRunner[Array[Byte], Array[Byte]](
-    funcs, evalType, argOffsets, jobArtifactUUID, null) {
+  extends BasePythonRunner[Array[Byte], Array[Byte]](
+    funcs, evalType, argOffsets, jobArtifactUUID) {
 
   override val pythonExec: String =
     SQLConf.get.pysparkWorkerPythonExecutable.getOrElse(
diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaWriterThread.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaWriterThread.scala
index b28cf81906..57cf2dc7bb 100644
--- 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaWriterThread.scala
+++ 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaWriterThread.scala
@@ -4,7 +4,7 @@ package org.apache.spark.sql.execution.python
 import org.apache.sedona.common.geometrySerde.CoordinateType
 import org.apache.spark._
 import org.apache.spark.SedonaSparkEnv
-import org.apache.spark.api.python.{BarrierTaskContextMessageProtocol, 
BasePythonRunner, ChainedPythonFunctions, EncryptedPythonBroadcastServer, 
PythonRDD, SedonaBasePythonRunner, SpecialLengths}
+import org.apache.spark.api.python.{BarrierTaskContextMessageProtocol, 
BasePythonRunner, ChainedPythonFunctions, EncryptedPythonBroadcastServer, 
PythonRDD, SpecialLengths}
 import org.apache.spark.internal.Logging
 import org.apache.spark.internal.config.Python._
 import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES}


Reply via email to