This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new ad3059f23c4 [SPARK-44246][CONNECT][FOLLOW-UP] Miscellaneous cleanups
for Spark Connect Jar/Classfile Isolation
ad3059f23c4 is described below
commit ad3059f23c407e464cf0b203bbafd7655c480866
Author: vicennial <[email protected]>
AuthorDate: Mon Jul 3 09:35:27 2023 +0900
[SPARK-44246][CONNECT][FOLLOW-UP] Miscellaneous cleanups for Spark Connect
Jar/Classfile Isolation
### What changes were proposed in this pull request?
This PR is a follow-up of #41701 and addresses the comments mentioned
[here](https://github.com/apache/spark/pull/41701#issuecomment-1608577372). The
summary is:
- `pythonIncludes` are directly fetched from the `ArtifactManager` via
`SessionHolder` instead of propagating through the spark conf
- `SessionHolder#withContext` renamed to
`SessionHolder#withContextClassLoader` to decrease ambiguity.
- General increased test coverage for isolated classloading (New unit test
in `ArtifactManagerSuite` and a new suite `ClassLoaderIsolationSuite`.
### Why are the changes needed?
General follow-ups from
[here.](https://github.com/apache/spark/pull/41701#issuecomment-1608577372)
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
New test suite and unit tests.
Closes #41789 from vicennial/SPARK-44246.
Authored-by: vicennial <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../sql/connect/planner/SparkConnectPlanner.scala | 13 +--
.../spark/sql/connect/service/SessionHolder.scala | 45 +--------
.../service/SparkConnectAnalyzeHandler.scala | 2 +-
.../connect/artifact/ArtifactManagerSuite.scala | 20 +++-
core/src/test/resources/TestHelloV2.jar | Bin 0 -> 3784 bytes
core/src/test/resources/TestHelloV3.jar | Bin 0 -> 3595 bytes
.../spark/executor/ClassLoaderIsolationSuite.scala | 102 +++++++++++++++++++++
7 files changed, 126 insertions(+), 56 deletions(-)
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index cdad4fc6190..149d5512953 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -26,8 +26,6 @@ import com.google.protobuf.{Any => ProtoAny, ByteString}
import io.grpc.{Context, Status, StatusRuntimeException}
import io.grpc.stub.StreamObserver
import org.apache.commons.lang3.exception.ExceptionUtils
-import org.json4s._
-import org.json4s.jackson.JsonMethods.parse
import org.apache.spark.{Partition, SparkEnv, TaskContext}
import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction}
@@ -91,15 +89,6 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder)
extends Logging {
private lazy val pythonExec =
sys.env.getOrElse("PYSPARK_PYTHON",
sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", "python3"))
- // SparkConnectPlanner is used per request.
- private lazy val pythonIncludes = {
- implicit val formats = DefaultFormats
- parse(session.conf.get("spark.connect.pythonUDF.includes", "[]"))
- .extract[Array[String]]
- .toList
- .asJava
- }
-
// The root of the query plan is a relation and we apply the transformations
to it.
def transformRelation(rel: proto.Relation): LogicalPlan = {
val plan = rel.getRelTypeCase match {
@@ -1527,7 +1516,7 @@ class SparkConnectPlanner(val sessionHolder:
SessionHolder) extends Logging {
command = fun.getCommand.toByteArray,
// Empty environment variables
envVars = Maps.newHashMap(),
- pythonIncludes = pythonIncludes,
+ pythonIncludes =
sessionHolder.artifactManager.getSparkConnectPythonIncludes.asJava,
pythonExec = pythonExec,
pythonVer = fun.getPythonVer,
// Empty broadcast variables
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
index 24502fccd96..56ef68abbc2 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
@@ -24,9 +24,6 @@ import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap}
import scala.collection.JavaConverters._
import scala.util.control.NonFatal
-import org.json4s.JsonDSL._
-import org.json4s.jackson.JsonMethods.{compact, render}
-
import org.apache.spark.JobArtifactSet
import org.apache.spark.SparkException
import org.apache.spark.connect.proto
@@ -114,7 +111,7 @@ case class SessionHolder(userId: String, sessionId: String,
session: SparkSessio
* @param f
* @tparam T
*/
- def withContext[T](f: => T): T = {
+ def withContextClassLoader[T](f: => T): T = {
// Needed for deserializing and evaluating the UDF on the driver
Utils.withContextClassLoader(classloader) {
// Needed for propagating the dependencies to the executors.
@@ -124,49 +121,15 @@ case class SessionHolder(userId: String, sessionId:
String, session: SparkSessio
}
}
- /**
- * Set the session-based Python paths to include in Python UDF.
- * @param f
- * @tparam T
- */
- def withSessionBasedPythonPaths[T](f: => T): T = {
- try {
- session.conf.set(
- "spark.connect.pythonUDF.includes",
- compact(render(artifactManager.getSparkConnectPythonIncludes)))
- f
- } finally {
- session.conf.unset("spark.connect.pythonUDF.includes")
- }
- }
-
/**
* Execute a block of code with this session as the active SparkConnect
session.
* @param f
* @tparam T
*/
def withSession[T](f: SparkSession => T): T = {
- withSessionBasedPythonPaths {
- withContext {
- session.withActive {
- f(session)
- }
- }
- }
- }
-
- /**
- * Execute a block of code using the session from this [[SessionHolder]] as
the active
- * SparkConnect session.
- * @param f
- * @tparam T
- */
- def withSessionHolder[T](f: SessionHolder => T): T = {
- withSessionBasedPythonPaths {
- withContext {
- session.withActive {
- f(this)
- }
+ withContextClassLoader {
+ session.withActive {
+ f(session)
}
}
}
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala
index 5c069bfaf5d..414a852380f 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala
@@ -38,7 +38,7 @@ private[connect] class SparkConnectAnalyzeHandler(
request.getSessionId)
// `withSession` ensures that session-specific artifacts (such as JARs and
class files) are
// available during processing (such as deserialization).
- sessionHolder.withSessionHolder { sessionHolder =>
+ sessionHolder.withSession { _ =>
val response = process(request, sessionHolder)
responseObserver.onNext(response)
responseObserver.onCompleted()
diff --git
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala
index 42ab8ca18f6..612bf096b22 100644
---
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala
+++
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala
@@ -224,6 +224,7 @@ class ArtifactManagerSuite extends SharedSparkSession with
ResourceHelper {
test("Classloaders for spark sessions are isolated") {
val holder1 = SparkConnectService.getOrCreateIsolatedSession("c1",
"session1")
val holder2 = SparkConnectService.getOrCreateIsolatedSession("c2",
"session2")
+ val holder3 = SparkConnectService.getOrCreateIsolatedSession("c3",
"session3")
def addHelloClass(holder: SessionHolder): Unit = {
val copyDir = Utils.createTempDir().toPath
@@ -234,7 +235,7 @@ class ArtifactManagerSuite extends SharedSparkSession with
ResourceHelper {
holder.addArtifact(remotePath, stagingPath, None)
}
- // Add the classfile only for the first user
+ // Add the "Hello" classfile for the first user
addHelloClass(holder1)
val classLoader1 = holder1.classloader
@@ -246,7 +247,8 @@ class ArtifactManagerSuite extends SharedSparkSession with
ResourceHelper {
val udf1 = org.apache.spark.sql.functions.udf(instance1)
holder1.withSession { session =>
- session.range(10).select(udf1(col("id").cast("string"))).collect()
+ val result =
session.range(10).select(udf1(col("id").cast("string"))).collect()
+ assert(result.forall(_.getString(0).contains("Talon")))
}
assertThrows[ClassNotFoundException] {
@@ -257,6 +259,20 @@ class ArtifactManagerSuite extends SharedSparkSession with
ResourceHelper {
.newInstance("Talon")
.asInstanceOf[String => String]
}
+
+ // Add the "Hello" classfile for the third user
+ addHelloClass(holder3)
+ val instance3 = holder3.classloader
+ .loadClass("Hello")
+ .getDeclaredConstructor(classOf[String])
+ .newInstance("Ahri")
+ .asInstanceOf[String => String]
+ val udf3 = org.apache.spark.sql.functions.udf(instance3)
+
+ holder3.withSession { session =>
+ val result =
session.range(10).select(udf3(col("id").cast("string"))).collect()
+ assert(result.forall(_.getString(0).contains("Ahri")))
+ }
}
}
diff --git a/core/src/test/resources/TestHelloV2.jar
b/core/src/test/resources/TestHelloV2.jar
new file mode 100644
index 00000000000..d89cf6543a2
Binary files /dev/null and b/core/src/test/resources/TestHelloV2.jar differ
diff --git a/core/src/test/resources/TestHelloV3.jar
b/core/src/test/resources/TestHelloV3.jar
new file mode 100644
index 00000000000..b175a6c8640
Binary files /dev/null and b/core/src/test/resources/TestHelloV3.jar differ
diff --git
a/core/src/test/scala/org/apache/spark/executor/ClassLoaderIsolationSuite.scala
b/core/src/test/scala/org/apache/spark/executor/ClassLoaderIsolationSuite.scala
new file mode 100644
index 00000000000..33c1baccd72
--- /dev/null
+++
b/core/src/test/scala/org/apache/spark/executor/ClassLoaderIsolationSuite.scala
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.executor
+
+import org.apache.spark.{JobArtifactSet, LocalSparkContext, SparkConf,
SparkContext, SparkFunSuite}
+import org.apache.spark.util.Utils
+
+class ClassLoaderIsolationSuite extends SparkFunSuite with LocalSparkContext {
+ val jar1 =
Thread.currentThread().getContextClassLoader.getResource("TestUDTF.jar").toString
+
+ // package com.example
+ // object Hello { def test(): Int = 2 }
+ // case class Hello(x: Int, y: Int)
+ val jar2 =
Thread.currentThread().getContextClassLoader.getResource("TestHelloV2.jar").toString
+
+ // package com.example
+ // object Hello { def test(): Int = 3 }
+ // case class Hello(x: String)
+ val jar3 =
Thread.currentThread().getContextClassLoader.getResource("TestHelloV3.jar").toString
+
+ test("Executor classloader isolation with JobArtifactSet") {
+ sc = new SparkContext(new
SparkConf().setAppName("test").setMaster("local"))
+ sc.addJar(jar1)
+ sc.addJar(jar2)
+ sc.addJar(jar3)
+
+ // TestHelloV2's test method returns '2'
+ val artifactSetWithHelloV2 = new JobArtifactSet(
+ uuid = Some("hello2"),
+ replClassDirUri = None,
+ jars = Map(jar2 -> 1L),
+ files = Map.empty,
+ archives = Map.empty
+ )
+
+ JobArtifactSet.withActive(artifactSetWithHelloV2) {
+ sc.parallelize(1 to 1).foreach { i =>
+ val cls = Utils.classForName("com.example.Hello$")
+ val module = cls.getField("MODULE$").get(null)
+ val result = cls.getMethod("test").invoke(module).asInstanceOf[Int]
+ if (result != 2) {
+ throw new RuntimeException("Unexpected result: " + result)
+ }
+ }
+ }
+
+ // TestHelloV3's test method returns '3'
+ val artifactSetWithHelloV3 = new JobArtifactSet(
+ uuid = Some("hello3"),
+ replClassDirUri = None,
+ jars = Map(jar3 -> 1L),
+ files = Map.empty,
+ archives = Map.empty
+ )
+
+ JobArtifactSet.withActive(artifactSetWithHelloV3) {
+ sc.parallelize(1 to 1).foreach { i =>
+ val cls = Utils.classForName("com.example.Hello$")
+ val module = cls.getField("MODULE$").get(null)
+ val result = cls.getMethod("test").invoke(module).asInstanceOf[Int]
+ if (result != 3) {
+ throw new RuntimeException("Unexpected result: " + result)
+ }
+ }
+ }
+
+ // Should not be able to see any "Hello" class if they're excluded from
the artifact set
+ val artifactSetWithoutHello = new JobArtifactSet(
+ uuid = Some("Jar 1"),
+ replClassDirUri = None,
+ jars = Map(jar1 -> 1L),
+ files = Map.empty,
+ archives = Map.empty
+ )
+
+ JobArtifactSet.withActive(artifactSetWithoutHello) {
+ sc.parallelize(1 to 1).foreach { i =>
+ try {
+ Utils.classForName("com.example.Hello$")
+ throw new RuntimeException("Import should fail")
+ } catch {
+ case _: ClassNotFoundException =>
+ }
+ }
+ }
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]