This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ad3059f23c4 [SPARK-44246][CONNECT][FOLLOW-UP] Miscellaneous cleanups 
for Spark Connect Jar/Classfile Isolation
ad3059f23c4 is described below

commit ad3059f23c407e464cf0b203bbafd7655c480866
Author: vicennial <[email protected]>
AuthorDate: Mon Jul 3 09:35:27 2023 +0900

    [SPARK-44246][CONNECT][FOLLOW-UP] Miscellaneous cleanups for Spark Connect 
Jar/Classfile Isolation
    
    ### What changes were proposed in this pull request?
    
    This PR is a follow-up of #41701 and addresses the comments mentioned 
[here](https://github.com/apache/spark/pull/41701#issuecomment-1608577372). The 
summary is:
    
    - `pythonIncludes` are directly fetched from the `ArtifactManager` via 
`SessionHolder` instead of propagating through the spark conf
    - `SessionHolder#withContext` renamed to 
`SessionHolder#withContextClassLoader` to decrease ambiguity.
    - General increased test coverage for isolated classloading (New unit test 
in `ArtifactManagerSuite` and a new suite `ClassLoaderIsolationSuite`.
    
    ### Why are the changes needed?
    
    General follow-ups from 
[here.](https://github.com/apache/spark/pull/41701#issuecomment-1608577372)
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    New test suite and unit tests.
    
    Closes #41789 from vicennial/SPARK-44246.
    
    Authored-by: vicennial <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../sql/connect/planner/SparkConnectPlanner.scala  |  13 +--
 .../spark/sql/connect/service/SessionHolder.scala  |  45 +--------
 .../service/SparkConnectAnalyzeHandler.scala       |   2 +-
 .../connect/artifact/ArtifactManagerSuite.scala    |  20 +++-
 core/src/test/resources/TestHelloV2.jar            | Bin 0 -> 3784 bytes
 core/src/test/resources/TestHelloV3.jar            | Bin 0 -> 3595 bytes
 .../spark/executor/ClassLoaderIsolationSuite.scala | 102 +++++++++++++++++++++
 7 files changed, 126 insertions(+), 56 deletions(-)

diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index cdad4fc6190..149d5512953 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -26,8 +26,6 @@ import com.google.protobuf.{Any => ProtoAny, ByteString}
 import io.grpc.{Context, Status, StatusRuntimeException}
 import io.grpc.stub.StreamObserver
 import org.apache.commons.lang3.exception.ExceptionUtils
-import org.json4s._
-import org.json4s.jackson.JsonMethods.parse
 
 import org.apache.spark.{Partition, SparkEnv, TaskContext}
 import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction}
@@ -91,15 +89,6 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) 
extends Logging {
   private lazy val pythonExec =
     sys.env.getOrElse("PYSPARK_PYTHON", 
sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", "python3"))
 
-  // SparkConnectPlanner is used per request.
-  private lazy val pythonIncludes = {
-    implicit val formats = DefaultFormats
-    parse(session.conf.get("spark.connect.pythonUDF.includes", "[]"))
-      .extract[Array[String]]
-      .toList
-      .asJava
-  }
-
   // The root of the query plan is a relation and we apply the transformations 
to it.
   def transformRelation(rel: proto.Relation): LogicalPlan = {
     val plan = rel.getRelTypeCase match {
@@ -1527,7 +1516,7 @@ class SparkConnectPlanner(val sessionHolder: 
SessionHolder) extends Logging {
       command = fun.getCommand.toByteArray,
       // Empty environment variables
       envVars = Maps.newHashMap(),
-      pythonIncludes = pythonIncludes,
+      pythonIncludes = 
sessionHolder.artifactManager.getSparkConnectPythonIncludes.asJava,
       pythonExec = pythonExec,
       pythonVer = fun.getPythonVer,
       // Empty broadcast variables
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
index 24502fccd96..56ef68abbc2 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
@@ -24,9 +24,6 @@ import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap}
 import scala.collection.JavaConverters._
 import scala.util.control.NonFatal
 
-import org.json4s.JsonDSL._
-import org.json4s.jackson.JsonMethods.{compact, render}
-
 import org.apache.spark.JobArtifactSet
 import org.apache.spark.SparkException
 import org.apache.spark.connect.proto
@@ -114,7 +111,7 @@ case class SessionHolder(userId: String, sessionId: String, 
session: SparkSessio
    * @param f
    * @tparam T
    */
-  def withContext[T](f: => T): T = {
+  def withContextClassLoader[T](f: => T): T = {
     // Needed for deserializing and evaluating the UDF on the driver
     Utils.withContextClassLoader(classloader) {
       // Needed for propagating the dependencies to the executors.
@@ -124,49 +121,15 @@ case class SessionHolder(userId: String, sessionId: 
String, session: SparkSessio
     }
   }
 
-  /**
-   * Set the session-based Python paths to include in Python UDF.
-   * @param f
-   * @tparam T
-   */
-  def withSessionBasedPythonPaths[T](f: => T): T = {
-    try {
-      session.conf.set(
-        "spark.connect.pythonUDF.includes",
-        compact(render(artifactManager.getSparkConnectPythonIncludes)))
-      f
-    } finally {
-      session.conf.unset("spark.connect.pythonUDF.includes")
-    }
-  }
-
   /**
    * Execute a block of code with this session as the active SparkConnect 
session.
    * @param f
    * @tparam T
    */
   def withSession[T](f: SparkSession => T): T = {
-    withSessionBasedPythonPaths {
-      withContext {
-        session.withActive {
-          f(session)
-        }
-      }
-    }
-  }
-
-  /**
-   * Execute a block of code using the session from this [[SessionHolder]] as 
the active
-   * SparkConnect session.
-   * @param f
-   * @tparam T
-   */
-  def withSessionHolder[T](f: SessionHolder => T): T = {
-    withSessionBasedPythonPaths {
-      withContext {
-        session.withActive {
-          f(this)
-        }
+    withContextClassLoader {
+      session.withActive {
+        f(session)
       }
     }
   }
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala
index 5c069bfaf5d..414a852380f 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala
@@ -38,7 +38,7 @@ private[connect] class SparkConnectAnalyzeHandler(
       request.getSessionId)
     // `withSession` ensures that session-specific artifacts (such as JARs and 
class files) are
     // available during processing (such as deserialization).
-    sessionHolder.withSessionHolder { sessionHolder =>
+    sessionHolder.withSession { _ =>
       val response = process(request, sessionHolder)
       responseObserver.onNext(response)
       responseObserver.onCompleted()
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala
index 42ab8ca18f6..612bf096b22 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala
@@ -224,6 +224,7 @@ class ArtifactManagerSuite extends SharedSparkSession with 
ResourceHelper {
   test("Classloaders for spark sessions are isolated") {
     val holder1 = SparkConnectService.getOrCreateIsolatedSession("c1", 
"session1")
     val holder2 = SparkConnectService.getOrCreateIsolatedSession("c2", 
"session2")
+    val holder3 = SparkConnectService.getOrCreateIsolatedSession("c3", 
"session3")
 
     def addHelloClass(holder: SessionHolder): Unit = {
       val copyDir = Utils.createTempDir().toPath
@@ -234,7 +235,7 @@ class ArtifactManagerSuite extends SharedSparkSession with 
ResourceHelper {
       holder.addArtifact(remotePath, stagingPath, None)
     }
 
-    // Add the classfile only for the first user
+    // Add the "Hello" classfile for the first user
     addHelloClass(holder1)
 
     val classLoader1 = holder1.classloader
@@ -246,7 +247,8 @@ class ArtifactManagerSuite extends SharedSparkSession with 
ResourceHelper {
     val udf1 = org.apache.spark.sql.functions.udf(instance1)
 
     holder1.withSession { session =>
-      session.range(10).select(udf1(col("id").cast("string"))).collect()
+      val result = 
session.range(10).select(udf1(col("id").cast("string"))).collect()
+      assert(result.forall(_.getString(0).contains("Talon")))
     }
 
     assertThrows[ClassNotFoundException] {
@@ -257,6 +259,20 @@ class ArtifactManagerSuite extends SharedSparkSession with 
ResourceHelper {
         .newInstance("Talon")
         .asInstanceOf[String => String]
     }
+
+    // Add the "Hello" classfile for the third user
+    addHelloClass(holder3)
+    val instance3 = holder3.classloader
+      .loadClass("Hello")
+      .getDeclaredConstructor(classOf[String])
+      .newInstance("Ahri")
+      .asInstanceOf[String => String]
+    val udf3 = org.apache.spark.sql.functions.udf(instance3)
+
+    holder3.withSession { session =>
+      val result = 
session.range(10).select(udf3(col("id").cast("string"))).collect()
+      assert(result.forall(_.getString(0).contains("Ahri")))
+    }
   }
 }
 
diff --git a/core/src/test/resources/TestHelloV2.jar 
b/core/src/test/resources/TestHelloV2.jar
new file mode 100644
index 00000000000..d89cf6543a2
Binary files /dev/null and b/core/src/test/resources/TestHelloV2.jar differ
diff --git a/core/src/test/resources/TestHelloV3.jar 
b/core/src/test/resources/TestHelloV3.jar
new file mode 100644
index 00000000000..b175a6c8640
Binary files /dev/null and b/core/src/test/resources/TestHelloV3.jar differ
diff --git 
a/core/src/test/scala/org/apache/spark/executor/ClassLoaderIsolationSuite.scala 
b/core/src/test/scala/org/apache/spark/executor/ClassLoaderIsolationSuite.scala
new file mode 100644
index 00000000000..33c1baccd72
--- /dev/null
+++ 
b/core/src/test/scala/org/apache/spark/executor/ClassLoaderIsolationSuite.scala
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.executor
+
+import org.apache.spark.{JobArtifactSet, LocalSparkContext, SparkConf, 
SparkContext, SparkFunSuite}
+import org.apache.spark.util.Utils
+
+class ClassLoaderIsolationSuite extends SparkFunSuite with LocalSparkContext  {
+  val jar1 = 
Thread.currentThread().getContextClassLoader.getResource("TestUDTF.jar").toString
+
+  // package com.example
+  // object Hello { def test(): Int = 2 }
+  // case class Hello(x: Int, y: Int)
+  val jar2 = 
Thread.currentThread().getContextClassLoader.getResource("TestHelloV2.jar").toString
+
+  // package com.example
+  // object Hello { def test(): Int = 3 }
+  // case class Hello(x: String)
+  val jar3 = 
Thread.currentThread().getContextClassLoader.getResource("TestHelloV3.jar").toString
+
+  test("Executor classloader isolation with JobArtifactSet") {
+    sc = new SparkContext(new 
SparkConf().setAppName("test").setMaster("local"))
+    sc.addJar(jar1)
+    sc.addJar(jar2)
+    sc.addJar(jar3)
+
+    // TestHelloV2's test method returns '2'
+    val artifactSetWithHelloV2 = new JobArtifactSet(
+      uuid = Some("hello2"),
+      replClassDirUri = None,
+      jars = Map(jar2 -> 1L),
+      files = Map.empty,
+      archives = Map.empty
+    )
+
+    JobArtifactSet.withActive(artifactSetWithHelloV2) {
+      sc.parallelize(1 to 1).foreach { i =>
+        val cls = Utils.classForName("com.example.Hello$")
+        val module = cls.getField("MODULE$").get(null)
+        val result = cls.getMethod("test").invoke(module).asInstanceOf[Int]
+        if (result != 2) {
+          throw new RuntimeException("Unexpected result: " + result)
+        }
+      }
+    }
+
+    // TestHelloV3's test method returns '3'
+    val artifactSetWithHelloV3 = new JobArtifactSet(
+      uuid = Some("hello3"),
+      replClassDirUri = None,
+      jars = Map(jar3 -> 1L),
+      files = Map.empty,
+      archives = Map.empty
+    )
+
+    JobArtifactSet.withActive(artifactSetWithHelloV3) {
+      sc.parallelize(1 to 1).foreach { i =>
+        val cls = Utils.classForName("com.example.Hello$")
+        val module = cls.getField("MODULE$").get(null)
+        val result = cls.getMethod("test").invoke(module).asInstanceOf[Int]
+        if (result != 3) {
+          throw new RuntimeException("Unexpected result: " + result)
+        }
+      }
+    }
+
+    // Should not be able to see any "Hello" class if they're excluded from 
the artifact set
+    val artifactSetWithoutHello = new JobArtifactSet(
+      uuid = Some("Jar 1"),
+      replClassDirUri = None,
+      jars = Map(jar1 -> 1L),
+      files = Map.empty,
+      archives = Map.empty
+    )
+
+    JobArtifactSet.withActive(artifactSetWithoutHello) {
+      sc.parallelize(1 to 1).foreach { i =>
+        try {
+          Utils.classForName("com.example.Hello$")
+          throw new RuntimeException("Import should fail")
+        } catch {
+          case _: ClassNotFoundException =>
+        }
+      }
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to