This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new d4de9131ae31 [SPARK-54615][PYTHON] Always pass runner_conf to python
worker
d4de9131ae31 is described below
commit d4de9131ae31ba063ff67892b05106f77023e3fc
Author: Tian Gao <[email protected]>
AuthorDate: Mon Dec 8 07:50:17 2025 +0900
[SPARK-54615][PYTHON] Always pass runner_conf to python worker
### What changes were proposed in this pull request?
Always pass runnerConf to python worker, even if it's not used.
### Why are the changes needed?
This is part of the effort to consolidate our protocol from JVM to the
worker. We have different ways to pass the runner conf now and sometimes we
just don't pass it. It makes the worker side code a bit messy - we need to
determine whether to read the conf based on eval type. However reading an empty
conf is super cheap and we can just do it regardless.
With this infra, vanilla python udfs can also pass some runner conf in the
future. We can do some refactoring on our JVM worker code in the future.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
`pyspark-sql` passed locally. Running the rest on CI.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #53353 from gaogaotiantian/always-pass-runnerconf.
Authored-by: Tian Gao <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../org/apache/spark/api/python/PythonRunner.scala | 14 ++++++++++++++
python/pyspark/worker.py | 21 +++++++++------------
.../sql/execution/python/ArrowPythonRunner.scala | 12 +++++-------
.../execution/python/ArrowPythonUDTFRunner.scala | 2 +-
.../python/CoGroupedArrowPythonRunner.scala | 12 ++----------
.../sql/execution/python/PythonArrowInput.scala | 15 +++------------
.../ApplyInPandasWithStatePythonRunner.scala | 4 ++--
.../TransformWithStateInPySparkPythonRunner.scala | 14 +++++++-------
8 files changed, 43 insertions(+), 51 deletions(-)
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index 3d885ffdb02d..63484c23a920 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -212,6 +212,8 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
protected val hideTraceback: Boolean = false
protected val simplifiedTraceback: Boolean = false
+ protected val runnerConf: Map[String, String] = Map.empty
+
// All the Python functions should have the same exec, version and envvars.
protected val envVars: java.util.Map[String, String] =
funcs.head.funcs.head.envVars
protected val pythonExec: String = funcs.head.funcs.head.pythonExec
@@ -403,6 +405,17 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
*/
protected def writeCommand(dataOut: DataOutputStream): Unit
+ /**
+ * Writes worker configuration to the stream connected to the Python
worker.
+ */
+ protected def writeRunnerConf(dataOut: DataOutputStream): Unit = {
+ dataOut.writeInt(runnerConf.size)
+ for ((k, v) <- runnerConf) {
+ PythonWorkerUtils.writeUTF(k, dataOut)
+ PythonWorkerUtils.writeUTF(v, dataOut)
+ }
+ }
+
/**
* Writes input data to the stream connected to the Python worker.
* Returns true if any data was written to the stream, false if the input
is exhausted.
@@ -532,6 +545,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
PythonWorkerUtils.writeBroadcasts(broadcastVars, worker, env, dataOut)
dataOut.writeInt(evalType)
+ writeRunnerConf(dataOut)
writeCommand(dataOut)
dataOut.flush()
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 109157e2c339..65dcbbbf23e6 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -1514,10 +1514,8 @@ def read_single_udf(pickleSer, infile, eval_type,
runner_conf, udf_index, profil
# It expects the UDTF to be in a specific format and performs various checks to
# ensure the UDTF is valid. This function also prepares a mapper function for
applying
# the UDTF logic to input rows.
-def read_udtf(pickleSer, infile, eval_type):
+def read_udtf(pickleSer, infile, eval_type, runner_conf):
if eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF:
- # Load conf used for arrow evaluation.
- runner_conf = RunnerConf(infile)
input_types = [
field.dataType for field in
_parse_datatype_json_string(utf8_deserializer.loads(infile))
]
@@ -1532,7 +1530,6 @@ def read_udtf(pickleSer, infile, eval_type):
else:
ser = ArrowStreamUDTFSerializer()
elif eval_type == PythonEvalType.SQL_ARROW_UDTF:
- runner_conf = RunnerConf(infile)
# Read the table argument offsets
num_table_arg_offsets = read_int(infile)
table_arg_offsets = [read_int(infile) for _ in
range(num_table_arg_offsets)]
@@ -1540,7 +1537,6 @@ def read_udtf(pickleSer, infile, eval_type):
ser =
ArrowStreamArrowUDTFSerializer(table_arg_offsets=table_arg_offsets)
else:
# Each row is a group so do not batch but send one by one.
- runner_conf = RunnerConf()
ser = BatchedSerializer(CPickleSerializer(), 1)
# See 'PythonUDTFRunner.PythonUDFWriterThread.writeCommand'
@@ -2688,7 +2684,7 @@ def read_udtf(pickleSer, infile, eval_type):
return mapper, None, ser, ser
-def read_udfs(pickleSer, infile, eval_type):
+def read_udfs(pickleSer, infile, eval_type, runner_conf):
state_server_port = None
key_schema = None
if eval_type in (
@@ -2716,9 +2712,6 @@ def read_udfs(pickleSer, infile, eval_type):
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF,
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF,
):
- # Load conf used for pandas_udf evaluation
- runner_conf = RunnerConf(infile)
-
state_object_schema = None
if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
state_object_schema =
StructType.fromJson(json.loads(utf8_deserializer.loads(infile)))
@@ -2870,7 +2863,6 @@ def read_udfs(pickleSer, infile, eval_type):
int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled,
)
else:
- runner_conf = RunnerConf()
batch_size = int(os.environ.get("PYTHON_UDF_BATCH_SIZE", "100"))
ser = BatchedSerializer(CPickleSerializer(), batch_size)
@@ -3353,6 +3345,7 @@ def main(infile, outfile):
_accumulatorRegistry.clear()
eval_type = read_int(infile)
+ runner_conf = RunnerConf(infile)
if eval_type == PythonEvalType.NON_UDF:
func, profiler, deserializer, serializer = read_command(pickleSer,
infile)
elif eval_type in (
@@ -3360,9 +3353,13 @@ def main(infile, outfile):
PythonEvalType.SQL_ARROW_TABLE_UDF,
PythonEvalType.SQL_ARROW_UDTF,
):
- func, profiler, deserializer, serializer = read_udtf(pickleSer,
infile, eval_type)
+ func, profiler, deserializer, serializer = read_udtf(
+ pickleSer, infile, eval_type, runner_conf
+ )
else:
- func, profiler, deserializer, serializer = read_udfs(pickleSer,
infile, eval_type)
+ func, profiler, deserializer, serializer = read_udfs(
+ pickleSer, infile, eval_type, runner_conf
+ )
init_time = time.time()
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
index f5f968ee9522..499fa99a2444 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
@@ -35,7 +35,6 @@ abstract class BaseArrowPythonRunner[IN, OUT <: AnyRef](
_schema: StructType,
_timeZoneId: String,
protected override val largeVarTypes: Boolean,
- protected override val workerConf: Map[String, String],
override val pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
sessionUUID: Option[String])
@@ -86,12 +85,11 @@ abstract class RowInputArrowPythonRunner(
_schema: StructType,
_timeZoneId: String,
largeVarTypes: Boolean,
- workerConf: Map[String, String],
pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
sessionUUID: Option[String])
extends BaseArrowPythonRunner[Iterator[InternalRow], ColumnarBatch](
- funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes,
workerConf,
+ funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes,
pythonMetrics, jobArtifactUUID, sessionUUID)
with BasicPythonArrowInput
with BasicPythonArrowOutput
@@ -106,13 +104,13 @@ class ArrowPythonRunner(
_schema: StructType,
_timeZoneId: String,
largeVarTypes: Boolean,
- workerConf: Map[String, String],
+ protected override val runnerConf: Map[String, String],
pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
sessionUUID: Option[String],
profiler: Option[String])
extends RowInputArrowPythonRunner(
- funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes,
workerConf,
+ funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes,
pythonMetrics, jobArtifactUUID, sessionUUID) {
override protected def writeUDF(dataOut: DataOutputStream): Unit =
@@ -130,13 +128,13 @@ class ArrowPythonWithNamedArgumentRunner(
_schema: StructType,
_timeZoneId: String,
largeVarTypes: Boolean,
- workerConf: Map[String, String],
+ protected override val runnerConf: Map[String, String],
pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
sessionUUID: Option[String],
profiler: Option[String])
extends RowInputArrowPythonRunner(
- funcs, evalType, argMetas.map(_.map(_.offset)), _schema, _timeZoneId,
largeVarTypes, workerConf,
+ funcs, evalType, argMetas.map(_.map(_.offset)), _schema, _timeZoneId,
largeVarTypes,
pythonMetrics, jobArtifactUUID, sessionUUID) {
override protected def writeUDF(dataOut: DataOutputStream): Unit = {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
index 1d5df9bad924..979d91205d5a 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
@@ -39,7 +39,7 @@ class ArrowPythonUDTFRunner(
protected override val schema: StructType,
protected override val timeZoneId: String,
protected override val largeVarTypes: Boolean,
- protected override val workerConf: Map[String, String],
+ protected override val runnerConf: Map[String, String],
override val pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
sessionUUID: Option[String])
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
index 7f6efbae8881..b5986be9214a 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
@@ -25,7 +25,7 @@ import org.apache.arrow.vector.{VectorSchemaRoot,
VectorUnloader}
import org.apache.arrow.vector.compression.{CompressionCodec,
NoCompressionCodec}
import org.apache.spark.{SparkEnv, SparkException, TaskContext}
-import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions,
PythonRDD, PythonWorker}
+import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions,
PythonWorker}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.arrow.ArrowWriterWrapper
import org.apache.spark.sql.execution.metric.SQLMetric
@@ -45,7 +45,7 @@ class CoGroupedArrowPythonRunner(
rightSchema: StructType,
timeZoneId: String,
largeVarTypes: Boolean,
- conf: Map[String, String],
+ protected override val runnerConf: Map[String, String],
override val pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
sessionUUID: Option[String],
@@ -119,14 +119,6 @@ class CoGroupedArrowPythonRunner(
private var rightGroupArrowWriter: ArrowWriterWrapper = null
protected override def writeCommand(dataOut: DataOutputStream): Unit = {
-
- // Write config for the worker as a number of key -> value pairs of
strings
- dataOut.writeInt(conf.size)
- for ((k, v) <- conf) {
- PythonRDD.writeUTF(k, dataOut)
- PythonRDD.writeUTF(v, dataOut)
- }
-
PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets, profiler)
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
index f77b0a9342b0..d2d16b0c9623 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
@@ -27,7 +27,7 @@ import org.apache.arrow.vector.ipc.WriteChannel
import org.apache.arrow.vector.ipc.message.MessageSerializer
import org.apache.spark.{SparkEnv, SparkException, TaskContext}
-import org.apache.spark.api.python.{BasePythonRunner, PythonRDD, PythonWorker}
+import org.apache.spark.api.python.{BasePythonRunner, PythonWorker}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.arrow
import org.apache.spark.sql.execution.arrow.{ArrowWriter, ArrowWriterWrapper}
@@ -42,8 +42,6 @@ import org.apache.spark.util.Utils
* JVM (an iterator of internal rows + additional data if required) to Python
(Arrow).
*/
private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] =>
- protected val workerConf: Map[String, String]
-
protected val schema: StructType
protected val timeZoneId: String
@@ -62,14 +60,8 @@ private[python] trait PythonArrowInput[IN] { self:
BasePythonRunner[IN, _] =>
protected def writeUDF(dataOut: DataOutputStream): Unit
- protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = {
- // Write config for the worker as a number of key -> value pairs of strings
- stream.writeInt(workerConf.size)
- for ((k, v) <- workerConf) {
- PythonRDD.writeUTF(k, stream)
- PythonRDD.writeUTF(v, stream)
- }
- }
+ protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = {}
+
private val arrowSchema = ArrowUtils.toArrowSchema(
schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes)
protected val allocator =
@@ -301,7 +293,6 @@ private[python] trait GroupedPythonArrowInput { self:
RowInputArrowPythonRunner
context: TaskContext): Writer = {
new Writer(env, worker, inputIterator, partitionIndex, context) {
protected override def writeCommand(dataOut: DataOutputStream): Unit = {
- handleMetadataBeforeExec(dataOut)
writeUDF(dataOut)
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
index 14054ba89a94..ae89ff1637ed 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
@@ -58,7 +58,7 @@ class ApplyInPandasWithStatePythonRunner(
argOffsets: Array[Array[Int]],
inputSchema: StructType,
_timeZoneId: String,
- initialWorkerConf: Map[String, String],
+ initialRunnerConf: Map[String, String],
stateEncoder: ExpressionEncoder[Row],
keySchema: StructType,
outputSchema: StructType,
@@ -113,7 +113,7 @@ class ApplyInPandasWithStatePythonRunner(
// applyInPandasWithState has its own mechanism to construct the Arrow
RecordBatch instance.
// Configurations are both applied to executor and Python worker, set them
to the worker conf
// to let Python worker read the config properly.
- override protected val workerConf: Map[String, String] = initialWorkerConf +
+ override protected val runnerConf: Map[String, String] = initialRunnerConf +
(SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key ->
arrowMaxRecordsPerBatch.toString) +
(SQLConf.ARROW_EXECUTION_MAX_BYTES_PER_BATCH.key ->
arrowMaxBytesPerBatch.toString)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
index 3eb7c7e64d64..bbf7b9387526 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
@@ -52,7 +52,7 @@ class TransformWithStateInPySparkPythonRunner(
_schema: StructType,
processorHandle: StatefulProcessorHandleImpl,
_timeZoneId: String,
- initialWorkerConf: Map[String, String],
+ initialRunnerConf: Map[String, String],
override val pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
groupingKeySchema: StructType,
@@ -60,7 +60,7 @@ class TransformWithStateInPySparkPythonRunner(
eventTimeWatermarkForEviction: Option[Long])
extends TransformWithStateInPySparkPythonBaseRunner[InType](
funcs, evalType, argOffsets, _schema, processorHandle, _timeZoneId,
- initialWorkerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema,
+ initialRunnerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema,
batchTimestampMs, eventTimeWatermarkForEviction)
with PythonArrowInput[InType] {
@@ -126,7 +126,7 @@ class TransformWithStateInPySparkPythonInitialStateRunner(
initStateSchema: StructType,
processorHandle: StatefulProcessorHandleImpl,
_timeZoneId: String,
- initialWorkerConf: Map[String, String],
+ initialRunnerConf: Map[String, String],
override val pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
groupingKeySchema: StructType,
@@ -134,7 +134,7 @@ class TransformWithStateInPySparkPythonInitialStateRunner(
eventTimeWatermarkForEviction: Option[Long])
extends TransformWithStateInPySparkPythonBaseRunner[GroupedInType](
funcs, evalType, argOffsets, dataSchema, processorHandle, _timeZoneId,
- initialWorkerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema,
+ initialRunnerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema,
batchTimestampMs, eventTimeWatermarkForEviction)
with PythonArrowInput[GroupedInType] {
@@ -221,7 +221,7 @@ abstract class
TransformWithStateInPySparkPythonBaseRunner[I](
_schema: StructType,
processorHandle: StatefulProcessorHandleImpl,
_timeZoneId: String,
- initialWorkerConf: Map[String, String],
+ initialRunnerConf: Map[String, String],
override val pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
groupingKeySchema: StructType,
@@ -238,7 +238,7 @@ abstract class
TransformWithStateInPySparkPythonBaseRunner[I](
protected val arrowMaxRecordsPerBatch = sqlConf.arrowMaxRecordsPerBatch
protected val arrowMaxBytesPerBatch = sqlConf.arrowMaxBytesPerBatch
- override protected val workerConf: Map[String, String] = initialWorkerConf +
+ override protected val runnerConf: Map[String, String] = initialRunnerConf +
(SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key ->
arrowMaxRecordsPerBatch.toString) +
(SQLConf.ARROW_EXECUTION_MAX_BYTES_PER_BATCH.key ->
arrowMaxBytesPerBatch.toString)
@@ -251,7 +251,7 @@ abstract class
TransformWithStateInPySparkPythonBaseRunner[I](
override protected def handleMetadataBeforeExec(stream: DataOutputStream):
Unit = {
super.handleMetadataBeforeExec(stream)
- // Also write the port/path number for state server
+ // Write the port/path number for state server
if (isUnixDomainSock) {
stream.writeInt(-1)
PythonWorkerUtils.writeUTF(stateServerSocketPath, stream)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]