This is an automated email from the ASF dual-hosted git repository.
asf-gitbox-commits pushed a commit to branch branch-4.x
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.x by this push:
new 7e101f443458 [SPARK-56661] Introducing logical and physical planning
nodes for language-agnostic Spark UDFs
7e101f443458 is described below
commit 7e101f443458087d85954f170247b1421432328e
Author: Sven Weber <[email protected]>
AuthorDate: Fri May 22 09:42:47 2026 -0400
[SPARK-56661] Introducing logical and physical planning nodes for
language-agnostic Spark UDFs
### What changes were proposed in this pull request?
This PR introduces new logical and physical Catalyst nodes for
language-agnostic User Defined Functions (UDF) as part of [SPIP
SPARK-55278](https://issues.apache.org/jira/browse/SPARK-55278), which proposes
language-agnostic UDFs.
As a first step towards the goal of language-agnostic UDFs, we want to
target mapPartition UDFs like `pyspark.sql.DataFrame.mapInArrow`,
`pyspark.RDD.mapPartitions`, or `pyspark.sql.DataFrame.mapInArrow`. The
overarching goal is to deprecate the current, language-specific Catalyst nodes
(like `mapInArrow`). However, for now, the new nodes will exist in addition to
the old ones until the new framework has reach maturity.
In summary, this PR introduces:
- A new Catalyst Expression, `ExternalUDFExpression`, which captures
language-agnostic UDF properties (payload, name, etc.)
- A new Catalyst logical node, `ExternalUDF`, which serves as a base class
for all language-agnostic UDF nodes
- A new Catalyst logical node, `MapPartitionExternalUDF`, which is the new,
language-agnostic map partition node
- Catalyst physical nodes for both logical nodes
- `WorkerDispatcherManager` - A manager class which manages UDF Dispatchers
based on the target `UDFWorkerSpecification`
None of the changes introduced above are currently consumed in Spark.
### Why are the changes needed?
This is the first step toward language-agnostic UDF execution for Spark.
Existing physical and logical planning nodes need to be replaced eventually to
achieve this goal as they make language-specific assumptions.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
New unit-tests were added.
### Was this patch authored or co-authored using generative AI tooling?
Partially. However, the code was manually reviewed and adjusted.
Closes #55768 from
sven-weber-db/sven-weber_data/spark-56661-catalyst-and-udf.
Authored-by: Sven Weber <[email protected]>
Signed-off-by: Herman van Hövell <[email protected]>
(cherry picked from commit c2057a302ad719519f4493dcb2a60041e6b8b6ba)
Signed-off-by: Herman van Hövell <[email protected]>
---
core/pom.xml | 10 +
.../src/main/scala/org/apache/spark/SparkEnv.scala | 46 +++++
.../apache/spark/util/SparkUDFWorkerLogger.scala | 41 +++++
dev/deps/spark-deps-hadoop-3-hive-2.3 | 7 +
sql/catalyst/pom.xml | 10 +
.../expressions/ExternalUserDefinedFunction.scala | 76 ++++++++
.../logical/logicalExternalUDFOperators.scala | 67 +++++++
.../spark/sql/catalyst/trees/TreePatterns.scala | 1 +
.../org/apache/spark/sql/internal/SQLConf.scala | 11 ++
sql/core/pom.xml | 17 ++
.../org/apache/spark/sql/classic/Dataset.scala | 19 +-
.../spark/sql/execution/SparkStrategies.scala | 5 +
.../execution/externalUDF/ExternalUDFExec.scala | 89 +++++++++
.../execution/externalUDF/ExternalUDFPlanner.scala | 127 +++++++++++++
.../externalUDF/MapPartitionsExternalUDFExec.scala | 73 ++++++++
.../externalUDF/PythonUDFWorkerSpecification.scala | 116 ++++++++++++
.../sql/internal/BaseSessionStateBuilder.scala | 18 +-
.../apache/spark/sql/internal/SessionState.scala | 4 +-
.../externalUDF/ExternalUDFPlanningSuite.scala | 133 ++++++++++++++
.../PythonUDFWorkerSpecificationSuite.scala | 201 +++++++++++++++++++++
udf/worker/core/pom.xml | 5 +
.../udf/worker/core/UDFDispatcherFactory.scala | 41 +++++
.../udf/worker/core/UDFDispatcherManager.scala | 126 +++++++++++++
.../spark/udf/worker/core/WorkerLogger.scala | 4 +
.../core/direct/DirectWorkerDispatcher.scala | 3 +-
.../worker/core/DirectWorkerDispatcherSuite.scala | 54 +-----
.../udf/worker/core/TestDirectWorkerHelpers.scala | 79 ++++++++
.../worker/core/UDFDispatcherManagerSuite.scala | 131 ++++++++++++++
28 files changed, 1444 insertions(+), 70 deletions(-)
diff --git a/core/pom.xml b/core/pom.xml
index ce4103346a1f..53b0c6584adf 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -35,6 +35,16 @@
</properties>
<dependencies>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-udf-worker-proto_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-udf-worker-core_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ </dependency>
<dependency>
<groupId>org.scala-lang.modules</groupId>
<artifactId>scala-parallel-collections_${scala.binary.version}</artifactId>
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala
b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 7dcf66a60957..4e56c88501ed 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -47,8 +47,11 @@ import org.apache.spark.security.CryptoStreamUtils
import org.apache.spark.serializer.{JavaSerializer, Serializer,
SerializerManager}
import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.storage._
+import org.apache.spark.udf.worker.UDFWorkerSpecification
+import org.apache.spark.udf.worker.core.{UDFDispatcherFactory,
UDFDispatcherManager, WorkerDispatcher}
import org.apache.spark.util.{RpcUtils, Utils}
import org.apache.spark.util.ArrayImplicits._
+import org.apache.spark.util.SparkUDFWorkerLogger
/**
* :: DeveloperApi ::
@@ -120,6 +123,48 @@ class SparkEnv (
pythonExec: String, workerModule: String, daemonModule: String, envVars:
Map[String, String])
private val pythonWorkers = mutable.HashMap[PythonWorkersKey,
PythonWorkerFactory]()
+ /**
+ * :: Experimental ::
+ * Dispatcher factory to generate UDF worker dispatchers
+ * using the new UDF framework proposed in SPARK-55278.
+ * Initialized on first use via [[getExternalUDFDispatcher]].
+ */
+ @volatile private var udfDispatcherManager: Option[UDFDispatcherManager] =
None
+
+ private def createUDFDispatcherManager(): UDFDispatcherManager = {
+ val factory = new UDFDispatcherFactory {
+ override def createDispatcher(
+ workerSpec: UDFWorkerSpecification,
+ logger: org.apache.spark.udf.worker.core.WorkerLogger
+ ): WorkerDispatcher = {
+ // TODO [SPARK-55278]: Wire in the correct dispatcher factory
+ throw new UnsupportedOperationException(
+ "No UDF dispatcher factory configured. " +
+ "Set up a concrete factory for SPARK-55278.")
+ }
+ }
+ new UDFDispatcherManager(factory, new SparkUDFWorkerLogger())
+ }
+
+ /**
+ * :: Experimental ::
+ * Returns the [[WorkerDispatcher]] for the given worker
+ * specification via the [[UDFDispatcherManager]].
+ */
+ private[spark] def getExternalUDFDispatcher(
+ workerSpec: UDFWorkerSpecification): WorkerDispatcher = {
+ val manager : UDFDispatcherManager = udfDispatcherManager.getOrElse {
+ synchronized {
+ // Get or Else synchronized to protect
+ // against concurrent creation requests.
+ udfDispatcherManager.getOrElse {
+ createUDFDispatcherManager()
+ }
+ }
+ }
+ manager.getDispatcher(workerSpec)
+ }
+
// A general, soft-reference map for metadata needed during HadoopRDD split
computation
// (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats).
private[spark] val hadoopJobMetadata =
@@ -134,6 +179,7 @@ class SparkEnv (
if (!isStopped) {
isStopped = true
pythonWorkers.values.foreach(_.stop())
+ udfDispatcherManager.foreach(_.close())
mapOutputTracker.stop()
if (shuffleManager != null) {
shuffleManager.stop()
diff --git
a/core/src/main/scala/org/apache/spark/util/SparkUDFWorkerLogger.scala
b/core/src/main/scala/org/apache/spark/util/SparkUDFWorkerLogger.scala
new file mode 100644
index 000000000000..d18aebd348ce
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/SparkUDFWorkerLogger.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.util
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.udf.worker.core.WorkerLogger
+
+/**
+ * Adapts the UDF worker framework's [[WorkerLogger]] to Spark's
+ * [[Logging]] trait so that worker log messages go through the
+ * standard Spark logging pipeline.
+ */
+private[spark] class SparkUDFWorkerLogger
+ extends WorkerLogger with Logging {
+
+ override def info(msg: => String): Unit = {
+ logInfo(msg)
+ }
+ override def info(msg: => String, t: Throwable): Unit =
+ logInfo(msg, t)
+ override def warn(msg: => String): Unit = logWarning(msg)
+ override def warn(msg: => String, t: Throwable): Unit =
+ logWarning(msg, t)
+ override def debug(msg: => String): Unit = logDebug(msg)
+ override def debug(msg: => String, t: Throwable): Unit =
+ logDebug(msg, t)
+}
diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3
b/dev/deps/spark-deps-hadoop-3-hive-2.3
index fbf8f06a3404..842f841e1438 100644
--- a/dev/deps/spark-deps-hadoop-3-hive-2.3
+++ b/dev/deps/spark-deps-hadoop-3-hive-2.3
@@ -12,6 +12,7 @@ aliyun-java-sdk-kms/2.11.0//aliyun-java-sdk-kms-2.11.0.jar
aliyun-java-sdk-ram/3.1.0//aliyun-java-sdk-ram-3.1.0.jar
aliyun-sdk-oss/3.18.1//aliyun-sdk-oss-3.18.1.jar
analyticsaccelerator-s3/1.3.1//analyticsaccelerator-s3-1.3.1.jar
+animal-sniffer-annotations/1.24//animal-sniffer-annotations-1.24.jar
antlr-runtime/3.5.2//antlr-runtime-3.5.2.jar
antlr4-runtime/4.13.1//antlr4-runtime-4.13.1.jar
aopalliance-repackaged/3.0.6//aopalliance-repackaged-3.0.6.jar
@@ -64,10 +65,15 @@ derbyshared/10.16.1.1//derbyshared-10.16.1.1.jar
derbytools/10.16.1.1//derbytools-10.16.1.1.jar
dom4j/2.1.4//dom4j-2.1.4.jar
dropwizard-metrics-hadoop-metrics2-reporter/0.1.2//dropwizard-metrics-hadoop-metrics2-reporter-0.1.2.jar
+error_prone_annotations/2.18.0//error_prone_annotations-2.18.0.jar
esdk-obs-java/3.20.4.2//esdk-obs-java-3.20.4.2.jar
failureaccess/1.0.3//failureaccess-1.0.3.jar
flatbuffers-java/25.2.10//flatbuffers-java-25.2.10.jar
gmetric4j/1.0.10//gmetric4j-1.0.10.jar
+grpc-api/1.76.0//grpc-api-1.76.0.jar
+grpc-protobuf-lite/1.76.0//grpc-protobuf-lite-1.76.0.jar
+grpc-protobuf/1.76.0//grpc-protobuf-1.76.0.jar
+grpc-stub/1.76.0//grpc-stub-1.76.0.jar
gson/2.13.2//gson-2.13.2.jar
guava/33.6.0-jre//guava-33.6.0-jre.jar
hadoop-aliyun/3.5.0//hadoop-aliyun-3.5.0.jar
@@ -248,6 +254,7 @@
parquet-format-structures/1.17.0//parquet-format-structures-1.17.0.jar
parquet-hadoop/1.17.0//parquet-hadoop-1.17.0.jar
parquet-jackson/1.17.0//parquet-jackson-1.17.0.jar
pickle/1.5//pickle-1.5.jar
+proto-google-common-protos/2.59.2//proto-google-common-protos-2.59.2.jar
py4j/0.10.9.9//py4j-0.10.9.9.jar
reactive-streams/1.0.3//reactive-streams-1.0.3.jar
remotetea-oncrpc/1.1.2//remotetea-oncrpc-1.1.2.jar
diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
index fc4ed86bcabb..bbc4eb262506 100644
--- a/sql/catalyst/pom.xml
+++ b/sql/catalyst/pom.xml
@@ -62,6 +62,16 @@
<groupId>org.apache.spark</groupId>
<artifactId>spark-tags_${scala.binary.version}</artifactId>
</dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-udf-worker-proto_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-udf-worker-core_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ </dependency>
<!--
This spark-tags test-dep is needed even though it isn't used in this
module, otherwise testing-cmds that exclude
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExternalUserDefinedFunction.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExternalUserDefinedFunction.scala
new file mode 100644
index 000000000000..ede4adf32e11
--- /dev/null
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExternalUserDefinedFunction.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.sql.catalyst.trees.TreePattern.{EXTERNAL_UDF,
TreePattern}
+import org.apache.spark.sql.types.DataType
+
+/**
+ * :: Experimental ::
+ * A serialized external UDF that is executed in an external worker process
+ * via the language-agnostic UDF worker framework.
+ *
+ * This is a Catalyst expression analogous to [[PythonUDF]] but
+ * language-agnostic. The [[payload]] carries an opaque serialized
+ * function definition whose interpretation is left to the worker.
+ * The optional [[inputTypes]] declare the expected argument types for
+ * validation during analysis; when absent, any input types are accepted.
+ *
+ * This expression is [[Unevaluable]] and requires a dedicated physical
+ * operator (e.g.
[[org.apache.spark.sql.execution.externalUDF.MapPartitionsExternalUDFExec]])
+ * to execute.
+ *
+ * @param name Optional name of the UDF.
+ * @param payload Opaque serialized function definition.
+ * @param dataType Return type of the UDF.
+ * @param children Input argument expressions.
+ * @param inputTypes Optional declared input types for validation.
+ * @param udfDeterministic Whether this UDF is deterministic.
+ * @param udfNullable Whether this UDF can return null.
+ * @param resultId Unique expression ID for this invocation.
+ */
+@Experimental
+case class ExternalUserDefinedFunction(
+ name: Option[String],
+ payload: Array[Byte],
+ dataType: DataType,
+ children: Seq[Expression],
+ inputTypes: Option[Seq[DataType]] = None,
+ udfDeterministic: Boolean,
+ udfNullable: Boolean,
+ resultId: ExprId = NamedExpression.newExprId)
+ extends Expression with NonSQLExpression with Unevaluable {
+
+ override lazy val deterministic: Boolean = udfDeterministic &&
children.forall(_.deterministic)
+
+ override def nullable: Boolean = udfNullable
+
+ override lazy val canonicalized: Expression = {
+ val canonicalizedChildren = children.map(_.canonicalized)
+ // `resultId` can be seen as cosmetic variation in
ExternalUserDefinedFunction,
+ // as it doesn't affect the result.
+ this.copy(resultId = ExprId(-1)).withNewChildren(canonicalizedChildren)
+ }
+
+ final override val nodePatterns: Seq[TreePattern] = Seq(EXTERNAL_UDF)
+
+ override protected def withNewChildrenInternal(
+ newChildren: IndexedSeq[Expression]): ExternalUserDefinedFunction =
+ copy(children = newChildren)
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/logicalExternalUDFOperators.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/logicalExternalUDFOperators.scala
new file mode 100644
index 000000000000..49c3e938d5b6
--- /dev/null
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/logicalExternalUDFOperators.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans.logical
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.sql.catalyst.expressions.{Attribute,
+ ExternalUserDefinedFunction}
+import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.udf.worker.UDFWorkerSpecification
+
+/**
+ * :: Experimental ::
+ * Base trait for logical plan nodes representing UDFs that are executed
+ * in an external worker process. This covers Python UDFs, and any future
+ * UDF languages that use the language-agnostic UDF worker framework.
+ */
+@Experimental
+trait ExternalUDF extends UnaryNode {
+
+ /** Specification describing how to create and communicate with the UDF
worker. */
+ def workerSpec: UDFWorkerSpecification
+}
+
+/**
+ * :: Experimental ::
+ * Logical plan node for mapPartitions-style UDF execution in an
+ * external worker process.
+ *
+ * @param workerSpec Specification describing the UDF worker.
+ * @param function The UDF to invoke. Output attributes are
+ * derived from `function.dataType`.
+ * @param isBarrier Whether to use barrier execution.
+ * @param child Input relation whose partitions are processed.
+ */
+@Experimental
+case class MapPartitionsExternalUDF(
+ workerSpec: UDFWorkerSpecification,
+ function: ExternalUserDefinedFunction,
+ isBarrier: Boolean,
+ child: LogicalPlan)
+ extends ExternalUDF {
+
+ // Map partitions always operate on StructTypes
+ override def output: Seq[Attribute] = toAttributes(
+ function.dataType.asInstanceOf[StructType]
+ )
+
+ override protected def withNewChildInternal(
+ newChild: LogicalPlan): MapPartitionsExternalUDF =
+ copy(child = newChild)
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
index d94a506da82d..cca9bcd673d6 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
@@ -48,6 +48,7 @@ object TreePattern extends Enumeration {
val DYNAMIC_PRUNING_SUBQUERY: Value = Value
val EXISTS_SUBQUERY = Value
val EXPRESSION_WITH_RANDOM_SEED: Value = Value
+ val EXTERNAL_UDF: Value = Value
val EXTRACT_VALUE: Value = Value
val FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION: Value = Value
val GENERATOR: Value = Value
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index b3de3f0813a7..c34b52e15dbc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -4692,6 +4692,17 @@ object SQLConf {
.booleanConf
.createWithDefault(false)
+ val UNIFIED_UDF_EXECUTION_ENABLED =
+ buildConf("spark.sql.execution.udf.unified.execution.enabled")
+ .doc("When true, UDFs that support the language-agnostic " +
+ "UDF worker protocol are executed via the unified, " +
+ "external UDF worker framework instead of the " +
+ "language-specific runners. Experimental.")
+ .version("4.2.0")
+ .withBindingPolicy(ConfigBindingPolicy.SESSION)
+ .booleanConf
+ .createWithDefault(false)
+
val PYTHON_UDF_ARROW_ENABLED =
buildConf("spark.sql.execution.pythonUDF.arrow.enabled")
.doc("Enable Arrow optimization in regular Python UDFs. This
optimization " +
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index 883395a0ef33..43417797f72e 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -302,6 +302,23 @@
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-compression</artifactId>
</dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-udf-worker-proto_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-udf-worker-core_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-udf-worker-core_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
</dependencies>
<build>
<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
index d83a4df51cd5..833b3f451273 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
@@ -49,7 +49,6 @@ import org.apache.spark.sql.catalyst.parser.{ParseException,
ParserUtils}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNodeTag,
TreePattern}
-import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, IntervalUtils}
import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId
import org.apache.spark.sql.classic.ClassicConversions._
@@ -1520,15 +1519,10 @@ class Dataset[T] private[sql](
funcCol: Column,
isBarrier: Boolean = false,
profile: ResourceProfile = null): DataFrame = {
- val func = funcCol.expr
Dataset.ofRows(
sparkSession,
- MapInPandas(
- func,
- toAttributes(func.dataType.asInstanceOf[StructType]),
- logicalPlan,
- isBarrier,
- Option(profile)))
+ sparkSession.sessionState.externalUDFPlanner.planPythonMapInPandas(
+ funcCol.expr, logicalPlan, isBarrier, Option(profile)))
}
/**
@@ -1540,15 +1534,10 @@ class Dataset[T] private[sql](
funcCol: Column,
isBarrier: Boolean = false,
profile: ResourceProfile = null): DataFrame = {
- val func = funcCol.expr
Dataset.ofRows(
sparkSession,
- MapInArrow(
- func,
- toAttributes(func.dataType.asInstanceOf[StructType]),
- logicalPlan,
- isBarrier,
- Option(profile)))
+ sparkSession.sessionState.externalUDFPlanner.planPythonMapInArrow(
+ funcCol.expr, logicalPlan, isBarrier, Option(profile)))
}
/** @inheritdoc */
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 92818c12bfa0..455933d8e085 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -972,6 +972,11 @@ abstract class SparkStrategies extends
QueryPlanner[SparkPlan] {
execution.python.MapInPandasExec(func, output, planLater(child),
isBarrier, profile) :: Nil
case logical.MapInArrow(func, output, child, isBarrier, profile) =>
execution.python.MapInArrowExec(func, output, planLater(child),
isBarrier, profile) :: Nil
+ case logical.MapPartitionsExternalUDF(
+ workerSpec, functionExpr, isBarrier, child) =>
+ execution.externalUDF.MapPartitionsExternalUDFExec(
+ workerSpec, functionExpr,
+ isBarrier, planLater(child)) :: Nil
case logical.AttachDistributedSequence(attr, child, cache) =>
execution.python.AttachDistributedSequenceExec(attr, planLater(child),
cache) :: Nil
case logical.PythonWorkerLogs(jsonAttr) =>
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/externalUDF/ExternalUDFExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/externalUDF/ExternalUDFExec.scala
new file mode 100644
index 000000000000..541b03c52a0b
--- /dev/null
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/externalUDF/ExternalUDFExec.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.externalUDF
+
+import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.UnaryExecNode
+import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.udf.worker.UDFWorkerSpecification
+import org.apache.spark.udf.worker.core.{WorkerSecurityScope, WorkerSession}
+
+/**
+ * :: Experimental ::
+ * Base trait for physical plan nodes that execute UDFs in an external
+ * worker process via the language-agnostic UDF worker framework.
+ *
+ * Dispatchers are obtained via [[SparkEnv#getExternalUDFDispatcher]],
+ * which uses the [[UDFDispatcherManager]] registered on the
+ * environment. This avoids serializing the manager as part of the
+ * physical plan.
+ */
+@Experimental
+trait ExternalUDFExec extends UnaryExecNode {
+
+ /**
+ * Specification describing how to create and communicate with the UDF
worker.
+ * There is exactly one specification per [[ExternalUDFExec]] node.
+ */
+ def workerSpec: UDFWorkerSpecification
+
+ //
---------------------------------------------------------------------------
+ // Metrics
+ //
---------------------------------------------------------------------------
+
+ protected def externalUdfMetrics: Map[String, SQLMetric] = Map(
+ // TODO [SPARK-55278]: Emit the correct metrics here
+ )
+
+ override lazy val metrics: Map[String, SQLMetric] = externalUdfMetrics
+
+ //
---------------------------------------------------------------------------
+ // Session lifecycle
+ //
---------------------------------------------------------------------------
+
+ /**
+ * Creates a [[WorkerSession]] via [[SparkEnv#getExternalUDFDispatcher]].
+ * Registers session cancellation on task failure and session termination
+ * on task completion. The provided function receives the session
+ * and must return the result iterator. The function CAN but MUST NOT
+ * cancel or close the session.
+ */
+ protected def withUDFWorkerSession(
+ taskContext: TaskContext,
+ securityScope: Option[WorkerSecurityScope] = None)(
+ f: WorkerSession => Iterator[InternalRow]
+ ): Iterator[InternalRow] = {
+ val dispatcher = SparkEnv.get.getExternalUDFDispatcher(
+ workerSpec)
+ val session = dispatcher.createSession(securityScope)
+
+ // Make sure to cancel the session, if the task fails
+ taskContext.addTaskFailureListener { (_, _) =>
+ session.cancel()
+ }
+
+ // Make sure to close the session once we are done
+ taskContext.addTaskCompletionListener[Unit] { _ =>
+ session.close()
+ }
+
+ f(session)
+ }
+}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/externalUDF/ExternalUDFPlanner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/externalUDF/ExternalUDFPlanner.scala
new file mode 100644
index 000000000000..589fa3e1696b
--- /dev/null
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/externalUDF/ExternalUDFPlanner.scala
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.externalUDF
+
+import org.apache.spark.SparkConf
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.resource.ResourceProfile
+import org.apache.spark.sql.catalyst.expressions.{Expression,
+ ExternalUserDefinedFunction, PythonUDF}
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan,
+ MapInArrow, MapInPandas, MapPartitionsExternalUDF}
+import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
+import org.apache.spark.sql.types.StructType
+
+/**
+ * Strategy for converting UDF calls into logical plan nodes.
+ * Implementations decide whether to use the classic
+ * language-specific runner or the unified external UDF worker
+ * framework.
+ *
+ * Wired into [[org.apache.spark.sql.internal.SessionState]] via
+ * [[org.apache.spark.sql.internal.BaseSessionStateBuilder]].
+ */
+trait ExternalUDFPlanner {
+
+ /**
+ * Creates the logical plan node for a Python mapInPandas
+ * operation.
+ */
+ def planPythonMapInPandas(
+ func: Expression,
+ child: LogicalPlan,
+ isBarrier: Boolean,
+ profile: Option[ResourceProfile]): LogicalPlan
+
+ /**
+ * Creates the logical plan node for a Python mapInArrow
+ * operation.
+ */
+ def planPythonMapInArrow(
+ func: Expression,
+ child: LogicalPlan,
+ isBarrier: Boolean,
+ profile: Option[ResourceProfile]): LogicalPlan
+}
+
+/**
+ * Classic [[ExternalUDFPlanner]] that uses the built-in Python
+ * runner.
+ */
+class ClassicExternalUDFPlanner extends ExternalUDFPlanner {
+
+ override def planPythonMapInPandas(
+ func: Expression,
+ child: LogicalPlan,
+ isBarrier: Boolean,
+ profile: Option[ResourceProfile]): LogicalPlan = {
+ val output = toAttributes(
+ func.dataType.asInstanceOf[StructType])
+ MapInPandas(func, output, child, isBarrier, profile)
+ }
+
+ override def planPythonMapInArrow(
+ func: Expression,
+ child: LogicalPlan,
+ isBarrier: Boolean,
+ profile: Option[ResourceProfile]): LogicalPlan = {
+ val output = toAttributes(
+ func.dataType.asInstanceOf[StructType])
+ MapInArrow(func, output, child, isBarrier, profile)
+ }
+}
+
+/**
+ * :: Experimental ::
+ * Unified [[ExternalUDFPlanner]] that uses the language-agnostic
+ * external UDF worker framework.
+ */
+@Experimental
+class UnifiedExternalUDFPlanner(
+ private val conf: SparkConf) extends ExternalUDFPlanner {
+
+ override def planPythonMapInPandas(
+ func: Expression,
+ child: LogicalPlan,
+ isBarrier: Boolean,
+ profile: Option[ResourceProfile]): LogicalPlan = {
+ val pythonUdf = func.asInstanceOf[PythonUDF]
+ val workerSpec =
+ PythonUDFWorkerSpecification.fromPythonFunction(
+ pythonUdf.func, conf)
+ val udf = ExternalUserDefinedFunction(
+ name = Some(pythonUdf.name),
+ payload = pythonUdf.func.command.toArray,
+ dataType = pythonUdf.dataType,
+ children = Seq.empty,
+ udfDeterministic = pythonUdf.udfDeterministic,
+ udfNullable = true)
+ MapPartitionsExternalUDF(workerSpec, udf, isBarrier, child)
+ }
+
+ override def planPythonMapInArrow(
+ func: Expression,
+ child: LogicalPlan,
+ isBarrier: Boolean,
+ profile: Option[ResourceProfile]): LogicalPlan = {
+ // TODO [SPARK-55278]: Implement unified mapInArrow support.
+ // For now, fall back to the classic path.
+ val output = toAttributes(
+ func.dataType.asInstanceOf[StructType])
+ MapInArrow(func, output, child, isBarrier, profile)
+ }
+}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/externalUDF/MapPartitionsExternalUDFExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/externalUDF/MapPartitionsExternalUDFExec.scala
new file mode 100644
index 000000000000..01d27ba115db
--- /dev/null
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/externalUDF/MapPartitionsExternalUDFExec.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.externalUDF
+
+import org.apache.spark.TaskContext
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{
+ Attribute,
+ ExternalUserDefinedFunction
+}
+import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.udf.worker.UDFWorkerSpecification
+
+/**
+ * :: Experimental ::
+ * Physical plan node that executes a mapPartitions-style UDF in an
+ * external worker process.
+ *
+ * @param workerSpec Specification describing the UDF worker.
+ * @param functionExpr The UDF to invoke.
+ * @param isBarrier Whether the UDF should be invoked using barrier
execution.
+ * @param resultAttributes Output attributes produced by the UDF.
+ * @param child Child plan providing input partitions.
+ */
+@Experimental
+case class MapPartitionsExternalUDFExec(
+ workerSpec: UDFWorkerSpecification,
+ function: ExternalUserDefinedFunction,
+ isBarrier: Boolean,
+ child: SparkPlan)
+ extends ExternalUDFExec {
+
+ // Map partitions always operate on StructTypes
+ override def output: Seq[Attribute] = toAttributes(
+ function.dataType.asInstanceOf[StructType]
+ )
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ child.execute().mapPartitionsInternal { rows =>
+ withUDFWorkerSession(TaskContext.get(), securityScope = None) {
+ session =>
+ // TODO [SPARK-55278]: Stream rows to/from the worker
+ // via session.process().
+ // scalastyle:off throwerror
+ throw new NotImplementedError("doExecute() is not yet implemented.")
+ // scalastyle:on throwerror
+ }
+ }
+ }
+
+ override protected def withNewChildInternal(
+ newChild: SparkPlan): MapPartitionsExternalUDFExec =
+ copy(child = newChild)
+}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/externalUDF/PythonUDFWorkerSpecification.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/externalUDF/PythonUDFWorkerSpecification.scala
new file mode 100644
index 000000000000..aa5a0850580a
--- /dev/null
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/externalUDF/PythonUDFWorkerSpecification.scala
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.externalUDF
+
+import scala.jdk.CollectionConverters._
+
+import org.apache.spark.SparkConf
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.python.{PythonFunction, PythonUtils}
+import org.apache.spark.internal.config.Python.PYTHON_WORKER_MODULE
+import org.apache.spark.udf.worker._
+
+/**
+ * :: Experimental ::
+ * Builds a [[UDFWorkerSpecification]] for Python UDFs from a
+ * [[PythonFunction]] and [[SparkConf]].
+ *
+ * Reuses the same information the existing
+ * [[org.apache.spark.api.python.PythonWorkerFactory]] uses:
+ * - `pythonExec` from the function
+ * - Environment variables from the function (which already
+ * contain the caller-assembled `PYTHONPATH`), merged with
+ * Spark's built-in Python path and the system `PYTHONPATH`
+ * - Worker module from `spark.python.worker.module`
+ *
+ * Note: `pythonIncludes` are not added to the process
+ * environment. They are sent over the data channel to the
+ * already-running worker by the runner (see
+ * [[org.apache.spark.api.python.PythonRunner]]).
+ */
+@Experimental
+object PythonUDFWorkerSpecification {
+
+ /**
+ * Creates a [[UDFWorkerSpecification]] from a [[PythonFunction]].
+ *
+ * @param func the Python function containing pythonExec, env vars,
+ * and includes
+ * @param conf the SparkConf for reading the worker module config
+ * @return a fully populated [[UDFWorkerSpecification]]
+ */
+ def fromPythonFunction(
+ func: PythonFunction,
+ conf: SparkConf): UDFWorkerSpecification = {
+
+ val workerModule = conf.get(PYTHON_WORKER_MODULE)
+ .getOrElse("pyspark.worker")
+
+ // Assemble PYTHONPATH the same way PythonWorkerFactory does
+ val pythonPath = PythonUtils.mergePythonPaths(
+ PythonUtils.sparkPythonPath,
+ func.envVars.asScala
+ .getOrElse("PYTHONPATH", ""),
+ sys.env.getOrElse("PYTHONPATH", ""))
+
+ // Merge func.envVars with the assembled PYTHONPATH
+ val envVars = new java.util.HashMap[String, String]()
+ envVars.putAll(func.envVars)
+ envVars.put("PYTHONPATH", pythonPath)
+ // Match PythonWorkerFactory behavior
+ envVars.put("PYTHONUNBUFFERED", "YES")
+ // Required by pyspark.worker_util to allow import
+ envVars.put("SPARK_PYTHON_RUNTIME", "PYTHON_WORKER")
+ // Enable the execution mode supporting the new UDF execution
+ // framework.
+ // TODO [SPARK-55278]: Enable this on the python code
+ envVars.put("PYTHON_WORKER_UNIFIED_EXECUTION_ENABLED", "YES")
+
+ // Build the ProcessCallable:
+ // command = [pythonExec, "-m", workerModule]
+ val callable = ProcessCallable.newBuilder()
+ callable.addCommand(func.pythonExec)
+ callable.addCommand("-m")
+ callable.addCommand(workerModule)
+ // TODO [SPARK-55278]: Add additional, python specific env vars
+ // or transform them into init-message fields
+ envVars.forEach((k, v) => callable.putEnvironmentVariables(k, v))
+
+ // Capabilities: ARROW data format, bidirectional streaming
+ val caps = WorkerCapabilities.newBuilder()
+ .addSupportedDataFormats(UDFWorkerDataFormat.ARROW)
+ .addSupportedCommunicationPatterns(
+ UDFProtoCommunicationPattern.BIDIRECTIONAL_STREAMING)
+
+ // Connection: Unix domain socket
+ val conn = WorkerConnectionSpec.newBuilder()
+ .setUnixDomainSocket(UnixDomainSocket.newBuilder())
+
+ val props = UDFWorkerProperties.newBuilder()
+ .setConnection(conn)
+
+ val direct = DirectWorker.newBuilder()
+ .setRunner(callable)
+ .setProperties(props)
+
+ UDFWorkerSpecification.newBuilder()
+ .setEnvironment(WorkerEnvironment.newBuilder())
+ .setCapabilities(caps)
+ .setDirect(direct)
+ .build()
+ }
+}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
index cc8c9dcb71f5..c82651595bc5 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
@@ -38,6 +38,8 @@ import
org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin
import org.apache.spark.sql.execution.command.{CheckViewReferences,
CommandCheck}
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.v2.{TableCapabilityCheck,
V2SessionCatalog}
+import org.apache.spark.sql.execution.externalUDF.{ClassicExternalUDFPlanner,
+ ExternalUDFPlanner, UnifiedExternalUDFPlanner}
import org.apache.spark.sql.execution.streaming.runtime.ResolveWriteToStream
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.util.ExecutionListenerManager
@@ -395,6 +397,19 @@ abstract class BaseSessionStateBuilder(
extensions.buildQueryPostPlannerStrategyRules(session))
}
+ /**
+ * Strategy for converting (Python) UDF calls into logical plan
+ * nodes. Uses the unified external UDF worker framework when
+ * the config is enabled, otherwise the classic Python runner.
+ */
+ protected def externalUDFPlanner: ExternalUDFPlanner = {
+ if (conf.getConf(SQLConf.UNIFIED_UDF_EXECUTION_ENABLED)) {
+ new UnifiedExternalUDFPlanner(session.sparkContext.conf)
+ } else {
+ new ClassicExternalUDFPlanner()
+ }
+ }
+
protected def planNormalizationRules: Seq[Rule[LogicalPlan]] = {
NormalizeCTEIds +:
extensions.buildPlanNormalizationRules(session)
@@ -476,7 +491,8 @@ abstract class BaseSessionStateBuilder(
columnarRules,
adaptiveRulesHolder,
planNormalizationRules,
- () => artifactManager)
+ () => artifactManager,
+ externalUDFPlanner)
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
index 2e921d0054e6..d2b0f901ccb4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
@@ -37,6 +37,7 @@ import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.AdaptiveRulesHolder
import org.apache.spark.sql.execution.datasources.DataSourceManager
+import org.apache.spark.sql.execution.externalUDF.ExternalUDFPlanner
import org.apache.spark.sql.util.ExecutionListenerManager
import org.apache.spark.util.{DependencyUtils, Utils}
@@ -92,7 +93,8 @@ private[sql] class SessionState(
val columnarRules: Seq[ColumnarRule],
val adaptiveRulesHolder: AdaptiveRulesHolder,
val planNormalizationRules: Seq[Rule[LogicalPlan]],
- val artifactManagerBuilder: () => ArtifactManager) {
+ val artifactManagerBuilder: () => ArtifactManager,
+ val externalUDFPlanner: ExternalUDFPlanner) {
// The following fields are lazy to avoid creating the Hive client when
creating SessionState.
lazy val catalog: SessionCatalog = catalogBuilder()
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/externalUDF/ExternalUDFPlanningSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/externalUDF/ExternalUDFPlanningSuite.scala
new file mode 100644
index 000000000000..7f10e349d0da
--- /dev/null
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/externalUDF/ExternalUDFPlanningSuite.scala
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.externalUDF
+
+import scala.collection.mutable.ArrayBuffer
+import scala.jdk.CollectionConverters._
+import scala.reflect.ClassTag
+
+import org.apache.spark.SparkConf
+import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction}
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan,
+ MapInPandas, MapPartitionsExternalUDF}
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.python.{MapInPandasExec,
+ UserDefinedPythonFunction}
+import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.{IntegerType, StringType,
+ StructField, StructType}
+
+/**
+ * Shared UDF fixtures and assertion helpers for planning tests.
+ */
+trait ExternalUDFPlanningTestBase extends SharedSparkSession {
+ import testImplicits._
+
+ protected val outputSchema: StructType = StructType(Seq(
+ StructField("a", StringType),
+ StructField("b", IntegerType)))
+
+ protected def dummyPythonFunction: SimplePythonFunction =
+ new SimplePythonFunction(
+ command = Array.emptyByteArray,
+ envVars = Map.empty[String, String].asJava,
+ pythonIncludes = ArrayBuffer.empty[String].asJava,
+ pythonExec = "python3",
+ pythonVer = "3.12",
+ broadcastVars = null,
+ accumulator = null)
+
+ protected val mapInPandasUDF: UserDefinedPythonFunction =
+ UserDefinedPythonFunction(
+ name = "dummyMapInPandas",
+ func = dummyPythonFunction,
+ dataType = outputSchema,
+ pythonEvalType = PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
+ udfDeterministic = true)
+
+ protected def applyMapInPandas(): DataFrame = {
+ val inputDF = Seq(("hello", 1)).toDF("a", "b")
+ inputDF.mapInPandas(mapInPandasUDF(col("a"), col("b")))
+ }
+
+ protected def assertLogicalNode[T <: LogicalPlan: ClassTag](
+ df: DataFrame): Unit = {
+ val tag = implicitly[ClassTag[T]]
+ val node = df.queryExecution.analyzed.collectFirst {
+ case n if tag.runtimeClass.isInstance(n) => n
+ }
+ assert(node.isDefined,
+ s"Expected ${tag.runtimeClass.getSimpleName}" +
+ " in logical plan")
+ }
+
+ protected def assertPhysicalNode[T <: SparkPlan: ClassTag](
+ df: DataFrame): Unit = {
+ val tag = implicitly[ClassTag[T]]
+ val node =
+ df.queryExecution.executedPlan.collectFirst {
+ case n if tag.runtimeClass.isInstance(n) => n
+ }
+ assert(node.isDefined,
+ s"Expected ${tag.runtimeClass.getSimpleName}" +
+ " in physical plan")
+ }
+}
+
+/**
+ * Tests that the classic Python runner path is used when the
+ * unified UDF execution config is disabled (default).
+ */
+class ClassicUDFPlanningSuite
+ extends ExternalUDFPlanningTestBase {
+
+ test("mapInPandas uses MapInPandas logical node") {
+ val result = applyMapInPandas()
+ assertLogicalNode[MapInPandas](result)
+ }
+
+ test("mapInPandas uses MapInPandasExec physical node") {
+ val result = applyMapInPandas()
+ assertPhysicalNode[MapInPandasExec](result)
+ }
+}
+
+/**
+ * Tests that the unified external UDF worker framework is used
+ * when the config is enabled.
+ */
+class UnifiedUDFPlanningSuite
+ extends ExternalUDFPlanningTestBase {
+
+ override def sparkConf: SparkConf =
+ super.sparkConf.set(
+ SQLConf.UNIFIED_UDF_EXECUTION_ENABLED.key, "true")
+
+ test("mapInPandas uses MapPartitionsExternalUDF logical node") {
+ val result = applyMapInPandas()
+ assertLogicalNode[MapPartitionsExternalUDF](result)
+ }
+
+ test("mapInPandas uses MapPartitionsExternalUDFExec" +
+ " physical node") {
+ val result = applyMapInPandas()
+ assertPhysicalNode[MapPartitionsExternalUDFExec](result)
+ }
+}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/externalUDF/PythonUDFWorkerSpecificationSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/externalUDF/PythonUDFWorkerSpecificationSuite.scala
new file mode 100644
index 000000000000..f44d3adf7b7d
--- /dev/null
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/externalUDF/PythonUDFWorkerSpecificationSuite.scala
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.externalUDF
+
+import java.io.File
+import java.nio.charset.StandardCharsets
+import java.nio.file.Files
+
+import scala.collection.mutable.ArrayBuffer
+import scala.jdk.CollectionConverters._
+
+import org.apache.spark.api.python.SimplePythonFunction
+import org.apache.spark.sql.IntegratedUDFTestUtils
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.udf.worker.UDFWorkerSpecification
+import org.apache.spark.udf.worker.core.{TestDirectWorkerDispatcher,
+ UnixSocketWorkerConnection}
+
+/**
+ * A test [[UnixSocketWorkerConnection]] that opens a real Unix
+ * domain socket channel to the worker.
+ */
+private class RealSocketConnection(
+ socketPath: String,
+ private val channel: java.nio.channels.SocketChannel)
+ extends UnixSocketWorkerConnection(socketPath) {
+
+ override def isActive: Boolean = channel.isOpen
+
+ override def close(): Unit = {
+ channel.close()
+ super.close()
+ }
+}
+
+private object RealSocketConnection {
+ def connect(socketPath: String): RealSocketConnection = {
+ val address =
+ java.net.UnixDomainSocketAddress.of(socketPath)
+ val channel = java.nio.channels.SocketChannel.open(
+ java.net.StandardProtocolFamily.UNIX)
+ channel.connect(address)
+ new RealSocketConnection(socketPath, channel)
+ }
+}
+
+/**
+ * Extends [[TestDirectWorkerDispatcher]] to use a real Unix
+ * domain socket connection instead of just checking file
+ * existence.
+ */
+private class RealSocketTestDispatcher(
+ spec: UDFWorkerSpecification)
+ extends TestDirectWorkerDispatcher(spec) {
+
+ // Use /tmp to avoid UDS path length limits in deep
+ // worktree paths.
+ override protected def newEndpointAddress(
+ workerId: String): String = {
+ val socketDir = java.nio.file.Files.createTempDirectory(
+ java.nio.file.Paths.get("/tmp"), "udf-test-")
+ socketDir.toFile.deleteOnExit()
+ socketDir.resolve(s"w-$workerId.sock").toString
+ }
+
+ override protected def createConnection(
+ socketPath: String): UnixSocketWorkerConnection =
+ RealSocketConnection.connect(socketPath)
+}
+
+/**
+ * Tests that [[PythonUDFWorkerSpecification#fromPythonFunction]]
+ * produces a valid [[UDFWorkerSpecification]] that can be used by
+ * a [[org.apache.spark.udf.worker.core.WorkerDispatcher]]
+ * to spawn a real Python worker process.
+ *
+ * The test overrides `spark.python.worker.module` to point to a
+ * test-only Python module that creates the UDS socket and waits
+ * for SIGTERM, verifying that pythonExec, PYTHONPATH, env vars,
+ * and command construction all work end-to-end.
+ */
+class PythonUDFWorkerSpecificationSuite
+ extends SharedSparkSession {
+
+ import IntegratedUDFTestUtils.{
+ isPySparkAvailable, pythonExec, pythonVer,
+ pysparkPythonPath, pythonPath
+ }
+
+ // A test Python module that imports pyspark (validates
+ // the environment), creates the UDS socket at the
+ // --connection path, and waits for SIGTERM.
+ // scalastyle:off line.size.limit
+ private val testWorkerModuleSource =
+ """
+ |import argparse, signal, socket, sys, os
+ |# Validate PySpark is importable
+ |import pyspark
+ |
+ |parser = argparse.ArgumentParser()
+ |parser.add_argument('--id', required=True)
+ |parser.add_argument('--connection', required=True)
+ |args = parser.parse_args()
+ |
+ |sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+ |sock.bind(args.connection)
+ |sock.listen(1)
+ |
+ |running = True
+ |def handle_sigterm(signum, frame):
+ | global running
+ | running = False
+ |signal.signal(signal.SIGTERM, handle_sigterm)
+ |
+ |while running:
+ | # Wait till we receive a signal
+ | signal.pause()
+ |
+ |sock.close()
+ |try:
+ | os.unlink(args.connection)
+ |except OSError:
+ | pass
+ |""".stripMargin.trim
+ // scalastyle:on line.size.limit
+
+ /**
+ * Writes the test worker module to a temp directory and
+ * returns the directory path (to be added to PYTHONPATH)
+ * and the module name.
+ */
+ private def createTestWorkerModule(): (File, String) = {
+ val moduleName = "test_udf_worker"
+ val moduleDir = Files.createTempDirectory(
+ "udf-test-module-").toFile
+ moduleDir.deleteOnExit()
+ val moduleFile = new File(moduleDir, s"$moduleName.py")
+ moduleFile.deleteOnExit()
+ Files.write(moduleFile.toPath,
+ testWorkerModuleSource.getBytes(StandardCharsets.UTF_8))
+ (moduleDir, moduleName)
+ }
+
+ test("PythonUDFWorkerSpecification.fromPythonFunction" +
+ " produces a spec that spawns a Python worker") {
+ assume(isPySparkAvailable,
+ "Python and PySpark must be available")
+
+ val (moduleDir, moduleName) = createTestWorkerModule()
+
+ // Create a PythonFunction with the test module dir
+ // added to PYTHONPATH so the module is importable.
+ val envVars = new java.util.HashMap[String, String]()
+ envVars.put("PYTHONPATH",
+ s"${moduleDir.getAbsolutePath}:" +
+ s"$pysparkPythonPath:$pythonPath")
+ val func = new SimplePythonFunction(
+ command = Array.emptyByteArray,
+ envVars = envVars,
+ pythonIncludes = ArrayBuffer.empty[String].asJava,
+ pythonExec = pythonExec,
+ pythonVer = pythonVer,
+ broadcastVars = null,
+ accumulator = null)
+
+ // Override the worker module config to use our test module
+ val conf = spark.sparkContext.conf.clone()
+ conf.set("spark.python.worker.module", moduleName)
+
+ // Build the spec via the function under test
+ val workerSpec =
+ PythonUDFWorkerSpecification.fromPythonFunction(func, conf)
+
+ // Verify the spec works end-to-end with a real
+ // socket connection
+ val dispatcher = new RealSocketTestDispatcher(workerSpec)
+ try {
+ val session = dispatcher.createSession(
+ securityScope = None)
+ assert(session != null,
+ "Expected a non-null session from the dispatcher")
+ session.close()
+ } finally {
+ dispatcher.close()
+ }
+ }
+}
diff --git a/udf/worker/core/pom.xml b/udf/worker/core/pom.xml
index f045fa1abd2a..0857636fb9c2 100644
--- a/udf/worker/core/pom.xml
+++ b/udf/worker/core/pom.xml
@@ -51,6 +51,11 @@
<groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId>
</dependency>
+ <dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-core</artifactId>
+ <scope>test</scope>
+ </dependency>
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-inprocess</artifactId>
diff --git
a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/UDFDispatcherFactory.scala
b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/UDFDispatcherFactory.scala
new file mode 100644
index 000000000000..bd46821bf0f0
--- /dev/null
+++
b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/UDFDispatcherFactory.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.udf.worker.core
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.udf.worker.UDFWorkerSpecification
+
+/**
+ * :: Experimental ::
+ * Creates [[WorkerDispatcher]] instances.
+ *
+ * Implementations are passed to [[UDFDispatcherManager]] which
+ * handles caching and shutdown.
+ */
+@Experimental
+trait UDFDispatcherFactory {
+
+ /**
+ * Creates a new [[WorkerDispatcher]] for the given specification.
+ * At most one dispatcher will be requested per unique
+ * [[UDFWorkerSpecification]]
+ */
+ def createDispatcher(
+ workerSpec: UDFWorkerSpecification,
+ logger: WorkerLogger): WorkerDispatcher
+
+}
diff --git
a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/UDFDispatcherManager.scala
b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/UDFDispatcherManager.scala
new file mode 100644
index 000000000000..cf3d36a7c5d7
--- /dev/null
+++
b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/UDFDispatcherManager.scala
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.udf.worker.core
+
+import java.util.HashMap
+import java.util.concurrent.locks.ReentrantReadWriteLock
+
+import scala.util.control.NonFatal
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.udf.worker.UDFWorkerSpecification
+
+/**
+ * :: Experimental ::
+ * Manages [[WorkerDispatcher]] instances, caching them by
+ * [[UDFWorkerSpecification]] (protobuf value equality).
+ *
+ * Callers obtain a dispatcher via [[getDispatcher]] and create
+ * sessions on it directly. On [[stop]], all cached dispatchers
+ * are closed -- dispatchers are responsible for cleaning up
+ * their own sessions.
+ *
+ * Thread safety: a [[ReentrantReadWriteLock]] allows concurrent
+ * [[getDispatcher]] calls (read lock) while [[stop]] has
+ * exclusive access (write lock).
+ */
+@Experimental
+class UDFDispatcherManager(
+ private val dispatcherFactory: UDFDispatcherFactory,
+ workerLogger: WorkerLogger = WorkerLogger.NoOp
+) {
+
+ // Guarded by `rwLock`. The read lock is used by getDispatcher
+ // (with upgrade when a new dispatcher must be added) and the
+ // write lock is used by stop.
+ private val rwLock = new ReentrantReadWriteLock()
+ private val dispatchers =
+ new HashMap[UDFWorkerSpecification, WorkerDispatcher]()
+ private var closed = false
+
+ /**
+ * Returns the [[WorkerDispatcher]] for the given spec, creating
+ * one via the [[UDFDispatcherFactory]] if none exists yet.
+ */
+ def getDispatcher(
+ workerSpec: UDFWorkerSpecification): WorkerDispatcher = {
+ // First, try to read an existing dispatcher = quick path
+ rwLock.readLock().lock()
+ try {
+ if (closed) throwClosed()
+
+ // Reading existing dispatcher = quick path
+ val dispatcher = dispatchers.get(workerSpec)
+ if (dispatcher != null) {
+ return dispatcher
+ }
+ } finally {
+ rwLock.readLock().unlock()
+ }
+
+ // We need to acquire a new dispatcher
+ // = slower path with global lock
+ rwLock.writeLock().lock()
+ try {
+ if (closed) throwClosed()
+ // Re-check after acquiring write lock.
+ var dispatcher = dispatchers.get(workerSpec)
+ if (dispatcher == null) {
+ dispatcher = dispatcherFactory.createDispatcher(
+ workerSpec, workerLogger)
+ workerLogger.info(
+ s"Created new dispatcher")
+ dispatchers.put(workerSpec, dispatcher)
+ }
+ dispatcher
+ } finally {
+ rwLock.writeLock().unlock()
+ }
+ }
+
+ private def throwClosed(): Nothing =
+ throw new IllegalStateException("UDFDispatcherManager is stopped")
+
+ /**
+ * Closes all cached dispatchers and resets internal state.
+ * Dispatchers are responsible for cleaning up their own
+ * sessions.
+ */
+ def close(): Unit = {
+ rwLock.writeLock().lock()
+ try {
+ if (closed) return
+ closed = true
+ workerLogger.info(
+ "UDFDispatcherManager closing" +
+ s" (${dispatchers.size()} dispatchers)")
+ dispatchers.forEach { (_, dispatcher) =>
+ try {
+ dispatcher.close()
+ } catch {
+ case NonFatal(e) =>
+ workerLogger.warn(
+ "Error closing dispatcher during shutdown", e)
+ }
+ }
+ dispatchers.clear()
+ workerLogger.info("UDFDispatcherManager closed")
+ } finally {
+ rwLock.writeLock().unlock()
+ }
+ }
+}
diff --git
a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerLogger.scala
b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerLogger.scala
index a8f135f68890..a1acd348bf21 100644
---
a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerLogger.scala
+++
b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerLogger.scala
@@ -36,6 +36,8 @@ import org.apache.spark.annotation.Experimental
trait WorkerLogger {
def warn(msg: => String): Unit
def warn(msg: => String, t: Throwable): Unit
+ def info(msg: => String): Unit
+ def info(msg: => String, t: Throwable): Unit
def debug(msg: => String): Unit
def debug(msg: => String, t: Throwable): Unit
}
@@ -45,6 +47,8 @@ object WorkerLogger {
val NoOp: WorkerLogger = new WorkerLogger {
override def warn(msg: => String): Unit = ()
override def warn(msg: => String, t: Throwable): Unit = ()
+ override def info(msg: => String): Unit = ()
+ override def info(msg: => String, t: Throwable): Unit = ()
override def debug(msg: => String): Unit = ()
override def debug(msg: => String, t: Throwable): Unit = ()
}
diff --git
a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala
b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala
index 14db8da7ac89..2438546ddafb 100644
---
a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala
+++
b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala
@@ -222,7 +222,8 @@ abstract class DirectWorkerDispatcher(
if (!closed.compareAndSet(false, true)) {
return
}
- // TODO: close workers in parallel -- today shutdown is serialised at
+ // TODO [SPARK-55278]: Cleanup sessions as well?
+ // TODO [SPARK-55278]: close workers in parallel -- today shutdown is
serialised at
// N * gracefulTimeoutMs worst case.
workers.values().iterator().asScala.foreach { w =>
try {
diff --git
a/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala
b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala
index 7302c697d93c..06574fdf1013 100644
---
a/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala
+++
b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala
@@ -27,63 +27,13 @@ import org.scalatest.BeforeAndAfterEach
import org.scalatest.funsuite.AnyFunSuite
import org.apache.spark.udf.worker.{
- DirectWorker, Init, LocalTcpConnection, ProcessCallable, UDFWorkerProperties,
+ DirectWorker, LocalTcpConnection, ProcessCallable, UDFWorkerProperties,
UDFWorkerSpecification, UnixDomainSocket, WorkerConnectionSpec,
WorkerEnvironment}
import
org.apache.spark.udf.worker.core.direct.{DirectUnixSocketWorkerDispatcher,
- DirectWorkerException, DirectWorkerProcess, DirectWorkerSession,
+ DirectWorkerException, DirectWorkerProcess,
DirectWorkerTimeoutException}
-/**
- * A [[WorkerConnection]] test implementation that considers the connection
- * active as long as the socket file exists on disk. Inherits socket-file
- * deletion from [[UnixSocketWorkerConnection.close]].
- */
-class SocketFileConnection(socketPath: String)
- extends UnixSocketWorkerConnection(socketPath) {
- override def isActive: Boolean = new File(socketPath).exists()
-}
-
-/**
- * A stub [[DirectWorkerSession]] for process-lifecycle tests that don't
- * need actual data transmission.
- *
- * TODO: [[cancel]] is a no-op here. Once a concrete [[DirectWorkerSession]]
- * with real data-plane wiring lands, add tests exercising cancel() in
- * particular: cancel from a different thread than process(), cancel
- * after process() has returned, and cancel before init (should be a no-op).
- * See the thread-safety contract in the docstring on
- * [[org.apache.spark.udf.worker.core.WorkerSession.cancel]].
- */
-class StubWorkerSession(
- workerProcess: DirectWorkerProcess) extends
DirectWorkerSession(workerProcess) {
-
- override protected def doInit(message: Init): Unit = {}
-
- override protected def doProcess(
- input: Iterator[Array[Byte]]): Iterator[Array[Byte]] =
- Iterator.empty
-
- override def cancel(): Unit = {}
-}
-
-/**
- * A [[DirectUnixSocketWorkerDispatcher]] subclass for testing that uses
- * a socket-file connection and stub sessions instead of a real protocol
- * implementation.
- */
-class TestDirectWorkerDispatcher(spec: UDFWorkerSpecification)
- extends DirectUnixSocketWorkerDispatcher(spec) {
-
- override protected def createConnection(
- socketPath: String): UnixSocketWorkerConnection =
- new SocketFileConnection(socketPath)
-
- override protected def createSessionForWorker(
- worker: DirectWorkerProcess): WorkerSession =
- new StubWorkerSession(worker)
-}
-
/**
* Tests for [[DirectWorkerDispatcher]] process lifecycle: spawning workers
* and terminating them on close.
diff --git
a/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/TestDirectWorkerHelpers.scala
b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/TestDirectWorkerHelpers.scala
new file mode 100644
index 000000000000..1843024e9fb8
--- /dev/null
+++
b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/TestDirectWorkerHelpers.scala
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.udf.worker.core
+
+import java.io.File
+
+import org.apache.spark.udf.worker.{Init, UDFWorkerSpecification}
+import org.apache.spark.udf.worker.core.direct.{
+ DirectUnixSocketWorkerDispatcher, DirectWorkerProcess,
+ DirectWorkerSession}
+
+/**
+ * A [[WorkerConnection]] test implementation that considers the
+ * connection active as long as the socket file exists on disk.
+ * Inherits socket-file deletion from
+ * [[UnixSocketWorkerConnection.close]].
+ */
+class SocketFileConnection(socketPath: String)
+ extends UnixSocketWorkerConnection(socketPath) {
+ override def isActive: Boolean = new File(socketPath).exists()
+}
+
+/**
+ * A stub [[DirectWorkerSession]] for process-lifecycle tests
+ * that don't need actual data transmission.
+ *
+ * TODO: [[cancel]] is a no-op here. Once a concrete
+ * [[DirectWorkerSession]] with real data-plane wiring lands, add
+ * tests exercising cancel() in particular: cancel from a
+ * different thread than process(), cancel after process() has
+ * returned, and cancel before init (should be a no-op). See the
+ * thread-safety contract in the docstring on
+ * [[org.apache.spark.udf.worker.core.WorkerSession.cancel]].
+ */
+class StubWorkerSession(workerProcess: DirectWorkerProcess)
+ extends DirectWorkerSession(workerProcess) {
+
+ override protected def doInit(message: Init): Unit = {}
+
+ override protected def doProcess(
+ input: Iterator[Array[Byte]]
+ ): Iterator[Array[Byte]] =
+ Iterator.empty
+
+ override def cancel(): Unit = {}
+}
+
+/**
+ * A [[DirectUnixSocketWorkerDispatcher]] subclass for testing
+ * that uses a socket-file connection and stub sessions instead
+ * of a real protocol implementation.
+ */
+class TestDirectWorkerDispatcher(spec: UDFWorkerSpecification)
+ extends DirectUnixSocketWorkerDispatcher(spec) {
+
+ override protected def createConnection(
+ socketPath: String
+ ): UnixSocketWorkerConnection =
+ new SocketFileConnection(socketPath)
+
+ override protected def createSessionForWorker(
+ worker: DirectWorkerProcess
+ ): WorkerSession =
+ new StubWorkerSession(worker)
+}
diff --git
a/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/UDFDispatcherManagerSuite.scala
b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/UDFDispatcherManagerSuite.scala
new file mode 100644
index 000000000000..858fa6ce6b78
--- /dev/null
+++
b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/UDFDispatcherManagerSuite.scala
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.udf.worker.core
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.mockito.Mockito.{mock, verify}
+// scalastyle:off funsuite
+import org.scalatest.funsuite.AnyFunSuite
+
+import org.apache.spark.udf.worker._
+
+class UDFDispatcherManagerSuite
+ extends AnyFunSuite { // scalastyle:ignore funsuite
+
+ private def createManager(): (
+ UDFDispatcherManager,
+ ArrayBuffer[WorkerDispatcher]) = {
+ val createdDispatchers = ArrayBuffer[WorkerDispatcher]()
+ val factory = new UDFDispatcherFactory {
+ override def createDispatcher(
+ workerSpec: UDFWorkerSpecification,
+ logger: WorkerLogger): WorkerDispatcher = {
+ val dispatcher = mock(classOf[WorkerDispatcher])
+ createdDispatchers += dispatcher
+ dispatcher
+ }
+ }
+ (new UDFDispatcherManager(factory), createdDispatchers)
+ }
+
+ private def makeSpec(
+ command: String): UDFWorkerSpecification = {
+ val callable = ProcessCallable.newBuilder()
+ callable.addCommand(command)
+ val caps = WorkerCapabilities.newBuilder()
+ .addSupportedDataFormats(UDFWorkerDataFormat.ARROW)
+ .addSupportedCommunicationPatterns(
+ UDFProtoCommunicationPattern.BIDIRECTIONAL_STREAMING)
+ val conn = WorkerConnectionSpec.newBuilder()
+ .setTcp(LocalTcpConnection.newBuilder())
+ val props = UDFWorkerProperties.newBuilder()
+ .setConnection(conn)
+ val direct = DirectWorker.newBuilder()
+ .setRunner(callable).setProperties(props)
+ UDFWorkerSpecification.newBuilder()
+ .setEnvironment(WorkerEnvironment.newBuilder())
+ .setCapabilities(caps).setDirect(direct).build()
+ }
+
+ test("Same spec returns the same dispatcher") {
+ val (manager, createdDispatchers) = createManager()
+ val spec = makeSpec("worker.bin")
+
+ val d1 = manager.getDispatcher(spec)
+ val d2 = manager.getDispatcher(spec)
+
+ assert(d1 eq d2)
+ assert(createdDispatchers.size === 1)
+ }
+
+ test("Value-equal specs return the same dispatcher") {
+ val (manager, createdDispatchers) = createManager()
+ val spec1 = makeSpec("worker.bin")
+ val spec2 = makeSpec("worker.bin")
+ assert(spec1 ne spec2)
+ assert(spec1 == spec2)
+
+ val d1 = manager.getDispatcher(spec1)
+ val d2 = manager.getDispatcher(spec2)
+
+ assert(d1 eq d2)
+ assert(createdDispatchers.size === 1)
+ }
+
+ test("Different specs create different dispatchers") {
+ val (manager, createdDispatchers) = createManager()
+
+ val dA = manager.getDispatcher(makeSpec("worker-a.bin"))
+ val dB = manager.getDispatcher(makeSpec("worker-b.bin"))
+
+ assert(dA ne dB)
+ assert(createdDispatchers.size === 2)
+ }
+
+ test("Close closes all cached dispatchers") {
+ val (manager, createdDispatchers) = createManager()
+
+ manager.getDispatcher(makeSpec("worker-a.bin"))
+ manager.getDispatcher(makeSpec("worker-b.bin"))
+ manager.close()
+
+ createdDispatchers.foreach(
+ dispatcher => verify(dispatcher).close())
+ }
+
+ test("Close is idempotent") {
+ val (manager, createdDispatchers) = createManager()
+ manager.getDispatcher(makeSpec("worker.bin"))
+
+ manager.close()
+ manager.close()
+
+ // close() should only be called once per dispatcher
+ createdDispatchers.foreach(
+ dispatcher => verify(dispatcher).close())
+ }
+
+ test("getDispatcher throws after close") {
+ val (manager, _) = createManager()
+ manager.close()
+
+ intercept[IllegalStateException] {
+ manager.getDispatcher(makeSpec("worker.bin"))
+ }
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]