This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 6d0fed9a18f [SPARK-43744][CONNECT] Fix class loading problem caused by stub user classes not found on the server classpath 6d0fed9a18f is described below commit 6d0fed9a18ff87e73fdf1ee46b6b0d2df8dd5a1b Author: Zhen Li <zhenli...@users.noreply.github.com> AuthorDate: Fri Jul 28 22:59:07 2023 -0400 [SPARK-43744][CONNECT] Fix class loading problem caused by stub user classes not found on the server classpath ### What changes were proposed in this pull request? This PR introduces a stub class loader for unpacking Scala UDFs in the driver and the executor. When encountering user classes that are not found on the server session classpath, the stub class loader would try to stub the class. This solves the problem that when serializing UDFs, Java serializer might include unnecessary user code e.g. User classes used in the lambda definition signatures in the same class where the UDF is defined. If the user code is actually needed to execute the UDF, we will return an error message to suggest the user to add the missing classes using the `addArtifact` method. ### Why are the changes needed? To enhance the user experience of UDF. This PR should be merged to master and 3.5. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added test both for Scala 2.12 & 2.13 4 tests in SparkSessionE2ESuite still fail to run with maven after the fix because the client test jar is installed on the system classpath (added using --jar at server start), the stub classloader can only stub classes missing from the session classpath (added using `session.addArtifact`). Moving the test jar to the session classpath causes failures in tests for `flatMapGroupsWithState` (SPARK-44576). Finish moving the test jar to session classpath once `flatMapGroupsWithState` test failures are fixed. Closes #42069 from zhenlineo/ref-spark-result. Authored-by: Zhen Li <zhenli...@users.noreply.github.com> Signed-off-by: Herman van Hovell <her...@databricks.com> --- .../scala/org/apache/spark/sql/SparkSession.scala | 2 +- .../sql/expressions/UserDefinedFunction.scala | 2 +- .../jvm/src/test/resources/StubClassDummyUdf.scala | 56 +++++++++ .../connect/client/jvm/src/test/resources/udf2.12 | Bin 0 -> 1520 bytes .../client/jvm/src/test/resources/udf2.12.jar | Bin 0 -> 5332 bytes .../connect/client/jvm/src/test/resources/udf2.13 | Bin 0 -> 1630 bytes .../client/jvm/src/test/resources/udf2.13.jar | Bin 0 -> 5674 bytes .../connect/client/UDFClassLoadingE2ESuite.scala | 83 +++++++++++++ .../connect/client/util/IntegrationTestUtils.scala | 2 +- .../connect/client/util/RemoteSparkSession.scala | 2 +- .../artifact/SparkConnectArtifactManager.scala | 17 ++- .../sql/connect/planner/SparkConnectPlanner.scala | 23 +++- connector/connect/server/src/test/resources/udf | Bin 0 -> 973 bytes .../connect/server/src/test/resources/udf_noA.jar | Bin 0 -> 5545 bytes .../connect/artifact/StubClassLoaderSuite.scala | 132 +++++++++++++++++++++ .../spark/util/ChildFirstURLClassLoader.java | 9 ++ .../scala/org/apache/spark/executor/Executor.scala | 86 +++++++++++--- .../org/apache/spark/internal/config/package.scala | 14 +++ .../org/apache/spark/util/StubClassLoader.scala | 79 ++++++++++++ 19 files changed, 480 insertions(+), 27 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index d1832e65f3e..4b3de91b56f 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -554,7 +554,7 @@ class SparkSession private[sql] ( val command = proto.Command.newBuilder().setRegisterFunction(udf).build() val plan = proto.Plan.newBuilder().setCommand(command).build() - client.execute(plan) + client.execute(plan).asScala.foreach(_ => ()) } @DeveloperApi diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 18aef8a2e4c..e5c89d90c19 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -92,7 +92,7 @@ sealed abstract class UserDefinedFunction { /** * Holder class for a scalar user-defined function and it's input/output encoder(s). */ -case class ScalarUserDefinedFunction private ( +case class ScalarUserDefinedFunction private[sql] ( // SPARK-43198: Eagerly serialize to prevent the UDF from containing a reference to this class. serializedUdfPacket: Array[Byte], inputTypes: Seq[proto.DataType], diff --git a/connector/connect/client/jvm/src/test/resources/StubClassDummyUdf.scala b/connector/connect/client/jvm/src/test/resources/StubClassDummyUdf.scala new file mode 100644 index 00000000000..ff1b3deafaf --- /dev/null +++ b/connector/connect/client/jvm/src/test/resources/StubClassDummyUdf.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.client + +// To generate a jar from the source file: +// `scalac StubClassDummyUdf.scala -d udf.jar` +// To remove class A from the jar: +// `jar -xvf udf.jar` -> delete A.class and A$.class +// `jar -cvf udf_noA.jar org/` +class StubClassDummyUdf { + val udf: Int => Int = (x: Int) => x + 1 + val dummy = (x: Int) => A(x) +} + +case class A(x: Int) { def get: Int = x + 5 } + +// The code to generate the udf file +object StubClassDummyUdf { + import java.io.{BufferedOutputStream, File, FileOutputStream} + import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.PrimitiveIntEncoder + import org.apache.spark.sql.connect.common.UdfPacket + import org.apache.spark.util.Utils + + def packDummyUdf(): String = { + val byteArray = + Utils.serialize[UdfPacket]( + new UdfPacket( + new StubClassDummyUdf().udf, + Seq(PrimitiveIntEncoder), + PrimitiveIntEncoder + ) + ) + val file = new File("src/test/resources/udf") + val target = new BufferedOutputStream(new FileOutputStream(file)) + try { + target.write(byteArray) + file.getAbsolutePath + } finally { + target.close + } + } +} diff --git a/connector/connect/client/jvm/src/test/resources/udf2.12 b/connector/connect/client/jvm/src/test/resources/udf2.12 new file mode 100644 index 00000000000..1090bc90d9b Binary files /dev/null and b/connector/connect/client/jvm/src/test/resources/udf2.12 differ diff --git a/connector/connect/client/jvm/src/test/resources/udf2.12.jar b/connector/connect/client/jvm/src/test/resources/udf2.12.jar new file mode 100644 index 00000000000..6ce6799678f Binary files /dev/null and b/connector/connect/client/jvm/src/test/resources/udf2.12.jar differ diff --git a/connector/connect/client/jvm/src/test/resources/udf2.13 b/connector/connect/client/jvm/src/test/resources/udf2.13 new file mode 100644 index 00000000000..863ac32a76d Binary files /dev/null and b/connector/connect/client/jvm/src/test/resources/udf2.13 differ diff --git a/connector/connect/client/jvm/src/test/resources/udf2.13.jar b/connector/connect/client/jvm/src/test/resources/udf2.13.jar new file mode 100644 index 00000000000..c89830f127c Binary files /dev/null and b/connector/connect/client/jvm/src/test/resources/udf2.13.jar differ diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/UDFClassLoadingE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/UDFClassLoadingE2ESuite.scala new file mode 100644 index 00000000000..8fdb7efbcba --- /dev/null +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/UDFClassLoadingE2ESuite.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.client + +import java.io.File +import java.nio.file.{Files, Paths} + +import scala.util.Properties + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connect.client.util.RemoteSparkSession +import org.apache.spark.sql.connect.common.ProtoDataTypes +import org.apache.spark.sql.expressions.ScalarUserDefinedFunction + +class UDFClassLoadingE2ESuite extends RemoteSparkSession { + + private val scalaVersion = Properties.versionNumberString + .split("\\.") + .take(2) + .mkString(".") + + // See src/test/resources/StubClassDummyUdf for how the UDFs and jars are created. + private val udfByteArray: Array[Byte] = + Files.readAllBytes(Paths.get(s"src/test/resources/udf$scalaVersion")) + private val udfJar = + new File(s"src/test/resources/udf$scalaVersion.jar").toURI.toURL + + private def registerUdf(session: SparkSession): Unit = { + val udf = ScalarUserDefinedFunction( + serializedUdfPacket = udfByteArray, + inputTypes = Seq(ProtoDataTypes.IntegerType), + outputType = ProtoDataTypes.IntegerType, + name = Some("dummyUdf"), + nullable = true, + deterministic = true) + session.registerUdf(udf.toProto) + } + + test("update class loader after stubbing: new session") { + // Session1 should stub the missing class, but fail to call methods on it + val session1 = spark.newSession() + + assert( + intercept[Exception] { + registerUdf(session1) + }.getMessage.contains( + "java.lang.NoSuchMethodException: org.apache.spark.sql.connect.client.StubClassDummyUdf")) + + // Session2 uses the real class + val session2 = spark.newSession() + session2.addArtifact(udfJar.toURI) + registerUdf(session2) + } + + test("update class loader after stubbing: same session") { + // Session should stub the missing class, but fail to call methods on it + val session = spark.newSession() + + assert( + intercept[Exception] { + registerUdf(session) + }.getMessage.contains( + "java.lang.NoSuchMethodException: org.apache.spark.sql.connect.client.StubClassDummyUdf")) + + // Session uses the real class + session.addArtifact(udfJar.toURI) + registerUdf(session) + } +} diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala index 819df5fc25b..4d88565308f 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala @@ -30,7 +30,7 @@ object IntegrationTestUtils { // System properties used for testing and debugging private val DEBUG_SC_JVM_CLIENT = "spark.debug.sc.jvm.client" - // Enable this flag to print all client debug log + server logs to the console + // Enable this flag to print all server logs to the console private[connect] val isDebug = System.getProperty(DEBUG_SC_JVM_CLIENT, "false").toBoolean private[sql] lazy val scalaVersion = { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala index 594d3c369fe..1c1cb1403fe 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala @@ -96,7 +96,7 @@ object SparkConnectServerUtils { // To find InMemoryTableCatalog for V2 writer tests val catalystTestJar = tryFindJar("sql/catalyst", "spark-catalyst", "spark-catalyst", test = true) - .map(clientTestJar => Seq("--jars", clientTestJar.getCanonicalPath)) + .map(clientTestJar => Seq(clientTestJar.getCanonicalPath)) .getOrElse(Seq.empty) // For UDF maven E2E tests, the server needs the client code to find the UDFs defined in tests. diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala index d8f290639c2..03391cef68b 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala @@ -31,12 +31,13 @@ import org.apache.hadoop.fs.{LocalFileSystem, Path => FSPath} import org.apache.spark.{JobArtifactSet, JobArtifactState, SparkContext, SparkEnv} import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.CONNECT_SCALA_UDF_STUB_CLASSES import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connect.artifact.util.ArtifactUtils import org.apache.spark.sql.connect.config.Connect.CONNECT_COPY_FROM_LOCAL_TO_FS_ALLOW_DEST_LOCAL import org.apache.spark.sql.connect.service.SessionHolder import org.apache.spark.storage.{CacheId, StorageLevel} -import org.apache.spark.util.Utils +import org.apache.spark.util.{ChildFirstURLClassLoader, StubClassLoader, Utils} /** * The Artifact Manager for the [[SparkConnectService]]. @@ -161,7 +162,19 @@ class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging */ def classloader: ClassLoader = { val urls = getSparkConnectAddedJars :+ classDir.toUri.toURL - new URLClassLoader(urls.toArray, Utils.getContextOrSparkClassLoader) + val loader = if (SparkEnv.get.conf.get(CONNECT_SCALA_UDF_STUB_CLASSES).nonEmpty) { + val stubClassLoader = + StubClassLoader(null, SparkEnv.get.conf.get(CONNECT_SCALA_UDF_STUB_CLASSES)) + new ChildFirstURLClassLoader( + urls.toArray, + stubClassLoader, + Utils.getContextOrSparkClassLoader) + } else { + new URLClassLoader(urls.toArray, Utils.getContextOrSparkClassLoader) + } + + logDebug(s"Using class loader: $loader, containing urls: $urls") + loader } /** diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index e4ac34715fb..ebed8af48f0 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.connect.planner +import java.io.IOException + import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.Try @@ -1504,15 +1506,24 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { } private def unpackUdf(fun: proto.CommonInlineUserDefinedFunction): UdfPacket = { - Utils.deserialize[UdfPacket]( - fun.getScalarScalaUdf.getPayload.toByteArray, - Utils.getContextOrSparkClassLoader) + unpackScalarScalaUDF[UdfPacket](fun.getScalarScalaUdf) } private def unpackForeachWriter(fun: proto.ScalarScalaUDF): ForeachWriterPacket = { - Utils.deserialize[ForeachWriterPacket]( - fun.getPayload.toByteArray, - Utils.getContextOrSparkClassLoader) + unpackScalarScalaUDF[ForeachWriterPacket](fun) + } + + private def unpackScalarScalaUDF[T](fun: proto.ScalarScalaUDF): T = { + try { + logDebug(s"Unpack using class loader: ${Utils.getContextOrSparkClassLoader}") + Utils.deserialize[T](fun.getPayload.toByteArray, Utils.getContextOrSparkClassLoader) + } catch { + case e: IOException if e.getCause.isInstanceOf[NoSuchMethodException] => + throw new ClassNotFoundException( + s"Failed to load class correctly due to ${e.getCause}. " + + "Make sure the artifact where the class is defined is installed by calling" + + " session.addArtifact.") + } } /** diff --git a/connector/connect/server/src/test/resources/udf b/connector/connect/server/src/test/resources/udf new file mode 100644 index 00000000000..55a3264a017 Binary files /dev/null and b/connector/connect/server/src/test/resources/udf differ diff --git a/connector/connect/server/src/test/resources/udf_noA.jar b/connector/connect/server/src/test/resources/udf_noA.jar new file mode 100644 index 00000000000..4d8c423ab6d Binary files /dev/null and b/connector/connect/server/src/test/resources/udf_noA.jar differ diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/StubClassLoaderSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/StubClassLoaderSuite.scala new file mode 100644 index 00000000000..0f6e0543151 --- /dev/null +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/StubClassLoaderSuite.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.artifact + +import java.io.File + +import org.apache.spark.SparkFunSuite +import org.apache.spark.util.{ChildFirstURLClassLoader, StubClassLoader} + +class StubClassLoaderSuite extends SparkFunSuite { + + // See src/test/resources/StubClassDummyUdf for how the UDFs and jars are created. + private val udfNoAJar = new File("src/test/resources/udf_noA.jar").toURI.toURL + private val classDummyUdf = "org.apache.spark.sql.connect.client.StubClassDummyUdf" + private val classA = "org.apache.spark.sql.connect.client.A" + + test("find class with stub class") { + val cl = new RecordedStubClassLoader(getClass().getClassLoader(), _ => true) + val cls = cl.findClass("my.name.HelloWorld") + assert(cls.getName === "my.name.HelloWorld") + assert(cl.lastStubbed === "my.name.HelloWorld") + } + + test("class for name with stub class") { + val cl = new RecordedStubClassLoader(getClass().getClassLoader(), _ => true) + // scalastyle:off classforname + val cls = Class.forName("my.name.HelloWorld", false, cl) + // scalastyle:on classforname + assert(cls.getName === "my.name.HelloWorld") + assert(cl.lastStubbed === "my.name.HelloWorld") + } + + test("filter class to stub") { + val list = "my.name" :: Nil + val cl = StubClassLoader(getClass().getClassLoader(), list) + val cls = cl.findClass("my.name.HelloWorld") + assert(cls.getName === "my.name.HelloWorld") + + intercept[ClassNotFoundException] { + cl.findClass("name.my.GoodDay") + } + } + + test("stub missing class") { + val sysClassLoader = getClass.getClassLoader() + val stubClassLoader = new RecordedStubClassLoader(null, _ => true) + + // Install artifact without class A. + val sessionClassLoader = + new ChildFirstURLClassLoader(Array(udfNoAJar), stubClassLoader, sysClassLoader) + // Load udf with A used in the same class. + loadDummyUdf(sessionClassLoader) + // Class A should be stubbed. + assert(stubClassLoader.lastStubbed === classA) + } + + test("unload stub class") { + val sysClassLoader = getClass.getClassLoader() + val stubClassLoader = new RecordedStubClassLoader(null, _ => true) + + val cl1 = new ChildFirstURLClassLoader(Array.empty, stubClassLoader, sysClassLoader) + + // Failed to load DummyUdf + intercept[Exception] { + loadDummyUdf(cl1) + } + // Successfully stubbed the missing class. + assert(stubClassLoader.lastStubbed === classDummyUdf) + + // Creating a new class loader will unpack the udf correctly. + val cl2 = new ChildFirstURLClassLoader( + Array(udfNoAJar), + stubClassLoader, // even with the same stub class loader. + sysClassLoader) + // Should be able to load after the artifact is added + loadDummyUdf(cl2) + } + + test("throw no such method if trying to access methods on stub class") { + val sysClassLoader = getClass.getClassLoader() + val stubClassLoader = new RecordedStubClassLoader(null, _ => true) + + val sessionClassLoader = + new ChildFirstURLClassLoader(Array.empty, stubClassLoader, sysClassLoader) + + // Failed to load DummyUdf because of missing methods + assert(intercept[NoSuchMethodException] { + loadDummyUdf(sessionClassLoader) + }.getMessage.contains(classDummyUdf)) + // Successfully stubbed the missing class. + assert(stubClassLoader.lastStubbed === classDummyUdf) + } + + private def loadDummyUdf(sessionClassLoader: ClassLoader): Unit = { + // Load DummyUdf and call a method on it. + // scalastyle:off classforname + val cls = Class.forName(classDummyUdf, false, sessionClassLoader) + // scalastyle:on classforname + cls.getDeclaredMethod("dummy") + + // Load class A used inside DummyUdf + // scalastyle:off classforname + Class.forName(classA, false, sessionClassLoader) + // scalastyle:on classforname + } +} + +class RecordedStubClassLoader(parent: ClassLoader, shouldStub: String => Boolean) + extends StubClassLoader(parent, shouldStub) { + var lastStubbed: String = _ + + override def findClass(name: String): Class[_] = { + if (shouldStub(name)) { + lastStubbed = name + } + super.findClass(name) + } +} diff --git a/core/src/main/java/org/apache/spark/util/ChildFirstURLClassLoader.java b/core/src/main/java/org/apache/spark/util/ChildFirstURLClassLoader.java index 57d96756c8b..2791209e019 100644 --- a/core/src/main/java/org/apache/spark/util/ChildFirstURLClassLoader.java +++ b/core/src/main/java/org/apache/spark/util/ChildFirstURLClassLoader.java @@ -40,6 +40,15 @@ public class ChildFirstURLClassLoader extends MutableURLClassLoader { this.parent = new ParentClassLoader(parent); } + /** + * Specify the grandparent if there is a need to load in the order of + * `grandparent -> urls (child) -> parent`. + */ + public ChildFirstURLClassLoader(URL[] urls, ClassLoader parent, ClassLoader grandparent) { + super(urls, grandparent); + this.parent = new ParentClassLoader(parent); + } + @Override public Class<?> loadClass(String name, boolean resolve) throws ClassNotFoundException { try { diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index b30569dc964..9327ea4d3dd 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -56,11 +56,12 @@ import org.apache.spark.util._ private[spark] class IsolatedSessionState( val sessionUUID: String, - val urlClassLoader: MutableURLClassLoader, + var urlClassLoader: MutableURLClassLoader, var replClassLoader: ClassLoader, val currentFiles: HashMap[String, Long], val currentJars: HashMap[String, Long], - val currentArchives: HashMap[String, Long]) + val currentArchives: HashMap[String, Long], + val replClassDirUri: Option[String]) /** * Spark executor, backed by a threadpool to run tasks. @@ -173,14 +174,20 @@ private[spark] class Executor( val currentFiles = new HashMap[String, Long] val currentJars = new HashMap[String, Long] val currentArchives = new HashMap[String, Long] - val urlClassLoader = createClassLoader(currentJars) + val urlClassLoader = createClassLoader(currentJars, !isDefaultState(jobArtifactState.uuid)) val replClassLoader = addReplClassLoaderIfNeeded( - urlClassLoader, jobArtifactState.replClassDirUri) + urlClassLoader, jobArtifactState.replClassDirUri, jobArtifactState.uuid) new IsolatedSessionState( jobArtifactState.uuid, urlClassLoader, replClassLoader, - currentFiles, currentJars, currentArchives) + currentFiles, + currentJars, + currentArchives, + jobArtifactState.replClassDirUri + ) } + private def isDefaultState(name: String) = name == "default" + // Classloader isolation // The default isolation group val defaultSessionState = newSessionState(JobArtifactState("default", None)) @@ -514,9 +521,8 @@ private[spark] class Executor( // Classloader isolation val isolatedSession = taskDescription.artifacts.state match { - case Some(jobArtifactState) => isolatedSessionCache.get( - jobArtifactState.uuid, - () => newSessionState(jobArtifactState)) + case Some(jobArtifactState) => + isolatedSessionCache.get(jobArtifactState.uuid, () => newSessionState(jobArtifactState)) case _ => defaultSessionState } @@ -548,6 +554,9 @@ private[spark] class Executor( taskDescription.artifacts.jars, taskDescription.artifacts.archives, isolatedSession) + // Always reset the thread class loader to ensure if any updates, all threads (not only + // the thread that updated the dependencies) can update to the new class loader. + Thread.currentThread.setContextClassLoader(isolatedSession.replClassLoader) task = ser.deserialize[Task[Any]]( taskDescription.serializedTask, Thread.currentThread.getContextClassLoader) task.localProperties = taskDescription.properties @@ -999,7 +1008,9 @@ private[spark] class Executor( * Create a ClassLoader for use in tasks, adding any JARs specified by the user or any classes * created by the interpreter to the search path */ - private def createClassLoader(currentJars: HashMap[String, Long]): MutableURLClassLoader = { + private def createClassLoader( + currentJars: HashMap[String, Long], + useStub: Boolean): MutableURLClassLoader = { // Bootstrap the list of jars with the user class path. val now = System.currentTimeMillis() userClassPath.foreach { url => @@ -1011,8 +1022,23 @@ private[spark] class Executor( val urls = userClassPath.toArray ++ currentJars.keySet.map { uri => new File(uri.split("/").last).toURI.toURL } - logInfo(s"Starting executor with user classpath (userClassPathFirst = $userClassPathFirst): " + - urls.mkString("'", ",", "'")) + createClassLoader(urls, useStub) + } + + private def createClassLoader(urls: Array[URL], useStub: Boolean): MutableURLClassLoader = { + logInfo( + s"Starting executor with user classpath (userClassPathFirst = $userClassPathFirst): " + + urls.mkString("'", ",", "'") + ) + + if (useStub && conf.get(CONNECT_SCALA_UDF_STUB_CLASSES).nonEmpty) { + createClassLoaderWithStub(urls, conf.get(CONNECT_SCALA_UDF_STUB_CLASSES)) + } else { + createClassLoader(urls) + } + } + + private def createClassLoader(urls: Array[URL]): MutableURLClassLoader = { if (userClassPathFirst) { new ChildFirstURLClassLoader(urls, systemLoader) } else { @@ -1020,20 +1046,39 @@ private[spark] class Executor( } } + private def createClassLoaderWithStub( + urls: Array[URL], + binaryName: Seq[String]): MutableURLClassLoader = { + if (userClassPathFirst) { + // user -> (sys -> stub) + val stubClassLoader = + StubClassLoader(systemLoader, binaryName) + new ChildFirstURLClassLoader(urls, stubClassLoader) + } else { + // sys -> user -> stub + val stubClassLoader = + StubClassLoader(null, binaryName) + new ChildFirstURLClassLoader(urls, stubClassLoader, systemLoader) + } + } + /** * If the REPL is in use, add another ClassLoader that will read * new classes defined by the REPL as the user types code */ private def addReplClassLoaderIfNeeded( parent: ClassLoader, - sessionClassUri: Option[String]): ClassLoader = { + sessionClassUri: Option[String], + sessionUUID: String): ClassLoader = { val classUri = sessionClassUri.getOrElse(conf.get("spark.repl.class.uri", null)) - if (classUri != null) { + val classLoader = if (classUri != null) { logInfo("Using REPL class URI: " + classUri) new ExecutorClassLoader(conf, env, classUri, parent, userClassPathFirst) } else { parent } + logInfo(s"Created or updated repl class loader $classLoader for $sessionUUID.") + classLoader } /** @@ -1048,6 +1093,7 @@ private[spark] class Executor( state: IsolatedSessionState, testStartLatch: Option[CountDownLatch] = None, testEndLatch: Option[CountDownLatch] = None): Unit = { + var updated = false; lazy val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) updateDependenciesLock.lockInterruptibly() try { @@ -1056,7 +1102,7 @@ private[spark] class Executor( // If the session ID was specified from SparkSession, it's from a Spark Connect client. // Specify a dedicated directory for Spark Connect client. - lazy val root = if (state.sessionUUID != "default") { + lazy val root = if (!isDefaultState(state.sessionUUID)) { val newDest = new File(SparkFiles.getRootDirectory(), state.sessionUUID) newDest.mkdir() newDest @@ -1101,11 +1147,21 @@ private[spark] class Executor( // Add it to our class loader val url = new File(root, localName).toURI.toURL if (!state.urlClassLoader.getURLs().contains(url)) { - logInfo(s"Adding $url to class loader") + logInfo(s"Adding $url to class loader ${state.sessionUUID}") state.urlClassLoader.addURL(url) + if (!isDefaultState(state.sessionUUID)) { + updated = true + } } } } + if (updated) { + // When a new url is added for non-default class loader, recreate the class loader + // to ensure all classes are updated. + state.urlClassLoader = createClassLoader(state.urlClassLoader.getURLs, useStub = true) + state.replClassLoader = + addReplClassLoaderIfNeeded(state.urlClassLoader, state.replClassDirUri, state.sessionUUID) + } // For testing, so we can simulate a slow file download: testEndLatch.foreach(_.await()) } finally { diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 83e64f6f8a8..ba809b7a3b1 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -2555,4 +2555,18 @@ package object config { .version("3.5.0") .booleanConf .createWithDefault(false) + + private[spark] val CONNECT_SCALA_UDF_STUB_CLASSES = + ConfigBuilder("spark.connect.scalaUdf.stubClasses") + .internal() + .doc(""" + |Comma-separated list of binary names of classes/packages that should be stubbed during + |the Scala UDF serde and execution if not found on the server classpath. + |An empty list effectively disables stubbing for all missing classes. + |By default, the server stubs classes from the Scala client package. + |""".stripMargin) + .version("3.5.0") + .stringConf + .toSequence + .createWithDefault("org.apache.spark.sql.connect.client" :: Nil) } diff --git a/core/src/main/scala/org/apache/spark/util/StubClassLoader.scala b/core/src/main/scala/org/apache/spark/util/StubClassLoader.scala new file mode 100644 index 00000000000..a0bc753f488 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/StubClassLoader.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.util + +import org.apache.xbean.asm9.{ClassWriter, Opcodes} + +/** + * [[ClassLoader]] that replaces missing classes with stubs, if the cannot be found. It will only + * do this for classes that are marked for stubbing. + * + * While this is generally not a good idea. In this particular case this is used to load lambda's + * whose capturing class contains unknown (and unneeded) classes. The lambda itself does not need + * the class and therefor is safe to replace by a stub. + */ +class StubClassLoader(parent: ClassLoader, shouldStub: String => Boolean) + extends ClassLoader(parent) { + override def findClass(name: String): Class[_] = { + if (!shouldStub(name)) { + throw new ClassNotFoundException(name) + } + val bytes = StubClassLoader.generateStub(name) + defineClass(name, bytes, 0, bytes.length) + } +} + +object StubClassLoader { + def apply(parent: ClassLoader, binaryName: Seq[String]): StubClassLoader = { + new StubClassLoader(parent, name => binaryName.exists(p => name.startsWith(p))) + } + + def generateStub(binaryName: String): Array[Byte] = { + // Convert binary names to internal names. + val name = binaryName.replace('.', '/') + val classWriter = new ClassWriter(0) + classWriter.visit( + 49, + Opcodes.ACC_PUBLIC + Opcodes.ACC_SUPER, + name, + null, + "java/lang/Object", + null) + classWriter.visitSource(name + ".java", null) + + // Generate constructor. + val ctorWriter = classWriter.visitMethod( + Opcodes.ACC_PUBLIC, + "<init>", + "()V", + null, + null) + ctorWriter.visitVarInsn(Opcodes.ALOAD, 0) + ctorWriter.visitMethodInsn( + Opcodes.INVOKESPECIAL, + "java/lang/Object", + "<init>", + "()V", + false) + + ctorWriter.visitInsn(Opcodes.RETURN) + ctorWriter.visitMaxs(1, 1) + ctorWriter.visitEnd() + classWriter.visitEnd() + classWriter.toByteArray + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org