This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4235048b2c5f [SPARK-55323][PYTHON] Move UDF metadata to EvalConf to 
simplify worker protocol
4235048b2c5f is described below

commit 4235048b2c5ff270579961c577ec41f06aa48a78
Author: Tian Gao <[email protected]>
AuthorDate: Tue Feb 3 13:43:44 2026 +0800

    [SPARK-55323][PYTHON] Move UDF metadata to EvalConf to simplify worker 
protocol
    
    ### What changes were proposed in this pull request?
    
    A new EvalConf is created to hold all the configurations related to the UDF 
(compared to spark configs for the runner). Eliminated 
`handleMetadataBeforeExec` logic with this new conf.
    
    ### Why are the changes needed?
    
    We should try to minimize the special protocol logic for different eval 
types. Passing them through a unified mapping could do that. It's also easier 
for us to add extra confs in the future.
    
    This PR intentionally does not remove all such logics on scala side. It's 
safe to do it bit by bit. After this PR is merged and runs properly for a 
while, I'll fix the rest.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Local `test_udf` works, CI should test the rest.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #54103 from gaogaotiantian/consolidate-eval-conf.
    
    Authored-by: Tian Gao <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../org/apache/spark/api/python/PythonRunner.scala |  2 +
 python/pyspark/worker.py                           | 63 ++++++++++++++--------
 .../sql/execution/python/PythonArrowInput.scala    |  3 --
 .../ApplyInPandasWithStatePythonRunner.scala       | 15 ++----
 .../TransformWithStateInPySparkPythonRunner.scala  | 21 +++-----
 5 files changed, 55 insertions(+), 49 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index f830dd2d8b6e..c3ee3853ce0f 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -215,6 +215,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
   protected val simplifiedTraceback: Boolean = false
 
   protected def runnerConf: Map[String, String] = Map.empty
+  protected def evalConf: Map[String, String] = Map.empty
 
   // All the Python functions should have the same exec, version and envvars.
   protected val envVars: java.util.Map[String, String] = 
funcs.head.funcs.head.envVars
@@ -516,6 +517,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
 
         dataOut.writeInt(evalType)
         PythonWorkerUtils.writeConf(runnerConf, dataOut)
+        PythonWorkerUtils.writeConf(evalConf, dataOut)
         writeCommand(dataOut)
 
         dataOut.flush()
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 7606ece86ee7..7299c6211cf1 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -180,6 +180,30 @@ class RunnerConf(Conf):
         return self.get("spark.sql.pyspark.dataSource.profiler", None)
 
 
+class EvalConf(Conf):
+    @property
+    def state_value_schema(self) -> Optional[StructType]:
+        schema = self.get("state_value_schema", None)
+        if schema is None:
+            return None
+        return StructType.fromJson(json.loads(schema))
+
+    @property
+    def grouping_key_schema(self) -> Optional[StructType]:
+        schema = self.get("grouping_key_schema", None)
+        if schema is None:
+            return None
+        return StructType.fromJson(json.loads(schema))
+
+    @property
+    def state_server_socket_port(self) -> Optional[int | str]:
+        port = self.get("state_server_socket_port", None)
+        try:
+            return int(port)
+        except ValueError:
+            return port
+
+
 def report_times(outfile, boot, init, finish, processing_time_ms):
     write_int(SpecialLengths.TIMING_DATA, outfile)
     write_long(int(1000 * boot), outfile)
@@ -2620,9 +2644,7 @@ def read_udtf(pickleSer, infile, eval_type, runner_conf):
         return mapper, None, ser, ser
 
 
-def read_udfs(pickleSer, infile, eval_type, runner_conf):
-    state_server_port = None
-    key_schema = None
+def read_udfs(pickleSer, infile, eval_type, runner_conf, eval_conf):
     if eval_type in (
         PythonEvalType.SQL_ARROW_BATCHED_UDF,
         PythonEvalType.SQL_SCALAR_PANDAS_UDF,
@@ -2649,20 +2671,6 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf):
         PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF,
         PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF,
     ):
-        state_object_schema = None
-        if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
-            state_object_schema = 
StructType.fromJson(json.loads(utf8_deserializer.loads(infile)))
-        elif (
-            eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF
-            or eval_type == 
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF
-            or eval_type == 
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF
-            or eval_type == 
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF
-        ):
-            state_server_port = read_int(infile)
-            if state_server_port == -1:
-                state_server_port = utf8_deserializer.loads(infile)
-            key_schema = 
StructType.fromJson(json.loads(utf8_deserializer.loads(infile)))
-
         # NOTE: if timezone is set here, that implies respectSessionTimeZone 
is True
         if (
             eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF
@@ -2711,7 +2719,7 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf):
                 runner_conf.timezone,
                 runner_conf.safecheck,
                 runner_conf.assign_cols_by_name,
-                state_object_schema,
+                eval_conf.state_value_schema,
                 runner_conf.arrow_max_records_per_batch,
                 runner_conf.use_large_var_types,
                 
int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled,
@@ -2952,7 +2960,9 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf):
         arg_offsets, f = udfs[0]
         parsed_offsets = extract_key_value_indexes(arg_offsets)
         ser.key_offsets = parsed_offsets[0][0]
-        stateful_processor_api_client = 
StatefulProcessorApiClient(state_server_port, key_schema)
+        stateful_processor_api_client = StatefulProcessorApiClient(
+            eval_conf.state_server_socket_port, eval_conf.grouping_key_schema
+        )
 
         def mapper(a):
             mode = a[0]
@@ -2987,7 +2997,9 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf):
         parsed_offsets = extract_key_value_indexes(arg_offsets)
         ser.key_offsets = parsed_offsets[0][0]
         ser.init_key_offsets = parsed_offsets[1][0]
-        stateful_processor_api_client = 
StatefulProcessorApiClient(state_server_port, key_schema)
+        stateful_processor_api_client = StatefulProcessorApiClient(
+            eval_conf.state_server_socket_port, eval_conf.grouping_key_schema
+        )
 
         def mapper(a):
             mode = a[0]
@@ -3017,7 +3029,9 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf):
         arg_offsets, f = udfs[0]
         parsed_offsets = extract_key_value_indexes(arg_offsets)
         ser.key_offsets = parsed_offsets[0][0]
-        stateful_processor_api_client = 
StatefulProcessorApiClient(state_server_port, key_schema)
+        stateful_processor_api_client = StatefulProcessorApiClient(
+            eval_conf.state_server_socket_port, eval_conf.grouping_key_schema
+        )
 
         def mapper(a):
             mode = a[0]
@@ -3048,7 +3062,9 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf):
         parsed_offsets = extract_key_value_indexes(arg_offsets)
         ser.key_offsets = parsed_offsets[0][0]
         ser.init_key_offsets = parsed_offsets[1][0]
-        stateful_processor_api_client = 
StatefulProcessorApiClient(state_server_port, key_schema)
+        stateful_processor_api_client = StatefulProcessorApiClient(
+            eval_conf.state_server_socket_port, eval_conf.grouping_key_schema
+        )
 
         def mapper(a):
             mode = a[0]
@@ -3340,6 +3356,7 @@ def main(infile, outfile):
         _accumulatorRegistry.clear()
         eval_type = read_int(infile)
         runner_conf = RunnerConf(infile)
+        eval_conf = EvalConf(infile)
         if eval_type == PythonEvalType.NON_UDF:
             func, profiler, deserializer, serializer = read_command(pickleSer, 
infile)
         elif eval_type in (
@@ -3352,7 +3369,7 @@ def main(infile, outfile):
             )
         else:
             func, profiler, deserializer, serializer = read_udfs(
-                pickleSer, infile, eval_type, runner_conf
+                pickleSer, infile, eval_type, runner_conf, eval_conf
             )
 
         init_time = time.time()
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
index 45011fa3cebc..a659cb599b2a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
@@ -61,8 +61,6 @@ private[python] trait PythonArrowInput[IN] { self: 
BasePythonRunner[IN, _] =>
 
   protected def writeUDF(dataOut: DataOutputStream): Unit
 
-  protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = {}
-
   protected lazy val allocator: BufferAllocator =
     ArrowUtils.rootAllocator.newChildAllocator(s"stdout writer for 
$pythonExec", 0, Long.MaxValue)
 
@@ -112,7 +110,6 @@ private[python] trait PythonArrowInput[IN] { self: 
BasePythonRunner[IN, _] =>
     new Writer(env, worker, inputIterator, partitionIndex, context) {
 
       protected override def writeCommand(dataOut: DataOutputStream): Unit = {
-        handleMetadataBeforeExec(dataOut)
         writeUDF(dataOut)
       }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
index 165b90046630..89d8e425fd2b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
@@ -116,22 +116,17 @@ class ApplyInPandasWithStatePythonRunner(
       SQLConf.ARROW_EXECUTION_MAX_BYTES_PER_BATCH.key -> 
arrowMaxBytesPerBatch.toString
     )
 
+  override protected def evalConf: Map[String, String] =
+    super.evalConf ++ Map(
+      "state_value_schema" -> stateValueSchema.json
+    )
+
   private val stateRowDeserializer = stateEncoder.createDeserializer()
 
   override protected def writeUDF(dataOut: DataOutputStream): Unit = {
     PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets)
   }
 
-  /**
-   * This method sends out the additional metadata before sending out actual 
data.
-   *
-   * Specifically, this class overrides this method to also write the schema 
for state value.
-   */
-  override protected def handleMetadataBeforeExec(stream: DataOutputStream): 
Unit = {
-    super.handleMetadataBeforeExec(stream)
-    // Also write the schema for state value
-    PythonRDD.writeUTF(stateValueSchema.json, stream)
-  }
   private var pandasWriter: ApplyInPandasWithStateWriter = _
   /**
    * Read the (key, state, values) from input iterator and construct Arrow 
RecordBatches, and
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
index 4f3b6c0b951d..05771d38cd84 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
@@ -28,7 +28,7 @@ import org.apache.arrow.vector.VectorSchemaRoot
 import org.apache.arrow.vector.ipc.ArrowStreamWriter
 
 import org.apache.spark.{SparkEnv, SparkException, TaskContext}
-import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, 
PythonFunction, PythonRDD, PythonWorkerUtils, StreamingPythonRunner}
+import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, 
PythonFunction, PythonWorkerUtils, StreamingPythonRunner}
 import org.apache.spark.internal.Logging
 import org.apache.spark.internal.config.Python.{PYTHON_UNIX_DOMAIN_SOCKET_DIR, 
PYTHON_UNIX_DOMAIN_SOCKET_ENABLED}
 import org.apache.spark.sql.catalyst.InternalRow
@@ -244,21 +244,16 @@ abstract class 
TransformWithStateInPySparkPythonBaseRunner[I](
       SQLConf.ARROW_EXECUTION_MAX_BYTES_PER_BATCH.key -> 
arrowMaxBytesPerBatch.toString
     )
 
+  override protected def evalConf: Map[String, String] =
+    super.evalConf ++ Map(
+      "grouping_key_schema" -> groupingKeySchema.json,
+      "state_server_socket_port" ->
+        (if (isUnixDomainSock) stateServerSocketPath else 
stateServerSocketPort.toString)
+    )
+
   override protected val errorOnDuplicatedFieldNames: Boolean = true
   override protected val largeVarTypes: Boolean = sqlConf.arrowUseLargeVarTypes
 
-  override protected def handleMetadataBeforeExec(stream: DataOutputStream): 
Unit = {
-    super.handleMetadataBeforeExec(stream)
-    // Write the port/path number for state server
-    if (isUnixDomainSock) {
-      stream.writeInt(-1)
-      PythonWorkerUtils.writeUTF(stateServerSocketPath, stream)
-    } else {
-      stream.writeInt(stateServerSocketPort)
-    }
-    PythonRDD.writeUTF(groupingKeySchema.json, stream)
-  }
-
   override def compute(
       inputIterator: Iterator[I],
       partitionIndex: Int,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to