This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 30d0fe072c8b [SPARK-45441][PYTHON] Introduce more util functions for 
PythonWorkerUtils
30d0fe072c8b is described below

commit 30d0fe072c8bee4b6e98e013360d4c0c948ff003
Author: Takuya UESHIN <[email protected]>
AuthorDate: Sat Oct 7 15:32:32 2023 -0700

    [SPARK-45441][PYTHON] Introduce more util functions for PythonWorkerUtils
    
    ### What changes were proposed in this pull request?
    
    Introduces more util functions for `PythonWorkerUtils`.
    
    The following util functions will be added:
    
    - `writePythonFunction`
    - `readUTF`
    
    ### Why are the changes needed?
    
    There are more common codes to communicate with the Python worker.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    The existing tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #43251 from ueshin/issues/SPARK-45441/worker_util.
    
    Authored-by: Takuya UESHIN <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 .../planner/StreamingForeachBatchHelper.scala      | 10 ++----
 .../planner/StreamingQueryListenerHelper.scala     | 16 ++++------
 .../org/apache/spark/api/python/PythonRunner.scala | 12 ++------
 .../spark/api/python/PythonWorkerUtils.scala       | 24 +++++++++++++++
 .../spark/api/python/StreamingPythonRunner.scala   |  4 +--
 .../execution/python/BatchEvalPythonUDTFExec.scala |  3 +-
 .../sql/execution/python/PythonUDFRunner.scala     |  6 ++--
 .../python/UserDefinedPythonFunction.scala         | 36 ++++++++--------------
 8 files changed, 52 insertions(+), 59 deletions(-)

diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
index a5c0f863a174..b8097b235503 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
@@ -17,7 +17,6 @@
 package org.apache.spark.sql.connect.planner
 
 import java.io.EOFException
-import java.nio.charset.StandardCharsets
 import java.util.UUID
 import java.util.concurrent.ConcurrentHashMap
 import java.util.concurrent.ConcurrentMap
@@ -26,7 +25,7 @@ import scala.jdk.CollectionConverters._
 import scala.util.control.NonFatal
 
 import org.apache.spark.SparkException
-import org.apache.spark.api.python.{PythonException, PythonRDD, 
SimplePythonFunction, SpecialLengths, StreamingPythonRunner}
+import org.apache.spark.api.python.{PythonException, PythonWorkerUtils, 
SimplePythonFunction, SpecialLengths, StreamingPythonRunner}
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.connect.service.SessionHolder
@@ -124,7 +123,7 @@ object StreamingForeachBatchHelper extends Logging {
       //     the session alive. The session mapping at Connect server does not 
expire and query
       //     keeps running even if the original client disappears. This keeps 
the query running.
 
-      PythonRDD.writeUTF(args.dfId, dataOut)
+      PythonWorkerUtils.writeUTF(args.dfId, dataOut)
       dataOut.writeLong(args.batchId)
       dataOut.flush()
 
@@ -133,10 +132,7 @@ object StreamingForeachBatchHelper extends Logging {
           case 0 =>
             logInfo(s"Python foreach batch for dfId ${args.dfId} completed 
(ret: 0)")
           case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
-            val exLength = dataIn.readInt()
-            val obj = new Array[Byte](exLength)
-            dataIn.readFully(obj)
-            val msg = new String(obj, StandardCharsets.UTF_8)
+            val msg = PythonWorkerUtils.readUTF(dataIn)
             throw new PythonException(
               s"Found error inside foreachBatch Python process: $msg",
               null)
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
index 886aeab3befd..685991dbed87 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
@@ -18,10 +18,9 @@
 package org.apache.spark.sql.connect.planner
 
 import java.io.EOFException
-import java.nio.charset.StandardCharsets
 
 import org.apache.spark.SparkException
-import org.apache.spark.api.python.{PythonException, PythonRDD, 
SimplePythonFunction, SpecialLengths, StreamingPythonRunner}
+import org.apache.spark.api.python.{PythonException, PythonWorkerUtils, 
SimplePythonFunction, SpecialLengths, StreamingPythonRunner}
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.connect.service.{SessionHolder, 
SparkConnectService}
 import org.apache.spark.sql.streaming.StreamingQueryListener
@@ -47,28 +46,28 @@ class PythonStreamingQueryListener(listener: 
SimplePythonFunction, sessionHolder
   val (dataOut, dataIn) = runner.init()
 
   override def onQueryStarted(event: 
StreamingQueryListener.QueryStartedEvent): Unit = {
-    PythonRDD.writeUTF(event.json, dataOut)
+    PythonWorkerUtils.writeUTF(event.json, dataOut)
     dataOut.writeInt(0)
     dataOut.flush()
     handlePythonWorkerError("onQueryStarted")
   }
 
   override def onQueryProgress(event: 
StreamingQueryListener.QueryProgressEvent): Unit = {
-    PythonRDD.writeUTF(event.json, dataOut)
+    PythonWorkerUtils.writeUTF(event.json, dataOut)
     dataOut.writeInt(1)
     dataOut.flush()
     handlePythonWorkerError("onQueryProgress")
   }
 
   override def onQueryIdle(event: StreamingQueryListener.QueryIdleEvent): Unit 
= {
-    PythonRDD.writeUTF(event.json, dataOut)
+    PythonWorkerUtils.writeUTF(event.json, dataOut)
     dataOut.writeInt(2)
     dataOut.flush()
     handlePythonWorkerError("onQueryIdle")
   }
 
   override def onQueryTerminated(event: 
StreamingQueryListener.QueryTerminatedEvent): Unit = {
-    PythonRDD.writeUTF(event.json, dataOut)
+    PythonWorkerUtils.writeUTF(event.json, dataOut)
     dataOut.writeInt(3)
     dataOut.flush()
     handlePythonWorkerError("onQueryTerminated")
@@ -85,10 +84,7 @@ class PythonStreamingQueryListener(listener: 
SimplePythonFunction, sessionHolder
         case 0 =>
           logInfo(s"Streaming query listener function $functionName completed 
(ret: 0)")
         case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
-          val exLength = dataIn.readInt()
-          val obj = new Array[Byte](exLength)
-          dataIn.readFully(obj)
-          val msg = new String(obj, StandardCharsets.UTF_8)
+          val msg = PythonWorkerUtils.readUTF(dataIn)
           throw new PythonException(
             s"Found error inside Streaming query listener Python " +
               s"process for function $functionName: $msg",
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index b681294735ff..da658227e850 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -21,7 +21,6 @@ import java.io._
 import java.net._
 import java.nio.ByteBuffer
 import java.nio.channels.SelectionKey
-import java.nio.charset.StandardCharsets
 import java.nio.charset.StandardCharsets.UTF_8
 import java.nio.file.{Files => JavaFiles, Path}
 import java.util.concurrent.ConcurrentHashMap
@@ -516,11 +515,8 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
 
     protected def handlePythonException(): PythonException = {
       // Signals that an exception has been thrown in python
-      val exLength = stream.readInt()
-      val obj = new Array[Byte](exLength)
-      stream.readFully(obj)
-      new PythonException(new String(obj, StandardCharsets.UTF_8),
-        writer.exception.orNull)
+      val msg = PythonWorkerUtils.readUTF(stream)
+      new PythonException(msg, writer.exception.orNull)
     }
 
     protected def handleEndOfDataSection(): Unit = {
@@ -816,9 +812,7 @@ private[spark] class PythonRunner(
     new Writer(env, worker, inputIterator, partitionIndex, context) {
 
       protected override def writeCommand(dataOut: DataOutputStream): Unit = {
-        val command = funcs.head.funcs.head.command
-        dataOut.writeInt(command.length)
-        dataOut.write(command.toArray)
+        PythonWorkerUtils.writePythonFunction(funcs.head.funcs.head, dataOut)
       }
 
       override def writeNextInputToStream(dataOut: DataOutputStream): Boolean 
= {
diff --git 
a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala
index 3f7b11a40ada..782099235c43 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala
@@ -130,6 +130,30 @@ private[spark] object PythonWorkerUtils extends Logging {
     dataOut.flush()
   }
 
+  /**
+   * Write PythonFunction to the worker.
+   */
+  def writePythonFunction(func: PythonFunction, dataOut: DataOutputStream): 
Unit = {
+    dataOut.writeInt(func.command.length)
+    dataOut.write(func.command.toArray)
+  }
+
+  /**
+   * Read a string in UTF-8 charset.
+   */
+  def readUTF(dataIn: DataInputStream): String = {
+    readUTF(dataIn.readInt(), dataIn)
+  }
+
+  /**
+   * Read a string in UTF-8 charset with the given byte length.
+   */
+  def readUTF(length: Int, dataIn: DataInputStream): String = {
+    val obj = new Array[Byte](length)
+    dataIn.readFully(obj)
+    new String(obj, StandardCharsets.UTF_8)
+  }
+
   /**
    * Receive accumulator updates from the worker.
    *
diff --git 
a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala 
b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
index 2fb5d15bcfd4..cf5e912c3da5 100644
--- 
a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
+++ 
b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
@@ -90,9 +90,7 @@ private[spark] class StreamingPythonRunner(
     PythonRDD.writeUTF(sessionId, dataOut)
 
     // Send the user function to python process
-    val command = func.command
-    dataOut.writeInt(command.length)
-    dataOut.write(command.toArray)
+    PythonWorkerUtils.writePythonFunction(func, dataOut)
     dataOut.flush()
 
     val dataIn = new DataInputStream(
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
index 4ad9bfe3ec58..01fb3bd7ac6a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
@@ -132,8 +132,7 @@ object PythonUDTFRunner {
       case None =>
         dataOut.writeInt(0)
     }
-    dataOut.writeInt(udtf.func.command.length)
-    dataOut.write(udtf.func.command.toArray)
+    PythonWorkerUtils.writePythonFunction(udtf.func, dataOut)
     PythonWorkerUtils.writeUTF(udtf.elementSchema.json, dataOut)
   }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
index 12c51506b13c..37e95b608cc7 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
@@ -152,8 +152,7 @@ object PythonUDFRunner {
       }
       dataOut.writeInt(chained.funcs.length)
       chained.funcs.foreach { f =>
-        dataOut.writeInt(f.command.length)
-        dataOut.write(f.command.toArray)
+        PythonWorkerUtils.writePythonFunction(f, dataOut)
       }
     }
   }
@@ -178,8 +177,7 @@ object PythonUDFRunner {
       }
       dataOut.writeInt(chained.funcs.length)
       chained.funcs.foreach { f =>
-        dataOut.writeInt(f.command.length)
-        dataOut.write(f.command.toArray)
+        PythonWorkerUtils.writePythonFunction(f, dataOut)
       }
     }
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
index 6400f7978afd..6e053167dc96 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.python
 import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, 
DataOutputStream, EOFException, InputStream}
 import java.nio.ByteBuffer
 import java.nio.channels.SelectionKey
-import java.nio.charset.StandardCharsets
 import java.util.HashMap
 
 import scala.collection.mutable.ArrayBuffer
@@ -244,8 +243,7 @@ object UserDefinedPythonTableFunction {
       PythonWorkerUtils.writeBroadcasts(broadcastVars, worker, env, dataOut)
 
       // Send Python UDTF
-      dataOut.writeInt(func.command.length)
-      dataOut.write(func.command.toArray)
+      PythonWorkerUtils.writePythonFunction(func, dataOut)
 
       // Send arguments
       dataOut.writeInt(exprs.length)
@@ -276,40 +274,30 @@ object UserDefinedPythonTableFunction {
       val dataIn = new DataInputStream(new BufferedInputStream(
         new WorkerInputStream(worker, bufferStream.toByteBuffer), bufferSize))
 
-      // Receive the schema.
-      val schema = dataIn.readInt() match {
-        case length if length >= 0 =>
-          val obj = new Array[Byte](length)
-          dataIn.readFully(obj)
-          DataType.fromJson(new String(obj, 
StandardCharsets.UTF_8)).asInstanceOf[StructType]
-
-        case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
-          val exLength = dataIn.readInt()
-          val obj = new Array[Byte](exLength)
-          dataIn.readFully(obj)
-          val msg = new String(obj, StandardCharsets.UTF_8)
-          throw 
QueryCompilationErrors.tableValuedFunctionFailedToAnalyseInPythonError(msg)
+      // Receive the schema or an exception raised in Python worker.
+      val length = dataIn.readInt()
+      if (length == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+        val msg = PythonWorkerUtils.readUTF(dataIn)
+        throw 
QueryCompilationErrors.tableValuedFunctionFailedToAnalyseInPythonError(msg)
       }
+
+      val schema = DataType.fromJson(
+        PythonWorkerUtils.readUTF(length, dataIn)).asInstanceOf[StructType]
+
       // Receive whether the "with single partition" property is requested.
       val withSinglePartition = dataIn.readInt() == 1
       // Receive the list of requested partitioning columns, if any.
       val partitionByColumns = ArrayBuffer.empty[Expression]
       val numPartitionByColumns = dataIn.readInt()
       for (_ <- 0 until numPartitionByColumns) {
-        val length = dataIn.readInt()
-        val obj = new Array[Byte](length)
-        dataIn.readFully(obj)
-        val columnName = new String(obj, StandardCharsets.UTF_8)
+        val columnName = PythonWorkerUtils.readUTF(dataIn)
         partitionByColumns.append(UnresolvedAttribute(columnName))
       }
       // Receive the list of requested ordering columns, if any.
       val orderBy = ArrayBuffer.empty[SortOrder]
       val numOrderByItems = dataIn.readInt()
       for (_ <- 0 until numOrderByItems) {
-        val length = dataIn.readInt()
-        val obj = new Array[Byte](length)
-        dataIn.readFully(obj)
-        val columnName = new String(obj, StandardCharsets.UTF_8)
+        val columnName = PythonWorkerUtils.readUTF(dataIn)
         val direction = if (dataIn.readInt() == 1) Ascending else Descending
         val overrideNullsFirst = dataIn.readInt()
         overrideNullsFirst match {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to