This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 3e2b146eb81 [SPARK-45136][CONNECT] Enhance ClosureCleaner with Ammonite support 3e2b146eb81 is described below commit 3e2b146eb81d9a5727f07b58f7bb1760a71a8697 Author: Vsevolod Stepanov <vsevolod.stepa...@databricks.com> AuthorDate: Wed Oct 25 21:35:07 2023 -0400 [SPARK-45136][CONNECT] Enhance ClosureCleaner with Ammonite support ### What changes were proposed in this pull request? This PR enhances existing ClosureCleaner implementation to support cleaning closures defined in Ammonite. Please refer to [this gist](https://gist.github.com/vsevolodstep-db/b8e4d676745d6e2d047ecac291e5254c) to get more context on how Ammonite code wrapping works and what problems I'm trying to solve here. Overall, it contains these logical changes in `ClosureCleaner`: 1. Making it recognize and clean closures defined in Ammonite (previously it was checking if capturing class name starts with `$line` and ends with `$iw`, which is native Scala REPL specific thing 2. Making it clean closures if they are defined inside a user class in a REPL (see corner case 1 in the gist) 3. Making it clean nested closures properly for Ammonite REPL (see corner case 2 in the gist) 4. Making it transitively follow other Ammonite commands that are captured by the target closure. Please note that `cleanTransitively` option of `ClosureCleaner.clean()` method refers to following references transitively within enclosing command object, but it doesn't follow other command objects. As we need `ClosureCleaner` to be available in Spark Connect, I also moved the implementation to `common-utils` module. This brings a new `xbean-asm9-shaded` which is fairly small. Also, this PR moves `checkSerializable` check from `ClosureCleaner` to `SparkClosureCleaner`, as it is specific to Spark core The important changes affect `ClosureCleaner` only. They should not affect existing codepath for normal Scala closures / closures defined in a native Scala REPL and cover only closures defined in Ammonite. Also, this PR modifies SparkConnect's `UserDefinedFunction` to actually use `ClosureCleaner` and clean closures in SparkConnect ### Why are the changes needed? To properly support closures defined in Ammonite, reduce UDF payload size and avoid possible `NonSerializable` exceptions. This includes: - lambda capturing outer command object, leading in a circular dependency - lambda capturing other command objects transitively, exploding payload size ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing tests. New tests in `ReplE2ESuite` covering various scenarios using SparkConnect + Ammonite REPL to make sure closures are actually cleaned. ### Was this patch authored or co-authored using generative AI tooling? No Closes #42995 from vsevolodstep-db/SPARK-45136/closure-cleaner. Authored-by: Vsevolod Stepanov <vsevolod.stepa...@databricks.com> Signed-off-by: Herman van Hovell <her...@databricks.com> --- common/utils/pom.xml | 4 + .../org/apache/spark/util/ClosureCleaner.scala | 636 ++++++++++++++------- .../org/apache/spark/util/SparkStreamUtils.scala | 109 ++++ .../sql/expressions/UserDefinedFunction.scala | 10 +- .../spark/sql/application/ReplE2ESuite.scala | 143 +++++ .../CheckConnectJvmClientCompatibility.scala | 8 + core/pom.xml | 4 - .../main/scala/org/apache/spark/SparkContext.scala | 2 +- .../apache/spark/util/SparkClosureCleaner.scala | 49 ++ .../main/scala/org/apache/spark/util/Utils.scala | 85 +-- .../apache/spark/util/ClosureCleanerSuite.scala | 2 +- .../apache/spark/util/ClosureCleanerSuite2.scala | 4 +- project/MimaExcludes.scala | 4 +- .../catalyst/encoders/ExpressionEncoderSuite.scala | 4 +- .../org/apache/spark/streaming/StateSpec.scala | 6 +- 15 files changed, 756 insertions(+), 314 deletions(-) diff --git a/common/utils/pom.xml b/common/utils/pom.xml index 37d1ea48d97..44cb30a19ff 100644 --- a/common/utils/pom.xml +++ b/common/utils/pom.xml @@ -39,6 +39,10 @@ <groupId>org.apache.spark</groupId> <artifactId>spark-tags_${scala.binary.version}</artifactId> </dependency> + <dependency> + <groupId>org.apache.xbean</groupId> + <artifactId>xbean-asm9-shaded</artifactId> + </dependency> <dependency> <groupId>com.fasterxml.jackson.core</groupId> <artifactId>jackson-databind</artifactId> diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/common/utils/src/main/scala/org/apache/spark/util/ClosureCleaner.scala similarity index 61% rename from core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala rename to common/utils/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 29fb0206f90..ffa2f0e60b2 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/common/utils/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -21,7 +21,7 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import java.lang.invoke.{MethodHandleInfo, SerializedLambda} import java.lang.reflect.{Field, Modifier} -import scala.collection.mutable.{Map, Set, Stack} +import scala.collection.mutable.{Map, Queue, Set, Stack} import scala.jdk.CollectionConverters._ import org.apache.commons.lang3.ClassUtils @@ -29,14 +29,13 @@ import org.apache.xbean.asm9.{ClassReader, ClassVisitor, Handle, MethodVisitor, import org.apache.xbean.asm9.Opcodes._ import org.apache.xbean.asm9.tree.{ClassNode, MethodNode} -import org.apache.spark.{SparkEnv, SparkException} +import org.apache.spark.SparkException import org.apache.spark.internal.Logging /** * A cleaner that renders closures serializable if they can be done so safely. */ private[spark] object ClosureCleaner extends Logging { - // Get an ASM class reader for a given class from the JAR that loaded it private[util] def getClassReader(cls: Class[_]): ClassReader = { // Copy data over, before delegating to ClassReader - else we can run out of open file handles. @@ -46,11 +45,18 @@ private[spark] object ClosureCleaner extends Logging { null } else { val baos = new ByteArrayOutputStream(128) - Utils.copyStream(resourceStream, baos, true) + + SparkStreamUtils.copyStream(resourceStream, baos, closeStreams = true) new ClassReader(new ByteArrayInputStream(baos.toByteArray)) } } + private[util] def isAmmoniteCommandOrHelper(clazz: Class[_]): Boolean = clazz.getName.matches( + """^ammonite\.\$sess\.cmd[0-9]*(\$Helper\$?)?""") + + private[util] def isDefinedInAmmonite(clazz: Class[_]): Boolean = clazz.getName.matches( + """^ammonite\.\$sess\.cmd[0-9]*.*""") + // Check whether a class represents a Scala closure private def isClosure(cls: Class[_]): Boolean = { cls.getName.contains("$anonfun$") @@ -146,23 +152,6 @@ private[spark] object ClosureCleaner extends Logging { clone } - /** - * Clean the given closure in place. - * - * More specifically, this renders the given closure serializable as long as it does not - * explicitly reference unserializable objects. - * - * @param closure the closure to clean - * @param checkSerializable whether to verify that the closure is serializable after cleaning - * @param cleanTransitively whether to clean enclosing closures transitively - */ - def clean( - closure: AnyRef, - checkSerializable: Boolean = true, - cleanTransitively: Boolean = true): Unit = { - clean(closure, checkSerializable, cleanTransitively, Map.empty) - } - /** * Helper method to clean the given closure in place. * @@ -198,18 +187,15 @@ private[spark] object ClosureCleaner extends Logging { * pointer of a cloned scope "one" and set it the parent of scope "two", such that scope "two" * no longer references SomethingNotSerializable transitively. * - * @param func the starting closure to clean - * @param checkSerializable whether to verify that the closure is serializable after cleaning + * @param func the starting closure to clean * @param cleanTransitively whether to clean enclosing closures transitively - * @param accessedFields a map from a class to a set of its fields that are accessed by - * the starting closure + * @param accessedFields a map from a class to a set of its fields that are accessed by + * the starting closure */ - private def clean( + private[spark] def clean( func: AnyRef, - checkSerializable: Boolean, cleanTransitively: Boolean, - accessedFields: Map[Class[_], Set[String]]): Unit = { - + accessedFields: Map[Class[_], Set[String]]): Boolean = { // indylambda check. Most likely to be the case with 2.12, 2.13 // so we check first // non LMF-closures should be less frequent from now on @@ -217,131 +203,18 @@ private[spark] object ClosureCleaner extends Logging { if (!isClosure(func.getClass) && maybeIndylambdaProxy.isEmpty) { logDebug(s"Expected a closure; got ${func.getClass.getName}") - return + return false } // TODO: clean all inner closures first. This requires us to find the inner objects. // TODO: cache outerClasses / innerClasses / accessedFields if (func == null) { - return + return false } if (maybeIndylambdaProxy.isEmpty) { - logDebug(s"+++ Cleaning closure $func (${func.getClass.getName}) +++") - - // A list of classes that represents closures enclosed in the given one - val innerClasses = getInnerClosureClasses(func) - - // A list of enclosing objects and their respective classes, from innermost to outermost - // An outer object at a given index is of type outer class at the same index - val (outerClasses, outerObjects) = getOuterClassesAndObjects(func) - - // For logging purposes only - val declaredFields = func.getClass.getDeclaredFields - val declaredMethods = func.getClass.getDeclaredMethods - - if (log.isDebugEnabled) { - logDebug(s" + declared fields: ${declaredFields.size}") - declaredFields.foreach { f => logDebug(s" $f") } - logDebug(s" + declared methods: ${declaredMethods.size}") - declaredMethods.foreach { m => logDebug(s" $m") } - logDebug(s" + inner classes: ${innerClasses.size}") - innerClasses.foreach { c => logDebug(s" ${c.getName}") } - logDebug(s" + outer classes: ${outerClasses.size}" ) - outerClasses.foreach { c => logDebug(s" ${c.getName}") } - } - - // Fail fast if we detect return statements in closures - getClassReader(func.getClass).accept(new ReturnStatementFinder(), 0) - - // If accessed fields is not populated yet, we assume that - // the closure we are trying to clean is the starting one - if (accessedFields.isEmpty) { - logDebug(" + populating accessed fields because this is the starting closure") - // Initialize accessed fields with the outer classes first - // This step is needed to associate the fields to the correct classes later - initAccessedFields(accessedFields, outerClasses) - - // Populate accessed fields by visiting all fields and methods accessed by this and - // all of its inner closures. If transitive cleaning is enabled, this may recursively - // visits methods that belong to other classes in search of transitively referenced fields. - for (cls <- func.getClass :: innerClasses) { - getClassReader(cls).accept(new FieldAccessFinder(accessedFields, cleanTransitively), 0) - } - } - - logDebug(s" + fields accessed by starting closure: ${accessedFields.size} classes") - accessedFields.foreach { f => logDebug(" " + f) } - - // List of outer (class, object) pairs, ordered from outermost to innermost - // Note that all outer objects but the outermost one (first one in this list) must be closures - var outerPairs: List[(Class[_], AnyRef)] = outerClasses.zip(outerObjects).reverse - var parent: AnyRef = null - if (outerPairs.nonEmpty) { - val outermostClass = outerPairs.head._1 - val outermostObject = outerPairs.head._2 - - if (isClosure(outermostClass)) { - logDebug(s" + outermost object is a closure, so we clone it: ${outermostClass}") - } else if (outermostClass.getName.startsWith("$line")) { - // SPARK-14558: if the outermost object is a REPL line object, we should clone - // and clean it as it may carry a lot of unnecessary information, - // e.g. hadoop conf, spark conf, etc. - logDebug(s" + outermost object is a REPL line object, so we clone it:" + - s" ${outermostClass}") - } else { - // The closure is ultimately nested inside a class; keep the object of that - // class without cloning it since we don't want to clone the user's objects. - // Note that we still need to keep around the outermost object itself because - // we need it to clone its child closure later (see below). - logDebug(s" + outermost object is not a closure or REPL line object," + - s" so do not clone it: ${outermostClass}") - parent = outermostObject // e.g. SparkContext - outerPairs = outerPairs.tail - } - } else { - logDebug(" + there are no enclosing objects!") - } - - // Clone the closure objects themselves, nulling out any fields that are not - // used in the closure we're working on or any of its inner closures. - for ((cls, obj) <- outerPairs) { - logDebug(s" + cloning instance of class ${cls.getName}") - // We null out these unused references by cloning each object and then filling in all - // required fields from the original object. We need the parent here because the Java - // language specification requires the first constructor parameter of any closure to be - // its enclosing object. - val clone = cloneAndSetFields(parent, obj, cls, accessedFields) - - // If transitive cleaning is enabled, we recursively clean any enclosing closure using - // the already populated accessed fields map of the starting closure - if (cleanTransitively && isClosure(clone.getClass)) { - logDebug(s" + cleaning cloned closure recursively (${cls.getName})") - // No need to check serializable here for the outer closures because we're - // only interested in the serializability of the starting closure - clean(clone, checkSerializable = false, cleanTransitively, accessedFields) - } - parent = clone - } - - // Update the parent pointer ($outer) of this closure - if (parent != null) { - val field = func.getClass.getDeclaredField("$outer") - field.setAccessible(true) - // If the starting closure doesn't actually need our enclosing object, then just null it out - if (accessedFields.contains(func.getClass) && - !accessedFields(func.getClass).contains("$outer")) { - logDebug(s" + the starting closure doesn't actually need $parent, so we null it out") - field.set(func, null) - } else { - // Update this closure's parent pointer to point to our enclosing object, - // which could either be a cloned closure or the original user object - field.set(func, parent) - } - } - - logDebug(s" +++ closure $func (${func.getClass.getName}) is now cleaned +++") + cleanNonIndyLambdaClosure(func, cleanTransitively, accessedFields) } else { val lambdaProxy = maybeIndylambdaProxy.get val implMethodName = lambdaProxy.getImplMethodName @@ -359,62 +232,339 @@ private[spark] object ClosureCleaner extends Logging { val capturingClassReader = getClassReader(capturingClass) capturingClassReader.accept(new ReturnStatementFinder(Option(implMethodName)), 0) - val isClosureDeclaredInScalaRepl = capturingClassName.startsWith("$line") && - capturingClassName.endsWith("$iw") - val outerThisOpt = if (lambdaProxy.getCapturedArgCount > 0) { - Option(lambdaProxy.getCapturedArg(0)) + val outerThis = if (lambdaProxy.getCapturedArgCount > 0) { + // only need to clean when there is an enclosing non-null "this" captured by the closure + Option(lambdaProxy.getCapturedArg(0)).getOrElse(return false) } else { - None + return false } - // only need to clean when there is an enclosing "this" captured by the closure, and it - // should be something cleanable, i.e. a Scala REPL line object - val needsCleaning = isClosureDeclaredInScalaRepl && - outerThisOpt.isDefined && outerThisOpt.get.getClass.getName == capturingClassName - - if (needsCleaning) { - // indylambda closures do not reference enclosing closures via an `$outer` chain, so no - // transitive cleaning on the `$outer` chain is needed. - // Thus clean() shouldn't be recursively called with a non-empty accessedFields. - assert(accessedFields.isEmpty) - - initAccessedFields(accessedFields, Seq(capturingClass)) - IndylambdaScalaClosures.findAccessedFields( - lambdaProxy, classLoader, accessedFields, cleanTransitively) - - logDebug(s" + fields accessed by starting closure: ${accessedFields.size} classes") - accessedFields.foreach { f => logDebug(" " + f) } - - if (accessedFields(capturingClass).size < capturingClass.getDeclaredFields.length) { - // clone and clean the enclosing `this` only when there are fields to null out - - val outerThis = outerThisOpt.get - - logDebug(s" + cloning instance of REPL class $capturingClassName") - val clonedOuterThis = cloneAndSetFields( - parent = null, outerThis, capturingClass, accessedFields) - - val outerField = func.getClass.getDeclaredField("arg$1") - // SPARK-37072: When Java 17 is used and `outerField` is read-only, - // the content of `outerField` cannot be set by reflect api directly. - // But we can remove the `final` modifier of `outerField` before set value - // and reset the modifier after set value. - val modifiersField = getFinalModifiersFieldForJava17(outerField) - modifiersField - .foreach(m => m.setInt(outerField, outerField.getModifiers & ~Modifier.FINAL)) - outerField.setAccessible(true) - outerField.set(func, clonedOuterThis) - modifiersField - .foreach(m => m.setInt(outerField, outerField.getModifiers | Modifier.FINAL)) + // clean only if enclosing "this" is something cleanable, i.e. a Scala REPL line object or + // Ammonite command helper object. + // For Ammonite closures, we do not care about actual capturing class name, + // as closure needs to be cleaned if it captures Ammonite command helper object + if (isDefinedInAmmonite(outerThis.getClass)) { + // If outerThis is a lambda, we have to clean that instead + IndylambdaScalaClosures.getSerializationProxy(outerThis).foreach { _ => + return clean(outerThis, cleanTransitively, accessedFields) + } + cleanupAmmoniteReplClosure(func, lambdaProxy, outerThis, cleanTransitively) + } else { + val isClosureDeclaredInScalaRepl = capturingClassName.startsWith("$line") && + capturingClassName.endsWith("$iw") + if (isClosureDeclaredInScalaRepl && outerThis.getClass.getName == capturingClassName) { + assert(accessedFields.isEmpty) + cleanupScalaReplClosure(func, lambdaProxy, outerThis, cleanTransitively) } } logDebug(s" +++ indylambda closure ($implMethodName) is now cleaned +++") } - if (checkSerializable) { - ensureSerializable(func) + true + } + + /** + * Cleans non-indylambda closure in place + * + * @param func the starting closure to clean + * @param cleanTransitively whether to clean enclosing closures transitively + * @param accessedFields a map from a class to a set of its fields that are accessed by + * the starting closure + */ + private def cleanNonIndyLambdaClosure( + func: AnyRef, + cleanTransitively: Boolean, + accessedFields: Map[Class[_], Set[String]]): Unit = { + logDebug(s"+++ Cleaning closure $func (${func.getClass.getName}) +++") + + // A list of classes that represents closures enclosed in the given one + val innerClasses = getInnerClosureClasses(func) + + // A list of enclosing objects and their respective classes, from innermost to outermost + // An outer object at a given index is of type outer class at the same index + val (outerClasses, outerObjects) = getOuterClassesAndObjects(func) + + // For logging purposes only + val declaredFields = func.getClass.getDeclaredFields + val declaredMethods = func.getClass.getDeclaredMethods + + if (log.isDebugEnabled) { + logDebug(s" + declared fields: ${declaredFields.size}") + declaredFields.foreach { f => logDebug(s" $f") } + logDebug(s" + declared methods: ${declaredMethods.size}") + declaredMethods.foreach { m => logDebug(s" $m") } + logDebug(s" + inner classes: ${innerClasses.size}") + innerClasses.foreach { c => logDebug(s" ${c.getName}") } + logDebug(s" + outer classes: ${outerClasses.size}") + outerClasses.foreach { c => logDebug(s" ${c.getName}") } } + + // Fail fast if we detect return statements in closures + getClassReader(func.getClass).accept(new ReturnStatementFinder(), 0) + + // If accessed fields is not populated yet, we assume that + // the closure we are trying to clean is the starting one + if (accessedFields.isEmpty) { + logDebug(" + populating accessed fields because this is the starting closure") + // Initialize accessed fields with the outer classes first + // This step is needed to associate the fields to the correct classes later + initAccessedFields(accessedFields, outerClasses) + + // Populate accessed fields by visiting all fields and methods accessed by this and + // all of its inner closures. If transitive cleaning is enabled, this may recursively + // visits methods that belong to other classes in search of transitively referenced fields. + for (cls <- func.getClass :: innerClasses) { + getClassReader(cls).accept(new FieldAccessFinder(accessedFields, cleanTransitively), 0) + } + } + + logDebug(s" + fields accessed by starting closure: ${accessedFields.size} classes") + accessedFields.foreach { f => logDebug(" " + f) } + + // List of outer (class, object) pairs, ordered from outermost to innermost + // Note that all outer objects but the outermost one (first one in this list) must be closures + var outerPairs: List[(Class[_], AnyRef)] = outerClasses.zip(outerObjects).reverse + var parent: AnyRef = null + if (outerPairs.nonEmpty) { + val outermostClass = outerPairs.head._1 + val outermostObject = outerPairs.head._2 + + if (isClosure(outermostClass)) { + logDebug(s" + outermost object is a closure, so we clone it: ${outermostClass}") + } else if (outermostClass.getName.startsWith("$line")) { + // SPARK-14558: if the outermost object is a REPL line object, we should clone + // and clean it as it may carry a lot of unnecessary information, + // e.g. hadoop conf, spark conf, etc. + logDebug(s" + outermost object is a REPL line object, so we clone it:" + + s" ${outermostClass}") + } else { + // The closure is ultimately nested inside a class; keep the object of that + // class without cloning it since we don't want to clone the user's objects. + // Note that we still need to keep around the outermost object itself because + // we need it to clone its child closure later (see below). + logDebug(s" + outermost object is not a closure or REPL line object," + + s" so do not clone it: ${outermostClass}") + parent = outermostObject // e.g. SparkContext + outerPairs = outerPairs.tail + } + } else { + logDebug(" + there are no enclosing objects!") + } + + // Clone the closure objects themselves, nulling out any fields that are not + // used in the closure we're working on or any of its inner closures. + for ((cls, obj) <- outerPairs) { + logDebug(s" + cloning instance of class ${cls.getName}") + // We null out these unused references by cloning each object and then filling in all + // required fields from the original object. We need the parent here because the Java + // language specification requires the first constructor parameter of any closure to be + // its enclosing object. + val clone = cloneAndSetFields(parent, obj, cls, accessedFields) + + // If transitive cleaning is enabled, we recursively clean any enclosing closure using + // the already populated accessed fields map of the starting closure + if (cleanTransitively && isClosure(clone.getClass)) { + logDebug(s" + cleaning cloned closure recursively (${cls.getName})") + clean(clone, cleanTransitively, accessedFields) + } + parent = clone + } + + // Update the parent pointer ($outer) of this closure + if (parent != null) { + val field = func.getClass.getDeclaredField("$outer") + field.setAccessible(true) + // If the starting closure doesn't actually need our enclosing object, then just null it out + if (accessedFields.contains(func.getClass) && + !accessedFields(func.getClass).contains("$outer")) { + logDebug(s" + the starting closure doesn't actually need $parent, so we null it out") + field.set(func, null) + } else { + // Update this closure's parent pointer to point to our enclosing object, + // which could either be a cloned closure or the original user object + field.set(func, parent) + } + } + + logDebug(s" +++ closure $func (${func.getClass.getName}) is now cleaned +++") + } + + /** + * Null out fields of enclosing class which are not actually accessed by a closure + * @param func the starting closure to clean + * @param lambdaProxy starting closure proxy + * @param outerThis lambda enclosing class + * @param cleanTransitively whether to clean enclosing closures transitively + */ + private def cleanupScalaReplClosure( + func: AnyRef, + lambdaProxy: SerializedLambda, + outerThis: AnyRef, + cleanTransitively: Boolean): Unit = { + + val capturingClass = outerThis.getClass + val accessedFields: Map[Class[_], Set[String]] = Map.empty + initAccessedFields(accessedFields, Seq(capturingClass)) + + IndylambdaScalaClosures.findAccessedFields( + lambdaProxy, + func.getClass.getClassLoader, + accessedFields, + Map.empty, + Map.empty, + cleanTransitively) + + logDebug(s" + fields accessed by starting closure: ${accessedFields.size} classes") + accessedFields.foreach { f => logDebug(" " + f) } + + if (accessedFields(capturingClass).size < capturingClass.getDeclaredFields.length) { + // clone and clean the enclosing `this` only when there are fields to null out + logDebug(s" + cloning instance of REPL class ${capturingClass.getName}") + val clonedOuterThis = cloneAndSetFields( + parent = null, outerThis, capturingClass, accessedFields) + + val outerField = func.getClass.getDeclaredField("arg$1") + // SPARK-37072: When Java 17 is used and `outerField` is read-only, + // the content of `outerField` cannot be set by reflect api directly. + // But we can remove the `final` modifier of `outerField` before set value + // and reset the modifier after set value. + setFieldAndIgnoreModifiers(func, outerField, clonedOuterThis) + } + } + + + /** + * Cleans up Ammonite closures and nulls out fields captured from cmd & cmd$Helper objects + * but not actually accessed by the closure. To achieve this, it does: + * 1. Identify all accessed Ammonite cmd & cmd$Helper objects + * 2. Clone all accessed cmdX objects + * 3. Clone all accessed cmdX$Helper objects and set their $outer field to the cmdX clone + * 4. Iterate over these clones and set all other accessed fields to + * - a clone, if the field refers to an Ammonite object + * - a previous value otherwise + * 5. In case if capturing object is an inner class of Ammonite cmd$Helper object, clone & update + * this capturing object as well + * + * As a result: + * - For all accessed cmdX objects all their references to cmdY$Helper objects are + * either nulled out or updated to cmdY clone + * - For cmdX$Helper objects it means that variables defined in this command are + * nulled out if not accessed + * - lambda enclosing class is cleaned up as it's done for normal Scala closures + * + * @param func the starting closure to clean + * @param lambdaProxy starting closure proxy + * @param outerThis lambda enclosing class + * @param cleanTransitively whether to clean enclosing closures transitively + */ + private def cleanupAmmoniteReplClosure( + func: AnyRef, + lambdaProxy: SerializedLambda, + outerThis: AnyRef, + cleanTransitively: Boolean): Unit = { + + val accessedFields: Map[Class[_], Set[String]] = Map.empty + initAccessedFields(accessedFields, Seq(outerThis.getClass)) + + // Ammonite generates 3 classes for a command number X: + // - cmdX class containing all dependencies needed to execute the command + // (i.e. previous command helpers) + // - cmdX$Helper - inner class of cmdX - containing the user code. It pulls + // required dependencies (i.e. variables defined in other commands) from outer command + // - cmdX companion object holding an instance of cmdX and cmdX$Helper classes. + // Here, we care only about command objects and their helpers, companion objects are + // not captured by closure + + // instances of cmdX and cmdX$Helper + val ammCmdInstances: Map[Class[_], AnyRef] = Map.empty + // fields accessed in those commands + val accessedAmmCmdFields: Map[Class[_], Set[String]] = Map.empty + // outer class may be either Ammonite cmd / cmd$Helper class or an inner class + // defined in a user code. We need to clean up Ammonite classes only + if (isAmmoniteCommandOrHelper(outerThis.getClass)) { + ammCmdInstances(outerThis.getClass) = outerThis + accessedAmmCmdFields(outerThis.getClass) = Set.empty + } + + IndylambdaScalaClosures.findAccessedFields( + lambdaProxy, + func.getClass.getClassLoader, + accessedFields, + accessedAmmCmdFields, + ammCmdInstances, + cleanTransitively) + + logTrace(s" + command fields accessed by starting closure: " + + s"${accessedAmmCmdFields.size} classes") + accessedAmmCmdFields.foreach { f => logTrace(" " + f) } + + val cmdClones = Map[Class[_], AnyRef]() + for ((cmdClass, _) <- ammCmdInstances if !cmdClass.getName.contains("Helper")) { + logDebug(s" + Cloning instance of Ammonite command class ${cmdClass.getName}") + cmdClones(cmdClass) = instantiateClass(cmdClass, enclosingObject = null) + } + for ((cmdHelperClass, cmdHelperInstance) <- ammCmdInstances + if cmdHelperClass.getName.contains("Helper")) { + val cmdHelperOuter = cmdHelperClass.getDeclaredFields + .find(_.getName == "$outer") + .map { field => + field.setAccessible(true) + field.get(cmdHelperInstance) + } + val outerClone = cmdHelperOuter.flatMap(o => cmdClones.get(o.getClass)).orNull + logDebug(s" + Cloning instance of Ammonite command helper class ${cmdHelperClass.getName}") + cmdClones(cmdHelperClass) = + instantiateClass(cmdHelperClass, enclosingObject = outerClone) + } + + // set accessed fields + for ((_, cmdClone) <- cmdClones) { + val cmdClass = cmdClone.getClass + val accessedFields = accessedAmmCmdFields(cmdClass) + for (field <- cmdClone.getClass.getDeclaredFields + // outer fields were initialized during clone construction + if accessedFields.contains(field.getName) && field.getName != "$outer") { + // get command clone if exists, otherwise use an original field value + val value = cmdClones.getOrElse(field.getType, { + field.setAccessible(true) + field.get(ammCmdInstances(cmdClass)) + }) + setFieldAndIgnoreModifiers(cmdClone, field, value) + } + } + + val outerThisClone = if (!isAmmoniteCommandOrHelper(outerThis.getClass)) { + // if outer class is not Ammonite helper / command object then is was not cloned + // in the code above. We still need to clone it and update accessed fields + logDebug(s" + Cloning instance of lambda capturing class ${outerThis.getClass.getName}") + val clone = cloneAndSetFields(parent = null, outerThis, outerThis.getClass, accessedFields) + // making sure that the code below will update references to Ammonite objects if they exist + for (field <- outerThis.getClass.getDeclaredFields) { + field.setAccessible(true) + cmdClones.get(field.getType).foreach { value => + setFieldAndIgnoreModifiers(clone, field, value) + } + } + clone + } else { + cmdClones(outerThis.getClass) + } + + val outerField = func.getClass.getDeclaredField("arg$1") + // update lambda capturing class reference + setFieldAndIgnoreModifiers(func, outerField, outerThisClone) + } + + private def setFieldAndIgnoreModifiers(obj: AnyRef, field: Field, value: AnyRef): Unit = { + val modifiersField = getFinalModifiersFieldForJava17(field) + modifiersField + .foreach(m => m.setInt(field, field.getModifiers & ~Modifier.FINAL)) + field.setAccessible(true) + field.set(obj, value) + + modifiersField + .foreach(m => m.setInt(field, field.getModifiers | Modifier.FINAL)) } /** @@ -434,19 +584,7 @@ private[spark] object ClosureCleaner extends Logging { } else None } - private def ensureSerializable(func: AnyRef): Unit = { - try { - if (SparkEnv.get != null) { - SparkEnv.get.closureSerializer.newInstance().serialize(func) - } - } catch { - case ex: Exception => throw new SparkException("Task not serializable", ex) - } - } - - private def instantiateClass( - cls: Class[_], - enclosingObject: AnyRef): AnyRef = { + private def instantiateClass(cls: Class[_], enclosingObject: AnyRef): AnyRef = { // Use reflection to instantiate object without calling constructor val rf = sun.reflect.ReflectionFactory.getReflectionFactory() val parentCtor = classOf[java.lang.Object].getDeclaredConstructor() @@ -561,6 +699,9 @@ private[spark] object IndylambdaScalaClosures extends Logging { * same for all three combined, so they can be fused together easily while maintaining the same * ordering as the existing implementation. * + * It also visits transitively Ammonite cmd and cmd%Helper objects it encounters + * and populates accessed fields for them to be able to clean up these as well + * * Precondition: this function expects the `accessedFields` to be populated with all known * outer classes and their super classes to be in the map as keys, e.g. * initializing via ClosureCleaner.initAccessedFields. @@ -630,6 +771,8 @@ private[spark] object IndylambdaScalaClosures extends Logging { lambdaProxy: SerializedLambda, lambdaClassLoader: ClassLoader, accessedFields: Map[Class[_], Set[String]], + accessedAmmCmdFields: Map[Class[_], Set[String]], + ammCmdInstances: Map[Class[_], AnyRef], findTransitively: Boolean): Unit = { // We may need to visit the same class multiple times for different methods on it, and we'll @@ -642,15 +785,30 @@ private[spark] object IndylambdaScalaClosures extends Logging { // scalastyle:off classforname val clazz = Class.forName(classExternalName, false, lambdaClassLoader) // scalastyle:on classforname - val classNode = new ClassNode() - val classReader = ClosureCleaner.getClassReader(clazz) - classReader.accept(classNode, 0) - for (m <- classNode.methods.asScala) { - methodNodeById(MethodIdentifier(clazz, m.name, m.desc)) = m + def getClassNode(clazz: Class[_]): ClassNode = { + val classNode = new ClassNode() + val classReader = ClosureCleaner.getClassReader(clazz) + classReader.accept(classNode, 0) + classNode } - (clazz, classNode) + var curClazz = clazz + // we need to add superclass methods as well + // e.g. consider the following closure: + // object Enclosing { + // val closure = () => getClass.getName + // } + // To scan this closure properly, we need to add Object.getClass method + // to methodNodeById map + while (curClazz != null) { + for (m <- getClassNode(curClazz).methods.asScala) { + methodNodeById(MethodIdentifier(clazz, m.name, m.desc)) = m + } + curClazz = curClazz.getSuperclass + } + + (clazz, getClassNode(clazz)) }) classInfo } @@ -674,21 +832,55 @@ private[spark] object IndylambdaScalaClosures extends Logging { // to better find and track field accesses. val trackedClassInternalNames = Set[String](implClassInternalName) - // Depth-first search for inner closures and track the fields that were accessed in them. + // Breadth-first search for inner closures and track the fields that were accessed in them. // Start from the lambda body's implementation method, follow method invocations val visited = Set.empty[MethodIdentifier[_]] - val stack = Stack[MethodIdentifier[_]](implMethodId) + // Depth-first search will not work there. To make addAmmoniteCommandFieldsToTracking to work + // we need to process objects in order they appear in the reference tree. + // E.g. if there was a reference chain a -> b -> c, then DFS will process these nodes in order + // a -> c -> b. However, to initialize ammCmdInstances(c.getClass) we need to process node b + // first. + val queue = Queue[MethodIdentifier[_]](implMethodId) def pushIfNotVisited(methodId: MethodIdentifier[_]): Unit = { if (!visited.contains(methodId)) { - stack.push(methodId) + queue.enqueue(methodId) } } - while (!stack.isEmpty) { - val currentId = stack.pop() + def addAmmoniteCommandFieldsToTracking(currentClass: Class[_]): Unit = { + // get an instance of currentClass. It can be either lambda enclosing this + // or another already processed Ammonite object + val currentInstance = if (currentClass == lambdaProxy.getCapturedArg(0).getClass) { + Some(lambdaProxy.getCapturedArg(0)) + } else { + // This key exists if we encountered a non-null reference to `currentClass` before + // as we're processing nodes with a breadth-first search (see comment above) + ammCmdInstances.get(currentClass) + } + currentInstance.foreach { cmdInstance => + // track only cmdX and cmdX$Helper objects generated by Ammonite + for (otherCmdField <- cmdInstance.getClass.getDeclaredFields + if ClosureCleaner.isAmmoniteCommandOrHelper(otherCmdField.getType)) { + otherCmdField.setAccessible(true) + val otherCmdHelperRef = otherCmdField.get(cmdInstance) + val otherCmdClass = otherCmdField.getType + // Ammonite is clever enough to sometimes nullify references to unused commands. + // Ignoring these references for simplicity + if (otherCmdHelperRef != null && !ammCmdInstances.contains(otherCmdClass)) { + logTrace(s" started tracking ${otherCmdClass.getName} Ammonite object") + ammCmdInstances(otherCmdClass) = otherCmdHelperRef + accessedAmmCmdFields(otherCmdClass) = Set() + } + } + } + } + + while (queue.nonEmpty) { + val currentId = queue.dequeue() visited += currentId val currentClass = currentId.cls + addAmmoniteCommandFieldsToTracking(currentClass) val currentMethodNode = methodNodeById(currentId) logTrace(s" scanning ${currentId.cls.getName}.${currentId.name}${currentId.desc}") currentMethodNode.accept(new MethodVisitor(ASM9) { @@ -704,6 +896,10 @@ private[spark] object IndylambdaScalaClosures extends Logging { logTrace(s" found field access $name on $ownerExternalName") accessedFields(cl) += name } + for (cl <- accessedAmmCmdFields.keys if cl.getName == ownerExternalName) { + logTrace(s" found Ammonite command field access $name on $ownerExternalName") + accessedAmmCmdFields(cl) += name + } } } @@ -714,6 +910,10 @@ private[spark] object IndylambdaScalaClosures extends Logging { logTrace(s" found intra class call to $ownerExternalName.$name$desc") // could be invoking a helper method or a field accessor method, just follow it. pushIfNotVisited(MethodIdentifier(currentClass, name, desc)) + } else if (owner.startsWith("ammonite/$sess/cmd")) { + // we're inside Ammonite command / command helper object, track all calls from here + val classInfo = getOrUpdateClassInfo(owner) + pushIfNotVisited(MethodIdentifier(classInfo._1, name, desc)) } else if (isInnerClassCtorCapturingOuter( op, owner, name, desc, currentClassInternalName)) { // Discover inner classes. @@ -894,8 +1094,10 @@ private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM if (op == INVOKESPECIAL && name == "<init>" && argTypes.length > 0 && argTypes(0).toString.startsWith("L") // is it an object? && argTypes(0).getInternalName == myName) { - output += Utils.classForName(owner.replace('/', '.'), - initialize = false, noSparkClassLoader = true) + output += SparkClassUtils.classForName( + owner.replace('/', '.'), + initialize = false, + noSparkClassLoader = true) } } } diff --git a/common/utils/src/main/scala/org/apache/spark/util/SparkStreamUtils.scala b/common/utils/src/main/scala/org/apache/spark/util/SparkStreamUtils.scala new file mode 100644 index 00000000000..b9148901f1a --- /dev/null +++ b/common/utils/src/main/scala/org/apache/spark/util/SparkStreamUtils.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.util + +import java.io.{FileInputStream, FileOutputStream, InputStream, OutputStream} +import java.nio.channels.{FileChannel, WritableByteChannel} + +import org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally + +private[spark] trait SparkStreamUtils { + + /** + * Copy all data from an InputStream to an OutputStream. NIO way of file stream to file stream + * copying is disabled by default unless explicitly set transferToEnabled as true, the parameter + * transferToEnabled should be configured by spark.file.transferTo = [true|false]. + */ + def copyStream( + in: InputStream, + out: OutputStream, + closeStreams: Boolean = false, + transferToEnabled: Boolean = false): Long = { + tryWithSafeFinally { + (in, out) match { + case (input: FileInputStream, output: FileOutputStream) if transferToEnabled => + // When both streams are File stream, use transferTo to improve copy performance. + val inChannel = input.getChannel + val outChannel = output.getChannel + val size = inChannel.size() + copyFileStreamNIO(inChannel, outChannel, 0, size) + size + case (input, output) => + var count = 0L + val buf = new Array[Byte](8192) + var n = 0 + while (n != -1) { + n = input.read(buf) + if (n != -1) { + output.write(buf, 0, n) + count += n + } + } + count + } + } { + if (closeStreams) { + try { + in.close() + } finally { + out.close() + } + } + } + } + + def copyFileStreamNIO( + input: FileChannel, + output: WritableByteChannel, + startPosition: Long, + bytesToCopy: Long): Unit = { + val outputInitialState = output match { + case outputFileChannel: FileChannel => + Some((outputFileChannel.position(), outputFileChannel)) + case _ => None + } + var count = 0L + // In case transferTo method transferred less data than we have required. + while (count < bytesToCopy) { + count += input.transferTo(count + startPosition, bytesToCopy - count, output) + } + assert( + count == bytesToCopy, + s"request to copy $bytesToCopy bytes, but actually copied $count bytes.") + + // Check the position after transferTo loop to see if it is in the right position and + // give user information if not. + // Position will not be increased to the expected length after calling transferTo in + // kernel version 2.6.32, this issue can be seen in + // https://bugs.openjdk.java.net/browse/JDK-7052359 + // This will lead to stream corruption issue when using sort-based shuffle (SPARK-3948). + outputInitialState.foreach { case (initialPos, outputFileChannel) => + val finalPos = outputFileChannel.position() + val expectedPos = initialPos + bytesToCopy + assert( + finalPos == expectedPos, + s""" + |Current position $finalPos do not equal to expected position $expectedPos + |after transferTo, please check your kernel version to see if it is 2.6.32, + |this is a kernel bug which will lead to unexpected behavior when using transferTo. + |You can set spark.file.transferTo = false to disable this NIO feature. + """.stripMargin) + } + } +} + +private [spark] object SparkStreamUtils extends SparkStreamUtils diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index dcc038eb51d..c4431e9a87f 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.expressions +import scala.collection.mutable import scala.jdk.CollectionConverters._ import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal @@ -29,7 +30,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder} import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, UdfPacket} import org.apache.spark.sql.types.DataType -import org.apache.spark.util.{SparkClassUtils, SparkSerDeUtils} +import org.apache.spark.util.{ClosureCleaner, SparkClassUtils, SparkSerDeUtils} /** * A user-defined function. To create one, use the `udf` functions in `functions`. @@ -183,6 +184,7 @@ object ScalarUserDefinedFunction { function: AnyRef, inputEncoders: Seq[AgnosticEncoder[_]], outputEncoder: AgnosticEncoder[_]): ScalarUserDefinedFunction = { + SparkConnectClosureCleaner.clean(function) val udfPacketBytes = SparkSerDeUtils.serialize(UdfPacket(function, inputEncoders, outputEncoder)) checkDeserializable(udfPacketBytes) @@ -202,3 +204,9 @@ object ScalarUserDefinedFunction { outputEncoder = RowEncoder.encoderForDataType(returnType, lenient = false)) } } + +private object SparkConnectClosureCleaner { + def clean(closure: AnyRef): Unit = { + ClosureCleaner.clean(closure, cleanTransitively = true, mutable.Map.empty) + } +} diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala index 5bb8cbf3543..51e58f9b0bb 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala @@ -362,4 +362,147 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach { val output = runCommandsInShell(input) assertContains("noException: Boolean = true", output) } + + test("closure cleaner") { + val input = + """ + |class NonSerializable(val id: Int = -1) { } + | + |{ + | val x = 100 + | val y = new NonSerializable + |} + | + |val t = 200 + | + |{ + | def foo(): Int = { x } + | def bar(): Int = { y.id } + | val z = new NonSerializable + |} + | + |{ + | val myLambda = (a: Int) => a + t + foo() + | val myUdf = udf(myLambda) + |} + | + |spark.range(0, 10). + | withColumn("result", myUdf(col("id"))). + | agg(sum("result")). + | collect()(0)(0).asInstanceOf[Long] + |""".stripMargin + val output = runCommandsInShell(input) + assertContains(": Long = 3045", output) + } + + test("closure cleaner with function") { + val input = + """ + |class NonSerializable(val id: Int = -1) { } + | + |{ + | val x = 100 + | val y = new NonSerializable + |} + | + |{ + | def foo(): Int = { x } + | def bar(): Int = { y.id } + | val z = new NonSerializable + |} + | + |def example() = { + | val myLambda = (a: Int) => a + foo() + | val myUdf = udf(myLambda) + | spark.range(0, 10). + | withColumn("result", myUdf(col("id"))). + | agg(sum("result")). + | collect()(0)(0).asInstanceOf[Long] + |} + | + |example() + |""".stripMargin + val output = runCommandsInShell(input) + assertContains(": Long = 1045", output) + } + + test("closure cleaner nested") { + val input = + """ + |class NonSerializable(val id: Int = -1) { } + | + |{ + | val x = 100 + | val y = new NonSerializable + |} + | + |{ + | def foo(): Int = { x } + | def bar(): Int = { y.id } + | val z = new NonSerializable + |} + | + |val example = () => { + | val nested = () => { + | val myLambda = (a: Int) => a + foo() + | val myUdf = udf(myLambda) + | spark.range(0, 10). + | withColumn("result", myUdf(col("id"))). + | agg(sum("result")). + | collect()(0)(0).asInstanceOf[Long] + | } + | nested() + |} + |example() + |""".stripMargin + val output = runCommandsInShell(input) + assertContains(": Long = 1045", output) + } + + test("closure cleaner with enclosing lambdas") { + val input = + """ + |class NonSerializable(val id: Int = -1) { } + | + |{ + | val x = 100 + | val y = new NonSerializable + |} + | + |val z = new NonSerializable + | + |spark.range(0, 10). + |// for this call UdfUtils will create a new lambda and this lambda becomes enclosing + | map(i => i + x). + | agg(sum("value")). + | collect()(0)(0).asInstanceOf[Long] + |""".stripMargin + val output = runCommandsInShell(input) + assertContains(": Long = 1045", output) + } + + test("closure cleaner cleans capturing class") { + val input = + """ + |class NonSerializable(val id: Int = -1) { } + | + |{ + | val x = 100 + | val y = new NonSerializable + |} + | + |class Test extends Serializable { + | // capturing class is cmd$Helper$Test + | val myUdf = udf((i: Int) => i + x) + | val z = new NonSerializable + | val res = spark.range(0, 10). + | withColumn("result", myUdf(col("id"))). + | agg(sum("result")). + | collect()(0)(0).asInstanceOf[Long] + |} + |(new Test()).res + |""".stripMargin + val output = runCommandsInShell(input) + assertContains(": Long = 1045", output) + } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 785e1fa4017..7ddb339b12d 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -324,6 +324,14 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.expressions.ScalarUserDefinedFunction$"), + // New private API added in the client + ProblemFilters + .exclude[MissingClassProblem]( + "org.apache.spark.sql.expressions.SparkConnectClosureCleaner"), + ProblemFilters + .exclude[MissingClassProblem]( + "org.apache.spark.sql.expressions.SparkConnectClosureCleaner$"), + // Dataset ProblemFilters.exclude[DirectMissingMethodProblem]( "org.apache.spark.sql.Dataset.plan" diff --git a/core/pom.xml b/core/pom.xml index 5ac3d5bb4de..e55283b75fa 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -64,10 +64,6 @@ <artifactId>jnr-posix</artifactId> <scope>test</scope> </dependency> - <dependency> - <groupId>org.apache.xbean</groupId> - <artifactId>xbean-asm9-shaded</artifactId> - </dependency> <dependency> <groupId>org.apache.hadoop</groupId> <artifactId>hadoop-client-api</artifactId> diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 893895e8fb2..c86f755bbd1 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2699,7 +2699,7 @@ class SparkContext(config: SparkConf) extends Logging { * @return the cleaned closure */ private[spark] def clean[F <: AnyRef](f: F, checkSerializable: Boolean = true): F = { - ClosureCleaner.clean(f, checkSerializable) + SparkClosureCleaner.clean(f, checkSerializable) f } diff --git a/core/src/main/scala/org/apache/spark/util/SparkClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/SparkClosureCleaner.scala new file mode 100644 index 00000000000..44e0efb4494 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/SparkClosureCleaner.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import scala.collection.mutable + +import org.apache.spark.{SparkEnv, SparkException} + +private[spark] object SparkClosureCleaner { + /** + * Clean the given closure in place. + * + * More specifically, this renders the given closure serializable as long as it does not + * explicitly reference unserializable objects. + * + * @param closure the closure to clean + * @param checkSerializable whether to verify that the closure is serializable after cleaning + * @param cleanTransitively whether to clean enclosing closures transitively + */ + def clean( + closure: AnyRef, + checkSerializable: Boolean = true, + cleanTransitively: Boolean = true): Unit = { + if (ClosureCleaner.clean(closure, cleanTransitively, mutable.Map.empty)) { + try { + if (checkSerializable && SparkEnv.get != null) { + SparkEnv.get.closureSerializer.newInstance().serialize(closure) + } + } catch { + case ex: Exception => throw new SparkException("Task not serializable", ex) + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index f8decbcff5f..f22bec5c2be 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -24,7 +24,7 @@ import java.lang.reflect.InvocationTargetException import java.math.{MathContext, RoundingMode} import java.net._ import java.nio.ByteBuffer -import java.nio.channels.{Channels, FileChannel, WritableByteChannel} +import java.nio.channels.Channels import java.nio.charset.StandardCharsets import java.nio.file.Files import java.security.SecureRandom @@ -97,7 +97,8 @@ private[spark] object Utils with SparkClassUtils with SparkErrorUtils with SparkFileUtils - with SparkSerDeUtils { + with SparkSerDeUtils + with SparkStreamUtils { private val sparkUncaughtExceptionHandler = new SparkUncaughtExceptionHandler @volatile private var cachedLocalDir: String = "" @@ -244,49 +245,6 @@ private[spark] object Utils dir } - /** - * Copy all data from an InputStream to an OutputStream. NIO way of file stream to file stream - * copying is disabled by default unless explicitly set transferToEnabled as true, - * the parameter transferToEnabled should be configured by spark.file.transferTo = [true|false]. - */ - def copyStream( - in: InputStream, - out: OutputStream, - closeStreams: Boolean = false, - transferToEnabled: Boolean = false): Long = { - tryWithSafeFinally { - (in, out) match { - case (input: FileInputStream, output: FileOutputStream) if transferToEnabled => - // When both streams are File stream, use transferTo to improve copy performance. - val inChannel = input.getChannel - val outChannel = output.getChannel - val size = inChannel.size() - copyFileStreamNIO(inChannel, outChannel, 0, size) - size - case (input, output) => - var count = 0L - val buf = new Array[Byte](8192) - var n = 0 - while (n != -1) { - n = input.read(buf) - if (n != -1) { - output.write(buf, 0, n) - count += n - } - } - count - } - } { - if (closeStreams) { - try { - in.close() - } finally { - out.close() - } - } - } - } - /** * Copy the first `maxSize` bytes of data from the InputStream to an in-memory * buffer, primarily to check for corruption. @@ -331,43 +289,6 @@ private[spark] object Utils } } - def copyFileStreamNIO( - input: FileChannel, - output: WritableByteChannel, - startPosition: Long, - bytesToCopy: Long): Unit = { - val outputInitialState = output match { - case outputFileChannel: FileChannel => - Some((outputFileChannel.position(), outputFileChannel)) - case _ => None - } - var count = 0L - // In case transferTo method transferred less data than we have required. - while (count < bytesToCopy) { - count += input.transferTo(count + startPosition, bytesToCopy - count, output) - } - assert(count == bytesToCopy, - s"request to copy $bytesToCopy bytes, but actually copied $count bytes.") - - // Check the position after transferTo loop to see if it is in the right position and - // give user information if not. - // Position will not be increased to the expected length after calling transferTo in - // kernel version 2.6.32, this issue can be seen in - // https://bugs.openjdk.java.net/browse/JDK-7052359 - // This will lead to stream corruption issue when using sort-based shuffle (SPARK-3948). - outputInitialState.foreach { case (initialPos, outputFileChannel) => - val finalPos = outputFileChannel.position() - val expectedPos = initialPos + bytesToCopy - assert(finalPos == expectedPos, - s""" - |Current position $finalPos do not equal to expected position $expectedPos - |after transferTo, please check your kernel version to see if it is 2.6.32, - |this is a kernel bug which will lead to unexpected behavior when using transferTo. - |You can set spark.file.transferTo = false to disable this NIO feature. - """.stripMargin) - } - } - /** * A file name may contain some invalid URI characters, such as " ". This method will convert the * file name to a raw path accepted by `java.net.URI(String)`. diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index cef0d8c1de0..2f084b2037e 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -373,7 +373,7 @@ class TestCreateNullValue { println(getX) } // scalastyle:on println - ClosureCleaner.clean(closure) + SparkClosureCleaner.clean(closure) } nestedClosure() } diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala index 0635b4a358a..b055dae3994 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala @@ -96,10 +96,10 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri // If the resulting closure is not serializable even after // cleaning, we expect ClosureCleaner to throw a SparkException if (serializableAfter) { - ClosureCleaner.clean(closure, checkSerializable = true, transitive) + SparkClosureCleaner.clean(closure, checkSerializable = true, transitive) } else { intercept[SparkException] { - ClosureCleaner.clean(closure, checkSerializable = true, transitive) + SparkClosureCleaner.clean(closure, checkSerializable = true, transitive) } } assertSerializable(closure, serializableAfter) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 47fd7881d2f..10864390e3f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -43,7 +43,9 @@ object MimaExcludes { // [SPARK-44198][CORE] Support propagation of the log level to the executors ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages$SparkAppConfig$"), // [SPARK-45427][CORE] Add RPC SSL settings to SSLOptions and SparkTransportConf - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.SparkTransportConf.fromSparkConf") + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.SparkTransportConf.fromSparkConf"), + // [SPARK-45136][CONNECT] Enhance ClosureCleaner with Ammonite support + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.MethodIdentifier$") ) // Default exclude rules diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 1c77b87dbf1..dc5e22f0571 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.errors.QueryErrorsBase import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} -import org.apache.spark.util.{ClosureCleaner, Utils} +import org.apache.spark.util.{SparkClosureCleaner, Utils} case class RepeatedStruct(s: Seq[PrimitiveData]) @@ -689,7 +689,7 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes val encoder = implicitly[ExpressionEncoder[T]] // Make sure encoder is serializable. - ClosureCleaner.clean((s: String) => encoder.getClass.getName) + SparkClosureCleaner.clean((s: String) => encoder.getClass.getName) val row = encoder.createSerializer().apply(input) val schema = toAttributes(encoder.schema) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala index dcd698c860d..f04b9da9b45 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala @@ -22,7 +22,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.{JavaPairRDD, JavaUtils, Optional} import org.apache.spark.api.java.function.{Function3 => JFunction3, Function4 => JFunction4} import org.apache.spark.rdd.RDD -import org.apache.spark.util.ClosureCleaner +import org.apache.spark.util.SparkClosureCleaner /** * :: Experimental :: @@ -157,7 +157,7 @@ object StateSpec { def function[KeyType, ValueType, StateType, MappedType]( mappingFunction: (Time, KeyType, Option[ValueType], State[StateType]) => Option[MappedType] ): StateSpec[KeyType, ValueType, StateType, MappedType] = { - ClosureCleaner.clean(mappingFunction, checkSerializable = true) + SparkClosureCleaner.clean(mappingFunction, checkSerializable = true) new StateSpecImpl(mappingFunction) } @@ -175,7 +175,7 @@ object StateSpec { def function[KeyType, ValueType, StateType, MappedType]( mappingFunction: (KeyType, Option[ValueType], State[StateType]) => MappedType ): StateSpec[KeyType, ValueType, StateType, MappedType] = { - ClosureCleaner.clean(mappingFunction, checkSerializable = true) + SparkClosureCleaner.clean(mappingFunction, checkSerializable = true) val wrappedFunction = (time: Time, key: KeyType, value: Option[ValueType], state: State[StateType]) => { Some(mappingFunction(key, value, state)) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org