This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 26330355836f [SPARK-49249][SPARK-49122] Artifact isolation in Spark
Classic
26330355836f is described below
commit 26330355836f5b2dad9b7bd4c72d9830c7ce6788
Author: Paddy Xu <[email protected]>
AuthorDate: Wed Nov 13 10:51:02 2024 +0900
[SPARK-49249][SPARK-49122] Artifact isolation in Spark Classic
### What changes were proposed in this pull request?
This PR makes the isolation feature introduced by
`SparkSession.addArtifact` API (added in
https://github.com/apache/spark/pull/47631) work with Spark SQL.
Note that this PR does not enable isolation for the following two use cases:
- PySpark
- Future work is needed to add API to support adding isolated
Python UDTFs.
- When Hive is used as the metastore
- Hive UDF is a huge blocker due to artifacts can be used outside a
`SparkSession`, which resources escaped from our session scope.
### Why are the changes needed?
Because it didn't work before :)
### Does this PR introduce _any_ user-facing change?
Yes, the user can add a new artifact in the REPL and use it in the current
REPL session.
### How was this patch tested?
Added a new test.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #48120 from xupefei/session-artifact-apply.
Authored-by: Paddy Xu <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../org/apache/spark/sql/UDFRegistration.scala | 6 +
.../main/scala/org/apache/spark/SparkFiles.scala | 8 +-
python/pyspark/core/context.py | 2 +
python/pyspark/sql/connect/session.py | 6 +-
.../main/scala/org/apache/spark/repl/Main.scala | 4 +
repl/src/test/resources/IntSumUdf.class | Bin 0 -> 1333 bytes
.../src/test/resources/IntSumUdf.scala | 23 +--
.../scala/org/apache/spark/repl/ReplSuite.scala | 63 ++++++++
.../org/apache/spark/sql/api/UDFRegistration.scala | 17 ++
.../org/apache/spark/sql/internal/SQLConf.scala | 22 +++
.../sql/connect/SimpleSparkConnectService.scala | 3 +
.../sql/connect/service/SparkConnectServer.scala | 7 +-
.../scala/org/apache/spark/sql/SparkSession.scala | 4 +
.../org/apache/spark/sql/UDFRegistration.scala | 18 +--
.../spark/sql/artifact/ArtifactManager.scala | 74 ++++++---
.../apache/spark/sql/execution/SQLExecution.scala | 173 +++++++++++----------
.../sql/execution/streaming/StreamExecution.scala | 4 +-
.../sql/internal/BaseSessionStateBuilder.scala | 2 +-
.../spark/sql/artifact/ArtifactManagerSuite.scala | 27 ++--
.../spark/sql/execution/command/DDLSuite.scala | 7 +-
.../spark/sql/hive/execution/HiveQuerySuite.scala | 8 +
.../org/apache/spark/sql/hive/test/TestHive.scala | 6 +-
.../spark/sql/hive/test/TestHiveSingleton.scala | 7 +
23 files changed, 325 insertions(+), 166 deletions(-)
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
index 3a84d43ceae3..93d085a25c7b 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.internal.UdfToProtoUtils
+import org.apache.spark.sql.types.DataType
/**
* Functions for registering user-defined functions. Use `SparkSession.udf` to
access this:
@@ -30,6 +31,11 @@ import org.apache.spark.sql.internal.UdfToProtoUtils
* @since 3.5.0
*/
class UDFRegistration(session: SparkSession) extends api.UDFRegistration {
+ override def registerJava(name: String, className: String, returnDataType:
DataType): Unit = {
+ throw new UnsupportedOperationException(
+ "registerJava is currently not supported in Spark Connect.")
+ }
+
override protected def register(
name: String,
udf: UserDefinedFunction,
diff --git a/core/src/main/scala/org/apache/spark/SparkFiles.scala
b/core/src/main/scala/org/apache/spark/SparkFiles.scala
index 44f4444a1fa8..f4165c2fc6f2 100644
--- a/core/src/main/scala/org/apache/spark/SparkFiles.scala
+++ b/core/src/main/scala/org/apache/spark/SparkFiles.scala
@@ -27,8 +27,12 @@ object SparkFiles {
/**
* Get the absolute path of a file added through `SparkContext.addFile()`.
*/
- def get(filename: String): String =
- new File(getRootDirectory(), filename).getAbsolutePath()
+ def get(filename: String): String = {
+ val jobArtifactUUID = JobArtifactSet
+ .getCurrentJobArtifactState.map(_.uuid).getOrElse("default")
+ val withUuid = if (jobArtifactUUID == "default") filename else
s"$jobArtifactUUID/$filename"
+ new File(getRootDirectory(), withUuid).getAbsolutePath
+ }
/**
* Get the root directory that contains files added through
`SparkContext.addFile()`.
diff --git a/python/pyspark/core/context.py b/python/pyspark/core/context.py
index 63d41c11dafd..6ea793a11838 100644
--- a/python/pyspark/core/context.py
+++ b/python/pyspark/core/context.py
@@ -84,6 +84,8 @@ __all__ = ["SparkContext"]
DEFAULT_CONFIGS: Dict[str, Any] = {
"spark.serializer.objectStreamReset": 100,
"spark.rdd.compress": True,
+ # Disable artifact isolation in PySpark, or user-added .py file won't work
+ "spark.sql.artifact.isolation.enabled": "false",
}
T = TypeVar("T")
diff --git a/python/pyspark/sql/connect/session.py
b/python/pyspark/sql/connect/session.py
index e9984fae9ddb..83b0496a8427 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -1037,7 +1037,11 @@ class SparkSession:
os.environ["SPARK_LOCAL_CONNECT"] = "1"
# Configurations to be set if unset.
- default_conf = {"spark.plugins":
"org.apache.spark.sql.connect.SparkConnectPlugin"}
+ default_conf = {
+ "spark.plugins":
"org.apache.spark.sql.connect.SparkConnectPlugin",
+ "spark.sql.artifact.isolation.enabled": "true",
+ "spark.sql.artifact.isolation.always.apply.classloader":
"true",
+ }
if "SPARK_TESTING" in os.environ:
# For testing, we use 0 to use an ephemeral port to allow
parallel testing.
diff --git a/repl/src/main/scala/org/apache/spark/repl/Main.scala
b/repl/src/main/scala/org/apache/spark/repl/Main.scala
index 7b126c357271..4d3465b32039 100644
--- a/repl/src/main/scala/org/apache/spark/repl/Main.scala
+++ b/repl/src/main/scala/org/apache/spark/repl/Main.scala
@@ -25,6 +25,7 @@ import scala.tools.nsc.GenericRunnerSettings
import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
import org.apache.spark.util.Utils
@@ -95,6 +96,9 @@ object Main extends Logging {
// initialization in certain cases, there's an initialization order
issue that prevents
// this from being set after SparkContext is instantiated.
conf.set("spark.repl.class.outputDir", outputDir.getAbsolutePath())
+ // Disable isolation for REPL, to avoid having in-line classes stored in
a isolated directory,
+ // prevent the REPL classloader from finding it.
+ conf.set(SQLConf.ARTIFACTS_SESSION_ISOLATION_ENABLED, false)
if (execUri != null) {
conf.set("spark.executor.uri", execUri)
}
diff --git a/repl/src/test/resources/IntSumUdf.class
b/repl/src/test/resources/IntSumUdf.class
new file mode 100644
index 000000000000..75a41446cfca
Binary files /dev/null and b/repl/src/test/resources/IntSumUdf.class differ
diff --git a/core/src/main/scala/org/apache/spark/SparkFiles.scala
b/repl/src/test/resources/IntSumUdf.scala
similarity index 61%
copy from core/src/main/scala/org/apache/spark/SparkFiles.scala
copy to repl/src/test/resources/IntSumUdf.scala
index 44f4444a1fa8..9678caaed5db 100644
--- a/core/src/main/scala/org/apache/spark/SparkFiles.scala
+++ b/repl/src/test/resources/IntSumUdf.scala
@@ -15,25 +15,8 @@
* limitations under the License.
*/
-package org.apache.spark
-
-import java.io.File
-
-/**
- * Resolves paths to files added through `SparkContext.addFile()`.
- */
-object SparkFiles {
-
- /**
- * Get the absolute path of a file added through `SparkContext.addFile()`.
- */
- def get(filename: String): String =
- new File(getRootDirectory(), filename).getAbsolutePath()
-
- /**
- * Get the root directory that contains files added through
`SparkContext.addFile()`.
- */
- def getRootDirectory(): String =
- SparkEnv.get.driverTmpDir.getOrElse(".")
+import org.apache.spark.sql.api.java.UDF2
+class IntSumUdf extends UDF2[Long, Long, Long] {
+ override def call(t1: Long, t2: Long): Long = t1 + t2
}
diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala
b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index 1a7be083d2d9..327ef3d07420 100644
--- a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -396,4 +396,67 @@ class ReplSuite extends SparkFunSuite {
Main.sparkContext.stop()
System.clearProperty("spark.driver.port")
}
+
+ test("register UDF via SparkSession.addArtifact") {
+ val artifactPath = new File("src/test/resources").toPath
+ val intSumUdfPath = artifactPath.resolve("IntSumUdf.class")
+ val output = runInterpreterInPasteMode("local",
+ s"""
+ |import org.apache.spark.sql.api.java.UDF2
+ |import org.apache.spark.sql.types.DataTypes
+ |
+ |spark.addArtifact("${intSumUdfPath.toString}")
+ |
+ |spark.udf.registerJava("intSum", "IntSumUdf", DataTypes.LongType)
+ |
+ |val r = spark.range(5)
+ | .withColumn("id2", col("id") + 1)
+ | .selectExpr("intSum(id, id2)")
+ | .collect()
+ |assert(r.map(_.getLong(0)).toSeq == Seq(1, 3, 5, 7, 9))
+ |
+ """.stripMargin)
+ assertContains("Array([1], [3], [5], [7], [9])", output)
+ assertDoesNotContain("error:", output)
+ assertDoesNotContain("Exception", output)
+ assertDoesNotContain("assertion failed", output)
+
+ // The UDF should not work in a new REPL session.
+ val anotherOutput = runInterpreterInPasteMode("local",
+ s"""
+ |val r = spark.range(5)
+ | .withColumn("id2", col("id") + 1)
+ | .selectExpr("intSum(id, id2)")
+ | .collect()
+ |
+ """.stripMargin)
+ assertContains(
+ "[UNRESOLVED_ROUTINE] Cannot resolve routine `intSum` on search path",
+ anotherOutput)
+ }
+
+ test("register a class via SparkSession.addArtifact") {
+ val artifactPath = new File("src/test/resources").toPath
+ val intSumUdfPath = artifactPath.resolve("IntSumUdf.class")
+ val output = runInterpreterInPasteMode("local",
+ s"""
+ |import org.apache.spark.sql.functions.udf
+ |
+ |spark.addArtifact("${intSumUdfPath.toString}")
+ |
+ |val intSumUdf = udf((x: Long, y: Long) => new IntSumUdf().call(x, y))
+ |spark.udf.register("intSum", intSumUdf)
+ |
+ |val r = spark.range(5)
+ | .withColumn("id2", col("id") + 1)
+ | .selectExpr("intSum(id, id2)")
+ | .collect()
+ |assert(r.map(_.getLong(0)).toSeq == Seq(1, 3, 5, 7, 9))
+ |
+ """.stripMargin)
+ assertContains("Array([1], [3], [5], [7], [9])", output)
+ assertDoesNotContain("error:", output)
+ assertDoesNotContain("Exception", output)
+ assertDoesNotContain("assertion failed", output)
+ }
}
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/api/UDFRegistration.scala
b/sql/api/src/main/scala/org/apache/spark/sql/api/UDFRegistration.scala
index c11e266827ff..a8e8f5c5f855 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/api/UDFRegistration.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/api/UDFRegistration.scala
@@ -35,6 +35,23 @@ import org.apache.spark.sql.types.DataType
*/
abstract class UDFRegistration {
+ /**
+ * Register a Java UDF class using it's class name. The class must implement
one of the UDF
+ * interfaces in the [[org.apache.spark.sql.api.java]] package, and
discoverable by the current
+ * session's class loader.
+ *
+ * @param name
+ * Name of the UDF.
+ * @param className
+ * Fully qualified class name of the UDF.
+ * @param returnDataType
+ * Return type of UDF. If it is `null`, Spark would try to infer via
reflection.
+ * @note
+ * this method is currently not supported in Spark Connect.
+ * @since 4.0.0
+ */
+ def registerJava(name: String, className: String, returnDataType: DataType):
Unit
+
/**
* Registers a user-defined function (UDF), for a UDF that's already defined
using the Dataset
* API (i.e. of type UserDefinedFunction). To change a UDF to
nondeterministic, call the API
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index d17ab656fe6b..eac89212b9da 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -3957,6 +3957,28 @@ object SQLConf {
.intConf
.createWithDefault(20)
+ val ARTIFACTS_SESSION_ISOLATION_ENABLED =
+ buildConf("spark.sql.artifact.isolation.enabled")
+ .internal()
+ .doc("When enabled for a Spark Session, artifacts (such as JARs, files,
archives) added to " +
+ "this session are isolated from other sessions within the same Spark
instance. When " +
+ "disabled for a session, artifacts added to this session are visible
to other sessions " +
+ "that have this config disabled. This config can only be set during
the creation of a " +
+ "Spark Session and will have no effect when changed in the middle of
session usage.")
+ .version("4.0.0")
+ .booleanConf
+ .createWithDefault(true)
+
+ val ARTIFACTS_SESSION_ISOLATION_ALWAYS_APPLY_CLASSLOADER =
+ buildConf("spark.sql.artifact.isolation.always.apply.classloader")
+ .internal()
+ .doc("When enabled, the classloader holding per-session artifacts will
always be applied " +
+ "during SQL executions (useful for Spark Connect). When disabled, the
classloader will " +
+ "be applied only when any artifact is added to the session.")
+ .version("4.0.0")
+ .booleanConf
+ .createWithDefault(false)
+
val FAST_HASH_AGGREGATE_MAX_ROWS_CAPACITY_BIT =
buildConf("spark.sql.codegen.aggregate.fastHashMap.capacityBit")
.internal()
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/SimpleSparkConnectService.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/SimpleSparkConnectService.scala
index 1b6bdd8cd939..8061e913dc0d 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/SimpleSparkConnectService.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/SimpleSparkConnectService.scala
@@ -25,6 +25,7 @@ import scala.sys.exit
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connect.service.SparkConnectService
+import org.apache.spark.sql.internal.SQLConf
/**
* A simple main class method to start the spark connect server as a service
for client tests
@@ -40,6 +41,8 @@ private[sql] object SimpleSparkConnectService {
def main(args: Array[String]): Unit = {
val conf = new SparkConf()
.set("spark.plugins", "org.apache.spark.sql.connect.SparkConnectPlugin")
+ .set(SQLConf.ARTIFACTS_SESSION_ISOLATION_ENABLED, true)
+ .set(SQLConf.ARTIFACTS_SESSION_ISOLATION_ALWAYS_APPLY_CLASSLOADER, true)
val sparkSession = SparkSession.builder().config(conf).getOrCreate()
val sparkContext = sparkSession.sparkContext // init spark context
// scalastyle:off println
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectServer.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectServer.scala
index 4f05ea927e12..b2c4d1abb17b 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectServer.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectServer.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.connect.service
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.{HOST, PORT}
import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.internal.SQLConf
/**
* The Spark Connect server
@@ -28,7 +29,11 @@ object SparkConnectServer extends Logging {
def main(args: Array[String]): Unit = {
// Set the active Spark Session, and starts SparkEnv instance (via Spark
Context)
logInfo("Starting Spark session.")
- val session = SparkSession.builder().getOrCreate()
+ val session = SparkSession
+ .builder()
+ .config(SQLConf.ARTIFACTS_SESSION_ISOLATION_ENABLED.key, true)
+
.config(SQLConf.ARTIFACTS_SESSION_ISOLATION_ALWAYS_APPLY_CLASSLOADER.key, true)
+ .getOrCreate()
try {
try {
SparkConnectService.start(session.sparkContext)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 3af4a26cf187..afc0a2d7df60 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -789,7 +789,11 @@ object SparkSession extends api.BaseSparkSessionCompanion
with Logging {
/** @inheritdoc */
override def enableHiveSupport(): this.type = synchronized {
if (hiveClassesArePresent) {
+ // TODO(SPARK-50244): We now isolate artifacts added by the `ADD JAR`
command. This will
+ // break an existing Hive use case (one session adds JARs and another
session uses them).
+ // We need to decide whether/how to enable isolation for Hive.
super.enableHiveSupport()
+ .config(SQLConf.ARTIFACTS_SESSION_ISOLATION_ENABLED.key, false)
} else {
throw new IllegalArgumentException(
"Unable to instantiate SparkSession with Hive support because " +
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
index 2724399a1a84..6715673cf3d1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
@@ -32,7 +32,6 @@ import
org.apache.spark.sql.execution.python.UserDefinedPythonFunction
import org.apache.spark.sql.expressions.{SparkUserDefinedFunction,
UserDefinedAggregateFunction, UserDefinedAggregator, UserDefinedFunction}
import org.apache.spark.sql.internal.UserDefinedFunctionUtils.toScalaUDF
import org.apache.spark.sql.types.DataType
-import org.apache.spark.util.Utils
/**
* Functions for registering user-defined functions. Use `SparkSession.udf` to
access this:
@@ -44,7 +43,7 @@ import org.apache.spark.util.Utils
* @since 1.3.0
*/
@Stable
-class UDFRegistration private[sql] (functionRegistry: FunctionRegistry)
+class UDFRegistration private[sql] (session: SparkSession, functionRegistry:
FunctionRegistry)
extends api.UDFRegistration
with Logging {
protected[sql] def registerPython(name: String, udf:
UserDefinedPythonFunction): Unit = {
@@ -121,7 +120,7 @@ class UDFRegistration private[sql] (functionRegistry:
FunctionRegistry)
*/
private[sql] def registerJavaUDAF(name: String, className: String): Unit = {
try {
- val clazz = Utils.classForName[AnyRef](className)
+ val clazz = session.artifactManager.classloader.loadClass(className)
if (!classOf[UserDefinedAggregateFunction].isAssignableFrom(clazz)) {
throw QueryCompilationErrors
.classDoesNotImplementUserDefinedAggregateFunctionError(className)
@@ -137,17 +136,10 @@ class UDFRegistration private[sql] (functionRegistry:
FunctionRegistry)
}
// scalastyle:off line.size.limit
- /**
- * Register a Java UDF class using reflection, for use from pyspark
- *
- * @param name udf name
- * @param className fully qualified class name of udf
- * @param returnDataType return type of udf. If it is null, spark would try
to infer
- * via reflection.
- */
- private[sql] def registerJava(name: String, className: String,
returnDataType: DataType): Unit = {
+
+ override def registerJava(name: String, className: String, returnDataType:
DataType): Unit = {
try {
- val clazz = Utils.classForName[AnyRef](className)
+ val clazz = session.artifactManager.classloader.loadClass(className)
val udfInterfaces = clazz.getGenericInterfaces
.filter(_.isInstanceOf[ParameterizedType])
.map(_.asInstanceOf[ParameterizedType])
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala
b/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala
index b81c369f7e9c..d362c5bef878 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala
@@ -22,6 +22,7 @@ import java.net.{URI, URL, URLClassLoader}
import java.nio.ByteBuffer
import java.nio.file.{CopyOption, Files, Path, Paths, StandardCopyOption}
import java.util.concurrent.CopyOnWriteArrayList
+import java.util.concurrent.atomic.AtomicBoolean
import scala.jdk.CollectionConverters._
import scala.reflect.ClassTag
@@ -68,18 +69,43 @@ class ArtifactManager(session: SparkSession) extends
Logging {
s"$artifactRootURI${File.separator}${session.sessionUUID}")
// The base directory/URI where all class file artifacts are stored for this
`sessionUUID`.
- protected[artifact] val (classDir, classURI): (Path, String) =
+ protected[artifact] val (classDir, replClassURI): (Path, String) =
(ArtifactUtils.concatenatePaths(artifactPath, "classes"),
s"$artifactURI${File.separator}classes${File.separator}")
- protected[artifact] val state: JobArtifactState =
- JobArtifactState(session.sessionUUID, Option(classURI))
+ private lazy val alwaysApplyClassLoader =
+
session.conf.get(SQLConf.ARTIFACTS_SESSION_ISOLATION_ALWAYS_APPLY_CLASSLOADER.key).toBoolean
- def withResources[T](f: => T): T = {
- Utils.withContextClassLoader(classloader) {
- JobArtifactSet.withActiveJobArtifactState(state) {
+ private lazy val sessionIsolated =
+ session.conf.get(SQLConf.ARTIFACTS_SESSION_ISOLATION_ENABLED.key).toBoolean
+
+ protected[sql] lazy val state: JobArtifactState =
+ if (sessionIsolated) JobArtifactState(session.sessionUUID,
Some(replClassURI)) else null
+
+ /**
+ * Whether any artifact has been added to this artifact manager. We use this
to determine whether
+ * we should apply the classloader to the session, see
`withClassLoaderIfNeeded`.
+ */
+ protected val sessionArtifactAdded = new AtomicBoolean(false)
+
+ private def withClassLoaderIfNeeded[T](f: => T): T = {
+ val log = s" classloader for session ${session.sessionUUID} because " +
+ s"alwaysApplyClassLoader=$alwaysApplyClassLoader, " +
+ s"sessionArtifactAdded=${sessionArtifactAdded.get()}."
+ if (alwaysApplyClassLoader || sessionArtifactAdded.get()) {
+ logDebug(s"Applying $log")
+ Utils.withContextClassLoader(classloader) {
f
}
+ } else {
+ logDebug(s"Not applying $log")
+ f
+ }
+ }
+
+ def withResources[T](f: => T): T = withClassLoaderIfNeeded {
+ JobArtifactSet.withActiveJobArtifactState(state) {
+ f
}
}
@@ -176,6 +202,7 @@ class ArtifactManager(session: SparkSession) extends
Logging {
target,
allowOverwrite = true,
deleteSource = deleteStagedFile)
+ sessionArtifactAdded.set(true)
} else {
val target = ArtifactUtils.concatenatePaths(artifactPath,
normalizedRemoteRelativePath)
// Disallow overwriting with modified version
@@ -199,6 +226,7 @@ class ArtifactManager(session: SparkSession) extends
Logging {
sparkContextRelativePaths.add(
(SparkContextResourceType.JAR, normalizedRemoteRelativePath,
fragment))
jarsList.add(normalizedRemoteRelativePath)
+ sessionArtifactAdded.set(true)
} else if
(normalizedRemoteRelativePath.startsWith(s"pyfiles${File.separator}")) {
session.sparkContext.addFile(uri)
sparkContextRelativePaths.add(
@@ -258,9 +286,10 @@ class ArtifactManager(session: SparkSession) extends
Logging {
* Returns a [[ClassLoader]] for session-specific jar/class file resources.
*/
def classloader: ClassLoader = {
- val urls = getAddedJars :+ classDir.toUri.toURL
+ val urls = (getAddedJars :+ classDir.toUri.toURL).toArray
val prefixes = SparkEnv.get.conf.get(CONNECT_SCALA_UDF_STUB_PREFIXES)
val userClasspathFirst =
SparkEnv.get.conf.get(EXECUTOR_USER_CLASS_PATH_FIRST)
+ val fallbackClassLoader = session.sharedState.jarClassLoader
val loader = if (prefixes.nonEmpty) {
// Two things you need to know about classloader for all of this to make
sense:
// 1. A classloader needs to be able to fully define a class.
@@ -274,21 +303,16 @@ class ArtifactManager(session: SparkSession) extends
Logging {
// it delegates to.
if (userClasspathFirst) {
// USER -> SYSTEM -> STUB
- new ChildFirstURLClassLoader(
- urls.toArray,
- StubClassLoader(Utils.getContextOrSparkClassLoader, prefixes))
+ new ChildFirstURLClassLoader(urls,
StubClassLoader(fallbackClassLoader, prefixes))
} else {
// SYSTEM -> USER -> STUB
- new ChildFirstURLClassLoader(
- urls.toArray,
- StubClassLoader(null, prefixes),
- Utils.getContextOrSparkClassLoader)
+ new ChildFirstURLClassLoader(urls, StubClassLoader(null, prefixes),
fallbackClassLoader)
}
} else {
if (userClasspathFirst) {
- new ChildFirstURLClassLoader(urls.toArray,
Utils.getContextOrSparkClassLoader)
+ new ChildFirstURLClassLoader(urls, fallbackClassLoader)
} else {
- new URLClassLoader(urls.toArray, Utils.getContextOrSparkClassLoader)
+ new URLClassLoader(urls, fallbackClassLoader)
}
}
@@ -347,14 +371,16 @@ class ArtifactManager(session: SparkSession) extends
Logging {
// Clean up added files
val fileserver = SparkEnv.get.rpcEnv.fileServer
val sparkContext = session.sparkContext
- val shouldUpdateEnv = sparkContext.addedFiles.contains(state.uuid) ||
- sparkContext.addedArchives.contains(state.uuid) ||
- sparkContext.addedJars.contains(state.uuid)
- if (shouldUpdateEnv) {
-
sparkContext.addedFiles.remove(state.uuid).foreach(_.keys.foreach(fileserver.removeFile))
-
sparkContext.addedArchives.remove(state.uuid).foreach(_.keys.foreach(fileserver.removeFile))
-
sparkContext.addedJars.remove(state.uuid).foreach(_.keys.foreach(fileserver.removeJar))
- sparkContext.postEnvironmentUpdate()
+ if (state != null) {
+ val shouldUpdateEnv = sparkContext.addedFiles.contains(state.uuid) ||
+ sparkContext.addedArchives.contains(state.uuid) ||
+ sparkContext.addedJars.contains(state.uuid)
+ if (shouldUpdateEnv) {
+
sparkContext.addedFiles.remove(state.uuid).foreach(_.keys.foreach(fileserver.removeFile))
+
sparkContext.addedArchives.remove(state.uuid).foreach(_.keys.foreach(fileserver.removeFile))
+
sparkContext.addedJars.remove(state.uuid).foreach(_.keys.foreach(fileserver.removeJar))
+ sparkContext.postEnvironmentUpdate()
+ }
}
// Clean up cached relations
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
index 5db14a866213..e805aabe013c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
@@ -120,93 +120,97 @@ object SQLExecution extends Logging {
val redactedConfigs =
sparkSession.sessionState.conf.redactOptions(modifiedConfigs)
withSQLConfPropagated(sparkSession) {
- withSessionTagsApplied(sparkSession) {
- var ex: Option[Throwable] = None
- var isExecutedPlanAvailable = false
- val startTime = System.nanoTime()
- val startEvent = SparkListenerSQLExecutionStart(
- executionId = executionId,
- rootExecutionId = Some(rootExecutionId),
- description = desc,
- details = callSite.longForm,
- physicalPlanDescription = "",
- sparkPlanInfo = SparkPlanInfo.EMPTY,
- time = System.currentTimeMillis(),
- modifiedConfigs = redactedConfigs,
- jobTags = sc.getJobTags(),
- jobGroupId =
Option(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID))
- )
- try {
- body match {
- case Left(e) =>
- sc.listenerBus.post(startEvent)
+ sparkSession.artifactManager.withResources {
+ withSessionTagsApplied(sparkSession) {
+ var ex: Option[Throwable] = None
+ var isExecutedPlanAvailable = false
+ val startTime = System.nanoTime()
+ val startEvent = SparkListenerSQLExecutionStart(
+ executionId = executionId,
+ rootExecutionId = Some(rootExecutionId),
+ description = desc,
+ details = callSite.longForm,
+ physicalPlanDescription = "",
+ sparkPlanInfo = SparkPlanInfo.EMPTY,
+ time = System.currentTimeMillis(),
+ modifiedConfigs = redactedConfigs,
+ jobTags = sc.getJobTags(),
+ jobGroupId =
Option(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID))
+ )
+ try {
+ body match {
+ case Left(e) =>
+ sc.listenerBus.post(startEvent)
+ throw e
+ case Right(f) =>
+ val planDescriptionMode =
+
ExplainMode.fromString(sparkSession.sessionState.conf.uiExplainMode)
+ val planDesc =
queryExecution.explainString(planDescriptionMode)
+ val planInfo = try {
+ SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan)
+ } catch {
+ case NonFatal(e) =>
+ logDebug("Failed to generate SparkPlanInfo", e)
+ // If the queryExecution already failed before this, we
are not able to
+ // generate the the plan info, so we use and empty
graphviz node to make the
+ // UI happy
+ SparkPlanInfo.EMPTY
+ }
+ sc.listenerBus.post(
+ startEvent.copy(physicalPlanDescription = planDesc,
sparkPlanInfo = planInfo))
+ isExecutedPlanAvailable = true
+ f()
+ }
+ } catch {
+ case e: Throwable =>
+ ex = Some(e)
throw e
- case Right(f) =>
- val planDescriptionMode =
-
ExplainMode.fromString(sparkSession.sessionState.conf.uiExplainMode)
- val planDesc =
queryExecution.explainString(planDescriptionMode)
- val planInfo = try {
- SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan)
- } catch {
- case NonFatal(e) =>
- logDebug("Failed to generate SparkPlanInfo", e)
- // If the queryExecution already failed before this, we
are not able to generate
- // the the plan info, so we use and empty graphviz node to
make the UI happy
- SparkPlanInfo.EMPTY
- }
- sc.listenerBus.post(
- startEvent.copy(physicalPlanDescription = planDesc,
sparkPlanInfo = planInfo))
- isExecutedPlanAvailable = true
- f()
- }
- } catch {
- case e: Throwable =>
- ex = Some(e)
- throw e
- } finally {
- val endTime = System.nanoTime()
- val errorMessage = ex.map {
- case e: SparkThrowable =>
- SparkThrowableHelper.getMessage(e, ErrorMessageFormat.PRETTY)
- case e =>
- Utils.exceptionString(e)
- }
- if (queryExecution.shuffleCleanupMode != DoNotCleanup
- && isExecutedPlanAvailable) {
- val shuffleIds = queryExecution.executedPlan match {
- case ae: AdaptiveSparkPlanExec =>
- ae.context.shuffleIds.asScala.keys
- case _ =>
- Iterable.empty
+ } finally {
+ val endTime = System.nanoTime()
+ val errorMessage = ex.map {
+ case e: SparkThrowable =>
+ SparkThrowableHelper.getMessage(e, ErrorMessageFormat.PRETTY)
+ case e =>
+ Utils.exceptionString(e)
}
- shuffleIds.foreach { shuffleId =>
- queryExecution.shuffleCleanupMode match {
- case RemoveShuffleFiles =>
- // Same as what we do in ContextCleaner.doCleanupShuffle,
but do not unregister
- // the shuffle on MapOutputTracker, so that stage retries
would be triggered.
- // Set blocking to Utils.isTesting to deflake unit tests.
- sc.shuffleDriverComponents.removeShuffle(shuffleId,
Utils.isTesting)
- case SkipMigration =>
-
SparkEnv.get.blockManager.migratableResolver.addShuffleToSkip(shuffleId)
- case _ => // this should not happen
+ if (queryExecution.shuffleCleanupMode != DoNotCleanup
+ && isExecutedPlanAvailable) {
+ val shuffleIds = queryExecution.executedPlan match {
+ case ae: AdaptiveSparkPlanExec =>
+ ae.context.shuffleIds.asScala.keys
+ case _ =>
+ Iterable.empty
+ }
+ shuffleIds.foreach { shuffleId =>
+ queryExecution.shuffleCleanupMode match {
+ case RemoveShuffleFiles =>
+ // Same as what we do in
ContextCleaner.doCleanupShuffle, but do not
+ // unregister the shuffle on MapOutputTracker, so that
stage retries would be
+ // triggered.
+ // Set blocking to Utils.isTesting to deflake unit tests.
+ sc.shuffleDriverComponents.removeShuffle(shuffleId,
Utils.isTesting)
+ case SkipMigration =>
+
SparkEnv.get.blockManager.migratableResolver.addShuffleToSkip(shuffleId)
+ case _ => // this should not happen
+ }
}
}
+ val event = SparkListenerSQLExecutionEnd(
+ executionId,
+ System.currentTimeMillis(),
+ // Use empty string to indicate no error, as None may mean
events generated by old
+ // versions of Spark.
+ errorMessage.orElse(Some("")))
+ // Currently only `Dataset.withAction` and
`DataFrameWriter.runCommand` specify the
+ // `name` parameter. The `ExecutionListenerManager` only watches
SQL executions with
+ // name. We can specify the execution name in more places in the
future, so that
+ // `QueryExecutionListener` can track more cases.
+ event.executionName = name
+ event.duration = endTime - startTime
+ event.qe = queryExecution
+ event.executionFailure = ex
+ sc.listenerBus.post(event)
}
- val event = SparkListenerSQLExecutionEnd(
- executionId,
- System.currentTimeMillis(),
- // Use empty string to indicate no error, as None may mean
events generated by old
- // versions of Spark.
- errorMessage.orElse(Some("")))
- // Currently only `Dataset.withAction` and
`DataFrameWriter.runCommand` specify the
- // `name` parameter. The `ExecutionListenerManager` only watches
SQL executions with
- // name. We can specify the execution name in more places in the
future, so that
- // `QueryExecutionListener` can track more cases.
- event.executionName = name
- event.duration = endTime - startTime
- event.qe = queryExecution
- event.executionFailure = ex
- sc.listenerBus.post(event)
}
}
}
@@ -301,7 +305,10 @@ object SQLExecution extends Logging {
val activeSession = sparkSession
val sc = sparkSession.sparkContext
val localProps = Utils.cloneProperties(sc.getLocalProperties)
- val artifactState = JobArtifactSet.getCurrentJobArtifactState.orNull
+ // `getCurrentJobArtifactState` will return a stat only in Spark Connect
mode. In non-Connect
+ // mode, we default back to the resources of the current Spark session.
+ val artifactState = JobArtifactSet.getCurrentJobArtifactState.getOrElse(
+ activeSession.artifactManager.state)
exec.submit(() => JobArtifactSet.withActiveJobArtifactState(artifactState)
{
val originalSession = SparkSession.getActiveSession
val originalLocalProps = sc.getLocalProperties
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index d8f32a2cb922..bd501c935723 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -223,9 +223,7 @@ abstract class StreamExecution(
// To fix call site like "run at <unknown>:0", we bridge the call site
from the caller
// thread to this micro batch thread
sparkSession.sparkContext.setCallSite(callSite)
- JobArtifactSet.withActiveJobArtifactState(jobArtifactState) {
- runStream()
- }
+ runStream()
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
index f22d4fe32668..59a873ef982f 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
@@ -181,7 +181,7 @@ abstract class BaseSessionStateBuilder(
* Note 1: The user-defined functions must be deterministic.
* Note 2: This depends on the `functionRegistry` field.
*/
- protected def udfRegistration: UDFRegistration = new
UDFRegistration(functionRegistry)
+ protected def udfRegistration: UDFRegistration = new
UDFRegistration(session, functionRegistry)
protected def udtfRegistration: UDTFRegistration = new
UDTFRegistration(tableFunctionRegistry)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/artifact/ArtifactManagerSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/artifact/ArtifactManagerSuite.scala
index e929a6b5303a..e935af8b8bf8 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/artifact/ArtifactManagerSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/artifact/ArtifactManagerSuite.scala
@@ -24,8 +24,8 @@ import org.apache.commons.io.FileUtils
import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.api.java.UDF2
import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.DataTypes
import org.apache.spark.storage.CacheId
@@ -36,6 +36,8 @@ class ArtifactManagerSuite extends SharedSparkSession {
override protected def sparkConf: SparkConf = {
val conf = super.sparkConf
conf.set("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", "true")
+ conf.set(SQLConf.ARTIFACTS_SESSION_ISOLATION_ENABLED, true)
+ conf.set(SQLConf.ARTIFACTS_SESSION_ISOLATION_ALWAYS_APPLY_CLASSLOADER,
true)
}
private val artifactPath = new
File("src/test/resources/artifact-tests").toPath
@@ -331,24 +333,17 @@ class ArtifactManagerSuite extends SharedSparkSession {
}
}
- test("Add UDF as artifact") {
+ test("Added artifact can be loaded by the current SparkSession") {
val buffer = Files.readAllBytes(artifactPath.resolve("IntSumUdf.class"))
spark.addArtifact(buffer, "IntSumUdf.class")
- val instance = artifactManager.classloader
- .loadClass("IntSumUdf")
- .getDeclaredConstructor()
- .newInstance()
- .asInstanceOf[UDF2[Long, Long, Long]]
- spark.udf.register("intSum", instance, DataTypes.LongType)
-
- artifactManager.withResources {
- val r = spark.range(5)
- .withColumn("id2", col("id") + 1)
- .selectExpr("intSum(id, id2)")
- .collect()
- assert(r.map(_.getLong(0)).toSeq == Seq(1, 3, 5, 7, 9))
- }
+ spark.udf.registerJava("intSum", "IntSumUdf", DataTypes.LongType)
+
+ val r = spark.range(5)
+ .withColumn("id2", col("id") + 1)
+ .selectExpr("intSum(id, id2)")
+ .collect()
+ assert(r.map(_.getLong(0)).toSeq == Seq(1, 3, 5, 7, 9))
}
private def testAddArtifactToLocalSession(
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
index fec7183bc75e..32a63f5c6197 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
@@ -2141,7 +2141,12 @@ abstract class DDLSuite extends QueryTest with
DDLSuiteBase {
root = Utils.createTempDir().getCanonicalPath, namePrefix =
"addDirectory")
val testFile = File.createTempFile("testFile", "1", directoryToAdd)
spark.sql(s"ADD FILE $directoryToAdd")
- assert(new
File(SparkFiles.get(s"${directoryToAdd.getName}/${testFile.getName}")).exists())
+ // TODO(SPARK-50244): ADD JAR is inside `sql()` thus isolated. This will
break an existing Hive
+ // use case (one session adds JARs and another session uses them). After
we sort out the Hive
+ // isolation issue we will decide if the next assert should be wrapped
inside `withResources`.
+ spark.artifactManager.withResources {
+ assert(new
File(SparkFiles.get(s"${directoryToAdd.getName}/${testFile.getName}")).exists())
+ }
}
test(s"Add a directory when
${SQLConf.LEGACY_ADD_SINGLE_FILE_IN_ADD_FILE.key} set to true") {
diff --git
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index 42fc50e5b163..c41370c96241 100644
---
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -70,6 +70,14 @@ class HiveQuerySuite extends HiveComparisonTest with
SQLTestUtils with BeforeAnd
}
}
+ override def afterEach(): Unit = {
+ try {
+ spark.artifactManager.cleanUpResources()
+ } finally {
+ super.afterEach()
+ }
+ }
+
private def assertUnsupportedFeature(
body: => Unit,
operation: String,
diff --git
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
index 9611a37ef0d0..247a1c7096cb 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
@@ -630,7 +630,11 @@ private[hive] object TestHiveContext {
val overrideConfs: Map[String, String] =
Map(
// Fewer shuffle partitions to speed up testing.
- SQLConf.SHUFFLE_PARTITIONS.key -> "5"
+ SQLConf.SHUFFLE_PARTITIONS.key -> "5",
+ // TODO(SPARK-50244): We now isolate artifacts added by the `ADD JAR`
command. This will break
+ // an existing Hive use case (one session adds JARs and another session
uses them). We need
+ // to decide whether/how to enable isolation for Hive.
+ SQLConf.ARTIFACTS_SESSION_ISOLATION_ENABLED.key -> "false"
)
def makeWarehouseDir(): File = {
diff --git
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala
index d50bf0b8fd60..770e1da94a1c 100644
---
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala
+++
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala
@@ -40,4 +40,11 @@ trait TestHiveSingleton extends SparkFunSuite with
BeforeAndAfterAll {
}
}
+ protected override def afterEach(): Unit = {
+ try {
+ spark.artifactManager.cleanUpResources()
+ } finally {
+ super.afterEach()
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]