This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 30d0fe072c8b [SPARK-45441][PYTHON] Introduce more util functions for
PythonWorkerUtils
30d0fe072c8b is described below
commit 30d0fe072c8bee4b6e98e013360d4c0c948ff003
Author: Takuya UESHIN <[email protected]>
AuthorDate: Sat Oct 7 15:32:32 2023 -0700
[SPARK-45441][PYTHON] Introduce more util functions for PythonWorkerUtils
### What changes were proposed in this pull request?
Introduces more util functions for `PythonWorkerUtils`.
The following util functions will be added:
- `writePythonFunction`
- `readUTF`
### Why are the changes needed?
There are more common codes to communicate with the Python worker.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
The existing tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #43251 from ueshin/issues/SPARK-45441/worker_util.
Authored-by: Takuya UESHIN <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
.../planner/StreamingForeachBatchHelper.scala | 10 ++----
.../planner/StreamingQueryListenerHelper.scala | 16 ++++------
.../org/apache/spark/api/python/PythonRunner.scala | 12 ++------
.../spark/api/python/PythonWorkerUtils.scala | 24 +++++++++++++++
.../spark/api/python/StreamingPythonRunner.scala | 4 +--
.../execution/python/BatchEvalPythonUDTFExec.scala | 3 +-
.../sql/execution/python/PythonUDFRunner.scala | 6 ++--
.../python/UserDefinedPythonFunction.scala | 36 ++++++++--------------
8 files changed, 52 insertions(+), 59 deletions(-)
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
index a5c0f863a174..b8097b235503 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.connect.planner
import java.io.EOFException
-import java.nio.charset.StandardCharsets
import java.util.UUID
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ConcurrentMap
@@ -26,7 +25,7 @@ import scala.jdk.CollectionConverters._
import scala.util.control.NonFatal
import org.apache.spark.SparkException
-import org.apache.spark.api.python.{PythonException, PythonRDD,
SimplePythonFunction, SpecialLengths, StreamingPythonRunner}
+import org.apache.spark.api.python.{PythonException, PythonWorkerUtils,
SimplePythonFunction, SpecialLengths, StreamingPythonRunner}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.connect.service.SessionHolder
@@ -124,7 +123,7 @@ object StreamingForeachBatchHelper extends Logging {
// the session alive. The session mapping at Connect server does not
expire and query
// keeps running even if the original client disappears. This keeps
the query running.
- PythonRDD.writeUTF(args.dfId, dataOut)
+ PythonWorkerUtils.writeUTF(args.dfId, dataOut)
dataOut.writeLong(args.batchId)
dataOut.flush()
@@ -133,10 +132,7 @@ object StreamingForeachBatchHelper extends Logging {
case 0 =>
logInfo(s"Python foreach batch for dfId ${args.dfId} completed
(ret: 0)")
case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
- val exLength = dataIn.readInt()
- val obj = new Array[Byte](exLength)
- dataIn.readFully(obj)
- val msg = new String(obj, StandardCharsets.UTF_8)
+ val msg = PythonWorkerUtils.readUTF(dataIn)
throw new PythonException(
s"Found error inside foreachBatch Python process: $msg",
null)
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
index 886aeab3befd..685991dbed87 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
@@ -18,10 +18,9 @@
package org.apache.spark.sql.connect.planner
import java.io.EOFException
-import java.nio.charset.StandardCharsets
import org.apache.spark.SparkException
-import org.apache.spark.api.python.{PythonException, PythonRDD,
SimplePythonFunction, SpecialLengths, StreamingPythonRunner}
+import org.apache.spark.api.python.{PythonException, PythonWorkerUtils,
SimplePythonFunction, SpecialLengths, StreamingPythonRunner}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.connect.service.{SessionHolder,
SparkConnectService}
import org.apache.spark.sql.streaming.StreamingQueryListener
@@ -47,28 +46,28 @@ class PythonStreamingQueryListener(listener:
SimplePythonFunction, sessionHolder
val (dataOut, dataIn) = runner.init()
override def onQueryStarted(event:
StreamingQueryListener.QueryStartedEvent): Unit = {
- PythonRDD.writeUTF(event.json, dataOut)
+ PythonWorkerUtils.writeUTF(event.json, dataOut)
dataOut.writeInt(0)
dataOut.flush()
handlePythonWorkerError("onQueryStarted")
}
override def onQueryProgress(event:
StreamingQueryListener.QueryProgressEvent): Unit = {
- PythonRDD.writeUTF(event.json, dataOut)
+ PythonWorkerUtils.writeUTF(event.json, dataOut)
dataOut.writeInt(1)
dataOut.flush()
handlePythonWorkerError("onQueryProgress")
}
override def onQueryIdle(event: StreamingQueryListener.QueryIdleEvent): Unit
= {
- PythonRDD.writeUTF(event.json, dataOut)
+ PythonWorkerUtils.writeUTF(event.json, dataOut)
dataOut.writeInt(2)
dataOut.flush()
handlePythonWorkerError("onQueryIdle")
}
override def onQueryTerminated(event:
StreamingQueryListener.QueryTerminatedEvent): Unit = {
- PythonRDD.writeUTF(event.json, dataOut)
+ PythonWorkerUtils.writeUTF(event.json, dataOut)
dataOut.writeInt(3)
dataOut.flush()
handlePythonWorkerError("onQueryTerminated")
@@ -85,10 +84,7 @@ class PythonStreamingQueryListener(listener:
SimplePythonFunction, sessionHolder
case 0 =>
logInfo(s"Streaming query listener function $functionName completed
(ret: 0)")
case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
- val exLength = dataIn.readInt()
- val obj = new Array[Byte](exLength)
- dataIn.readFully(obj)
- val msg = new String(obj, StandardCharsets.UTF_8)
+ val msg = PythonWorkerUtils.readUTF(dataIn)
throw new PythonException(
s"Found error inside Streaming query listener Python " +
s"process for function $functionName: $msg",
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index b681294735ff..da658227e850 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -21,7 +21,6 @@ import java.io._
import java.net._
import java.nio.ByteBuffer
import java.nio.channels.SelectionKey
-import java.nio.charset.StandardCharsets
import java.nio.charset.StandardCharsets.UTF_8
import java.nio.file.{Files => JavaFiles, Path}
import java.util.concurrent.ConcurrentHashMap
@@ -516,11 +515,8 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
protected def handlePythonException(): PythonException = {
// Signals that an exception has been thrown in python
- val exLength = stream.readInt()
- val obj = new Array[Byte](exLength)
- stream.readFully(obj)
- new PythonException(new String(obj, StandardCharsets.UTF_8),
- writer.exception.orNull)
+ val msg = PythonWorkerUtils.readUTF(stream)
+ new PythonException(msg, writer.exception.orNull)
}
protected def handleEndOfDataSection(): Unit = {
@@ -816,9 +812,7 @@ private[spark] class PythonRunner(
new Writer(env, worker, inputIterator, partitionIndex, context) {
protected override def writeCommand(dataOut: DataOutputStream): Unit = {
- val command = funcs.head.funcs.head.command
- dataOut.writeInt(command.length)
- dataOut.write(command.toArray)
+ PythonWorkerUtils.writePythonFunction(funcs.head.funcs.head, dataOut)
}
override def writeNextInputToStream(dataOut: DataOutputStream): Boolean
= {
diff --git
a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala
b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala
index 3f7b11a40ada..782099235c43 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala
@@ -130,6 +130,30 @@ private[spark] object PythonWorkerUtils extends Logging {
dataOut.flush()
}
+ /**
+ * Write PythonFunction to the worker.
+ */
+ def writePythonFunction(func: PythonFunction, dataOut: DataOutputStream):
Unit = {
+ dataOut.writeInt(func.command.length)
+ dataOut.write(func.command.toArray)
+ }
+
+ /**
+ * Read a string in UTF-8 charset.
+ */
+ def readUTF(dataIn: DataInputStream): String = {
+ readUTF(dataIn.readInt(), dataIn)
+ }
+
+ /**
+ * Read a string in UTF-8 charset with the given byte length.
+ */
+ def readUTF(length: Int, dataIn: DataInputStream): String = {
+ val obj = new Array[Byte](length)
+ dataIn.readFully(obj)
+ new String(obj, StandardCharsets.UTF_8)
+ }
+
/**
* Receive accumulator updates from the worker.
*
diff --git
a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
index 2fb5d15bcfd4..cf5e912c3da5 100644
---
a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
+++
b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
@@ -90,9 +90,7 @@ private[spark] class StreamingPythonRunner(
PythonRDD.writeUTF(sessionId, dataOut)
// Send the user function to python process
- val command = func.command
- dataOut.writeInt(command.length)
- dataOut.write(command.toArray)
+ PythonWorkerUtils.writePythonFunction(func, dataOut)
dataOut.flush()
val dataIn = new DataInputStream(
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
index 4ad9bfe3ec58..01fb3bd7ac6a 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
@@ -132,8 +132,7 @@ object PythonUDTFRunner {
case None =>
dataOut.writeInt(0)
}
- dataOut.writeInt(udtf.func.command.length)
- dataOut.write(udtf.func.command.toArray)
+ PythonWorkerUtils.writePythonFunction(udtf.func, dataOut)
PythonWorkerUtils.writeUTF(udtf.elementSchema.json, dataOut)
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
index 12c51506b13c..37e95b608cc7 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
@@ -152,8 +152,7 @@ object PythonUDFRunner {
}
dataOut.writeInt(chained.funcs.length)
chained.funcs.foreach { f =>
- dataOut.writeInt(f.command.length)
- dataOut.write(f.command.toArray)
+ PythonWorkerUtils.writePythonFunction(f, dataOut)
}
}
}
@@ -178,8 +177,7 @@ object PythonUDFRunner {
}
dataOut.writeInt(chained.funcs.length)
chained.funcs.foreach { f =>
- dataOut.writeInt(f.command.length)
- dataOut.write(f.command.toArray)
+ PythonWorkerUtils.writePythonFunction(f, dataOut)
}
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
index 6400f7978afd..6e053167dc96 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.python
import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream,
DataOutputStream, EOFException, InputStream}
import java.nio.ByteBuffer
import java.nio.channels.SelectionKey
-import java.nio.charset.StandardCharsets
import java.util.HashMap
import scala.collection.mutable.ArrayBuffer
@@ -244,8 +243,7 @@ object UserDefinedPythonTableFunction {
PythonWorkerUtils.writeBroadcasts(broadcastVars, worker, env, dataOut)
// Send Python UDTF
- dataOut.writeInt(func.command.length)
- dataOut.write(func.command.toArray)
+ PythonWorkerUtils.writePythonFunction(func, dataOut)
// Send arguments
dataOut.writeInt(exprs.length)
@@ -276,40 +274,30 @@ object UserDefinedPythonTableFunction {
val dataIn = new DataInputStream(new BufferedInputStream(
new WorkerInputStream(worker, bufferStream.toByteBuffer), bufferSize))
- // Receive the schema.
- val schema = dataIn.readInt() match {
- case length if length >= 0 =>
- val obj = new Array[Byte](length)
- dataIn.readFully(obj)
- DataType.fromJson(new String(obj,
StandardCharsets.UTF_8)).asInstanceOf[StructType]
-
- case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
- val exLength = dataIn.readInt()
- val obj = new Array[Byte](exLength)
- dataIn.readFully(obj)
- val msg = new String(obj, StandardCharsets.UTF_8)
- throw
QueryCompilationErrors.tableValuedFunctionFailedToAnalyseInPythonError(msg)
+ // Receive the schema or an exception raised in Python worker.
+ val length = dataIn.readInt()
+ if (length == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+ val msg = PythonWorkerUtils.readUTF(dataIn)
+ throw
QueryCompilationErrors.tableValuedFunctionFailedToAnalyseInPythonError(msg)
}
+
+ val schema = DataType.fromJson(
+ PythonWorkerUtils.readUTF(length, dataIn)).asInstanceOf[StructType]
+
// Receive whether the "with single partition" property is requested.
val withSinglePartition = dataIn.readInt() == 1
// Receive the list of requested partitioning columns, if any.
val partitionByColumns = ArrayBuffer.empty[Expression]
val numPartitionByColumns = dataIn.readInt()
for (_ <- 0 until numPartitionByColumns) {
- val length = dataIn.readInt()
- val obj = new Array[Byte](length)
- dataIn.readFully(obj)
- val columnName = new String(obj, StandardCharsets.UTF_8)
+ val columnName = PythonWorkerUtils.readUTF(dataIn)
partitionByColumns.append(UnresolvedAttribute(columnName))
}
// Receive the list of requested ordering columns, if any.
val orderBy = ArrayBuffer.empty[SortOrder]
val numOrderByItems = dataIn.readInt()
for (_ <- 0 until numOrderByItems) {
- val length = dataIn.readInt()
- val obj = new Array[Byte](length)
- dataIn.readFully(obj)
- val columnName = new String(obj, StandardCharsets.UTF_8)
+ val columnName = PythonWorkerUtils.readUTF(dataIn)
val direction = if (dataIn.readInt() == 1) Ascending else Descending
val overrideNullsFirst = dataIn.readInt()
overrideNullsFirst match {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]