This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3e2b146eb81 [SPARK-45136][CONNECT] Enhance ClosureCleaner with 
Ammonite support
3e2b146eb81 is described below

commit 3e2b146eb81d9a5727f07b58f7bb1760a71a8697
Author: Vsevolod Stepanov <vsevolod.stepa...@databricks.com>
AuthorDate: Wed Oct 25 21:35:07 2023 -0400

    [SPARK-45136][CONNECT] Enhance ClosureCleaner with Ammonite support
    
    ### What changes were proposed in this pull request?
    This PR enhances existing ClosureCleaner implementation to support cleaning 
closures defined in Ammonite. Please refer to [this 
gist](https://gist.github.com/vsevolodstep-db/b8e4d676745d6e2d047ecac291e5254c) 
to get more context on how Ammonite code wrapping works and what problems I'm 
trying to solve here.
    
    Overall, it contains these logical changes in `ClosureCleaner`:
    1. Making it recognize and clean closures defined in Ammonite (previously 
it was checking if capturing class name starts with `$line` and ends with 
`$iw`, which is native Scala REPL specific thing
    2. Making it clean closures if they are defined inside a user class in a 
REPL (see corner case 1 in the gist)
    3. Making it clean nested closures properly for Ammonite REPL (see corner 
case 2 in the gist)
    4. Making it transitively follow other Ammonite commands that are captured 
by the target closure.
    
    Please note that `cleanTransitively` option of `ClosureCleaner.clean()` 
method refers to following references transitively within enclosing command 
object, but it doesn't follow other command objects.
    
    As we need `ClosureCleaner` to be available in Spark Connect, I also moved 
the implementation to `common-utils` module. This brings a new 
`xbean-asm9-shaded` which is fairly small.
    
    Also, this PR moves `checkSerializable` check from `ClosureCleaner` to 
`SparkClosureCleaner`, as it is specific to Spark core
    
    The important changes affect `ClosureCleaner` only. They should not affect 
existing codepath for normal Scala closures / closures defined in a native 
Scala REPL and cover only closures defined in Ammonite.
    
    Also,  this PR modifies SparkConnect's `UserDefinedFunction` to actually 
use `ClosureCleaner` and clean closures in SparkConnect
    
    ### Why are the changes needed?
    To properly support closures defined in Ammonite, reduce UDF payload size 
and avoid possible `NonSerializable` exceptions. This includes:
    - lambda capturing outer command object, leading in a circular dependency
    - lambda capturing other command objects transitively, exploding payload 
size
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Existing tests.
    New tests in `ReplE2ESuite` covering various scenarios using SparkConnect + 
Ammonite REPL to make sure closures are actually cleaned.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #42995 from vsevolodstep-db/SPARK-45136/closure-cleaner.
    
    Authored-by: Vsevolod Stepanov <vsevolod.stepa...@databricks.com>
    Signed-off-by: Herman van Hovell <her...@databricks.com>
---
 common/utils/pom.xml                               |   4 +
 .../org/apache/spark/util/ClosureCleaner.scala     | 636 ++++++++++++++-------
 .../org/apache/spark/util/SparkStreamUtils.scala   | 109 ++++
 .../sql/expressions/UserDefinedFunction.scala      |  10 +-
 .../spark/sql/application/ReplE2ESuite.scala       | 143 +++++
 .../CheckConnectJvmClientCompatibility.scala       |   8 +
 core/pom.xml                                       |   4 -
 .../main/scala/org/apache/spark/SparkContext.scala |   2 +-
 .../apache/spark/util/SparkClosureCleaner.scala    |  49 ++
 .../main/scala/org/apache/spark/util/Utils.scala   |  85 +--
 .../apache/spark/util/ClosureCleanerSuite.scala    |   2 +-
 .../apache/spark/util/ClosureCleanerSuite2.scala   |   4 +-
 project/MimaExcludes.scala                         |   4 +-
 .../catalyst/encoders/ExpressionEncoderSuite.scala |   4 +-
 .../org/apache/spark/streaming/StateSpec.scala     |   6 +-
 15 files changed, 756 insertions(+), 314 deletions(-)

diff --git a/common/utils/pom.xml b/common/utils/pom.xml
index 37d1ea48d97..44cb30a19ff 100644
--- a/common/utils/pom.xml
+++ b/common/utils/pom.xml
@@ -39,6 +39,10 @@
       <groupId>org.apache.spark</groupId>
       <artifactId>spark-tags_${scala.binary.version}</artifactId>
     </dependency>
+    <dependency>
+      <groupId>org.apache.xbean</groupId>
+      <artifactId>xbean-asm9-shaded</artifactId>
+    </dependency>
     <dependency>
       <groupId>com.fasterxml.jackson.core</groupId>
       <artifactId>jackson-databind</artifactId>
diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala 
b/common/utils/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
similarity index 61%
rename from core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
rename to common/utils/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
index 29fb0206f90..ffa2f0e60b2 100644
--- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
+++ b/common/utils/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
@@ -21,7 +21,7 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
 import java.lang.invoke.{MethodHandleInfo, SerializedLambda}
 import java.lang.reflect.{Field, Modifier}
 
-import scala.collection.mutable.{Map, Set, Stack}
+import scala.collection.mutable.{Map, Queue, Set, Stack}
 import scala.jdk.CollectionConverters._
 
 import org.apache.commons.lang3.ClassUtils
@@ -29,14 +29,13 @@ import org.apache.xbean.asm9.{ClassReader, ClassVisitor, 
Handle, MethodVisitor,
 import org.apache.xbean.asm9.Opcodes._
 import org.apache.xbean.asm9.tree.{ClassNode, MethodNode}
 
-import org.apache.spark.{SparkEnv, SparkException}
+import org.apache.spark.SparkException
 import org.apache.spark.internal.Logging
 
 /**
  * A cleaner that renders closures serializable if they can be done so safely.
  */
 private[spark] object ClosureCleaner extends Logging {
-
   // Get an ASM class reader for a given class from the JAR that loaded it
   private[util] def getClassReader(cls: Class[_]): ClassReader = {
     // Copy data over, before delegating to ClassReader - else we can run out 
of open file handles.
@@ -46,11 +45,18 @@ private[spark] object ClosureCleaner extends Logging {
       null
     } else {
       val baos = new ByteArrayOutputStream(128)
-      Utils.copyStream(resourceStream, baos, true)
+
+      SparkStreamUtils.copyStream(resourceStream, baos, closeStreams = true)
       new ClassReader(new ByteArrayInputStream(baos.toByteArray))
     }
   }
 
+  private[util] def isAmmoniteCommandOrHelper(clazz: Class[_]): Boolean = 
clazz.getName.matches(
+    """^ammonite\.\$sess\.cmd[0-9]*(\$Helper\$?)?""")
+
+  private[util] def isDefinedInAmmonite(clazz: Class[_]): Boolean = 
clazz.getName.matches(
+    """^ammonite\.\$sess\.cmd[0-9]*.*""")
+
   // Check whether a class represents a Scala closure
   private def isClosure(cls: Class[_]): Boolean = {
     cls.getName.contains("$anonfun$")
@@ -146,23 +152,6 @@ private[spark] object ClosureCleaner extends Logging {
     clone
   }
 
-  /**
-   * Clean the given closure in place.
-   *
-   * More specifically, this renders the given closure serializable as long as 
it does not
-   * explicitly reference unserializable objects.
-   *
-   * @param closure the closure to clean
-   * @param checkSerializable whether to verify that the closure is 
serializable after cleaning
-   * @param cleanTransitively whether to clean enclosing closures transitively
-   */
-  def clean(
-      closure: AnyRef,
-      checkSerializable: Boolean = true,
-      cleanTransitively: Boolean = true): Unit = {
-    clean(closure, checkSerializable, cleanTransitively, Map.empty)
-  }
-
   /**
    * Helper method to clean the given closure in place.
    *
@@ -198,18 +187,15 @@ private[spark] object ClosureCleaner extends Logging {
    * pointer of a cloned scope "one" and set it the parent of scope "two", 
such that scope "two"
    * no longer references SomethingNotSerializable transitively.
    *
-   * @param func the starting closure to clean
-   * @param checkSerializable whether to verify that the closure is 
serializable after cleaning
+   * @param func              the starting closure to clean
    * @param cleanTransitively whether to clean enclosing closures transitively
-   * @param accessedFields a map from a class to a set of its fields that are 
accessed by
-   *                       the starting closure
+   * @param accessedFields    a map from a class to a set of its fields that 
are accessed by
+   *                          the starting closure
    */
-  private def clean(
+  private[spark] def clean(
       func: AnyRef,
-      checkSerializable: Boolean,
       cleanTransitively: Boolean,
-      accessedFields: Map[Class[_], Set[String]]): Unit = {
-
+      accessedFields: Map[Class[_], Set[String]]): Boolean = {
     // indylambda check. Most likely to be the case with 2.12, 2.13
     // so we check first
     // non LMF-closures should be less frequent from now on
@@ -217,131 +203,18 @@ private[spark] object ClosureCleaner extends Logging {
 
     if (!isClosure(func.getClass) && maybeIndylambdaProxy.isEmpty) {
       logDebug(s"Expected a closure; got ${func.getClass.getName}")
-      return
+      return false
     }
 
     // TODO: clean all inner closures first. This requires us to find the 
inner objects.
     // TODO: cache outerClasses / innerClasses / accessedFields
 
     if (func == null) {
-      return
+      return false
     }
 
     if (maybeIndylambdaProxy.isEmpty) {
-      logDebug(s"+++ Cleaning closure $func (${func.getClass.getName}) +++")
-
-      // A list of classes that represents closures enclosed in the given one
-      val innerClasses = getInnerClosureClasses(func)
-
-      // A list of enclosing objects and their respective classes, from 
innermost to outermost
-      // An outer object at a given index is of type outer class at the same 
index
-      val (outerClasses, outerObjects) = getOuterClassesAndObjects(func)
-
-      // For logging purposes only
-      val declaredFields = func.getClass.getDeclaredFields
-      val declaredMethods = func.getClass.getDeclaredMethods
-
-      if (log.isDebugEnabled) {
-        logDebug(s" + declared fields: ${declaredFields.size}")
-        declaredFields.foreach { f => logDebug(s"     $f") }
-        logDebug(s" + declared methods: ${declaredMethods.size}")
-        declaredMethods.foreach { m => logDebug(s"     $m") }
-        logDebug(s" + inner classes: ${innerClasses.size}")
-        innerClasses.foreach { c => logDebug(s"     ${c.getName}") }
-        logDebug(s" + outer classes: ${outerClasses.size}" )
-        outerClasses.foreach { c => logDebug(s"     ${c.getName}") }
-      }
-
-      // Fail fast if we detect return statements in closures
-      getClassReader(func.getClass).accept(new ReturnStatementFinder(), 0)
-
-      // If accessed fields is not populated yet, we assume that
-      // the closure we are trying to clean is the starting one
-      if (accessedFields.isEmpty) {
-        logDebug(" + populating accessed fields because this is the starting 
closure")
-        // Initialize accessed fields with the outer classes first
-        // This step is needed to associate the fields to the correct classes 
later
-        initAccessedFields(accessedFields, outerClasses)
-
-        // Populate accessed fields by visiting all fields and methods 
accessed by this and
-        // all of its inner closures. If transitive cleaning is enabled, this 
may recursively
-        // visits methods that belong to other classes in search of 
transitively referenced fields.
-        for (cls <- func.getClass :: innerClasses) {
-          getClassReader(cls).accept(new FieldAccessFinder(accessedFields, 
cleanTransitively), 0)
-        }
-      }
-
-      logDebug(s" + fields accessed by starting closure: 
${accessedFields.size} classes")
-      accessedFields.foreach { f => logDebug("     " + f) }
-
-      // List of outer (class, object) pairs, ordered from outermost to 
innermost
-      // Note that all outer objects but the outermost one (first one in this 
list) must be closures
-      var outerPairs: List[(Class[_], AnyRef)] = 
outerClasses.zip(outerObjects).reverse
-      var parent: AnyRef = null
-      if (outerPairs.nonEmpty) {
-        val outermostClass = outerPairs.head._1
-        val outermostObject = outerPairs.head._2
-
-        if (isClosure(outermostClass)) {
-          logDebug(s" + outermost object is a closure, so we clone it: 
${outermostClass}")
-        } else if (outermostClass.getName.startsWith("$line")) {
-          // SPARK-14558: if the outermost object is a REPL line object, we 
should clone
-          // and clean it as it may carry a lot of unnecessary information,
-          // e.g. hadoop conf, spark conf, etc.
-          logDebug(s" + outermost object is a REPL line object, so we clone 
it:" +
-            s" ${outermostClass}")
-        } else {
-          // The closure is ultimately nested inside a class; keep the object 
of that
-          // class without cloning it since we don't want to clone the user's 
objects.
-          // Note that we still need to keep around the outermost object 
itself because
-          // we need it to clone its child closure later (see below).
-          logDebug(s" + outermost object is not a closure or REPL line 
object," +
-            s" so do not clone it: ${outermostClass}")
-          parent = outermostObject // e.g. SparkContext
-          outerPairs = outerPairs.tail
-        }
-      } else {
-        logDebug(" + there are no enclosing objects!")
-      }
-
-      // Clone the closure objects themselves, nulling out any fields that are 
not
-      // used in the closure we're working on or any of its inner closures.
-      for ((cls, obj) <- outerPairs) {
-        logDebug(s" + cloning instance of class ${cls.getName}")
-        // We null out these unused references by cloning each object and then 
filling in all
-        // required fields from the original object. We need the parent here 
because the Java
-        // language specification requires the first constructor parameter of 
any closure to be
-        // its enclosing object.
-        val clone = cloneAndSetFields(parent, obj, cls, accessedFields)
-
-        // If transitive cleaning is enabled, we recursively clean any 
enclosing closure using
-        // the already populated accessed fields map of the starting closure
-        if (cleanTransitively && isClosure(clone.getClass)) {
-          logDebug(s" + cleaning cloned closure recursively (${cls.getName})")
-          // No need to check serializable here for the outer closures because 
we're
-          // only interested in the serializability of the starting closure
-          clean(clone, checkSerializable = false, cleanTransitively, 
accessedFields)
-        }
-        parent = clone
-      }
-
-      // Update the parent pointer ($outer) of this closure
-      if (parent != null) {
-        val field = func.getClass.getDeclaredField("$outer")
-        field.setAccessible(true)
-        // If the starting closure doesn't actually need our enclosing object, 
then just null it out
-        if (accessedFields.contains(func.getClass) &&
-          !accessedFields(func.getClass).contains("$outer")) {
-          logDebug(s" + the starting closure doesn't actually need $parent, so 
we null it out")
-          field.set(func, null)
-        } else {
-          // Update this closure's parent pointer to point to our enclosing 
object,
-          // which could either be a cloned closure or the original user object
-          field.set(func, parent)
-        }
-      }
-
-      logDebug(s" +++ closure $func (${func.getClass.getName}) is now cleaned 
+++")
+      cleanNonIndyLambdaClosure(func, cleanTransitively, accessedFields)
     } else {
       val lambdaProxy = maybeIndylambdaProxy.get
       val implMethodName = lambdaProxy.getImplMethodName
@@ -359,62 +232,339 @@ private[spark] object ClosureCleaner extends Logging {
       val capturingClassReader = getClassReader(capturingClass)
       capturingClassReader.accept(new 
ReturnStatementFinder(Option(implMethodName)), 0)
 
-      val isClosureDeclaredInScalaRepl = 
capturingClassName.startsWith("$line") &&
-        capturingClassName.endsWith("$iw")
-      val outerThisOpt = if (lambdaProxy.getCapturedArgCount > 0) {
-        Option(lambdaProxy.getCapturedArg(0))
+      val outerThis = if (lambdaProxy.getCapturedArgCount > 0) {
+        // only need to clean when there is an enclosing non-null "this" 
captured by the closure
+        Option(lambdaProxy.getCapturedArg(0)).getOrElse(return false)
       } else {
-        None
+        return false
       }
 
-      // only need to clean when there is an enclosing "this" captured by the 
closure, and it
-      // should be something cleanable, i.e. a Scala REPL line object
-      val needsCleaning = isClosureDeclaredInScalaRepl &&
-        outerThisOpt.isDefined && outerThisOpt.get.getClass.getName == 
capturingClassName
-
-      if (needsCleaning) {
-        // indylambda closures do not reference enclosing closures via an 
`$outer` chain, so no
-        // transitive cleaning on the `$outer` chain is needed.
-        // Thus clean() shouldn't be recursively called with a non-empty 
accessedFields.
-        assert(accessedFields.isEmpty)
-
-        initAccessedFields(accessedFields, Seq(capturingClass))
-        IndylambdaScalaClosures.findAccessedFields(
-          lambdaProxy, classLoader, accessedFields, cleanTransitively)
-
-        logDebug(s" + fields accessed by starting closure: 
${accessedFields.size} classes")
-        accessedFields.foreach { f => logDebug("     " + f) }
-
-        if (accessedFields(capturingClass).size < 
capturingClass.getDeclaredFields.length) {
-          // clone and clean the enclosing `this` only when there are fields 
to null out
-
-          val outerThis = outerThisOpt.get
-
-          logDebug(s" + cloning instance of REPL class $capturingClassName")
-          val clonedOuterThis = cloneAndSetFields(
-            parent = null, outerThis, capturingClass, accessedFields)
-
-          val outerField = func.getClass.getDeclaredField("arg$1")
-          // SPARK-37072: When Java 17 is used and `outerField` is read-only,
-          // the content of `outerField` cannot be set by reflect api directly.
-          // But we can remove the `final` modifier of `outerField` before set 
value
-          // and reset the modifier after set value.
-          val modifiersField = getFinalModifiersFieldForJava17(outerField)
-          modifiersField
-            .foreach(m => m.setInt(outerField, outerField.getModifiers & 
~Modifier.FINAL))
-          outerField.setAccessible(true)
-          outerField.set(func, clonedOuterThis)
-          modifiersField
-            .foreach(m => m.setInt(outerField, outerField.getModifiers | 
Modifier.FINAL))
+      // clean only if enclosing "this" is something cleanable, i.e. a Scala 
REPL line object or
+      // Ammonite command helper object.
+      // For Ammonite closures, we do not care about actual capturing class 
name,
+      // as closure needs to be cleaned if it captures Ammonite command helper 
object
+      if (isDefinedInAmmonite(outerThis.getClass)) {
+        // If outerThis is a lambda, we have to clean that instead
+        IndylambdaScalaClosures.getSerializationProxy(outerThis).foreach { _ =>
+          return clean(outerThis, cleanTransitively, accessedFields)
+        }
+        cleanupAmmoniteReplClosure(func, lambdaProxy, outerThis, 
cleanTransitively)
+      } else {
+        val isClosureDeclaredInScalaRepl = 
capturingClassName.startsWith("$line") &&
+          capturingClassName.endsWith("$iw")
+        if (isClosureDeclaredInScalaRepl && outerThis.getClass.getName == 
capturingClassName) {
+          assert(accessedFields.isEmpty)
+          cleanupScalaReplClosure(func, lambdaProxy, outerThis, 
cleanTransitively)
         }
       }
 
       logDebug(s" +++ indylambda closure ($implMethodName) is now cleaned +++")
     }
 
-    if (checkSerializable) {
-      ensureSerializable(func)
+    true
+  }
+
+  /**
+   * Cleans non-indylambda closure in place
+   *
+   * @param func              the starting closure to clean
+   * @param cleanTransitively whether to clean enclosing closures transitively
+   * @param accessedFields    a map from a class to a set of its fields that 
are accessed by
+   *                          the starting closure
+   */
+  private def cleanNonIndyLambdaClosure(
+      func: AnyRef,
+      cleanTransitively: Boolean,
+      accessedFields: Map[Class[_], Set[String]]): Unit = {
+    logDebug(s"+++ Cleaning closure $func (${func.getClass.getName}) +++")
+
+    // A list of classes that represents closures enclosed in the given one
+    val innerClasses = getInnerClosureClasses(func)
+
+    // A list of enclosing objects and their respective classes, from 
innermost to outermost
+    // An outer object at a given index is of type outer class at the same 
index
+    val (outerClasses, outerObjects) = getOuterClassesAndObjects(func)
+
+    // For logging purposes only
+    val declaredFields = func.getClass.getDeclaredFields
+    val declaredMethods = func.getClass.getDeclaredMethods
+
+    if (log.isDebugEnabled) {
+      logDebug(s" + declared fields: ${declaredFields.size}")
+      declaredFields.foreach { f => logDebug(s"     $f") }
+      logDebug(s" + declared methods: ${declaredMethods.size}")
+      declaredMethods.foreach { m => logDebug(s"     $m") }
+      logDebug(s" + inner classes: ${innerClasses.size}")
+      innerClasses.foreach { c => logDebug(s"     ${c.getName}") }
+      logDebug(s" + outer classes: ${outerClasses.size}")
+      outerClasses.foreach { c => logDebug(s"     ${c.getName}") }
     }
+
+    // Fail fast if we detect return statements in closures
+    getClassReader(func.getClass).accept(new ReturnStatementFinder(), 0)
+
+    // If accessed fields is not populated yet, we assume that
+    // the closure we are trying to clean is the starting one
+    if (accessedFields.isEmpty) {
+      logDebug(" + populating accessed fields because this is the starting 
closure")
+      // Initialize accessed fields with the outer classes first
+      // This step is needed to associate the fields to the correct classes 
later
+      initAccessedFields(accessedFields, outerClasses)
+
+      // Populate accessed fields by visiting all fields and methods accessed 
by this and
+      // all of its inner closures. If transitive cleaning is enabled, this 
may recursively
+      // visits methods that belong to other classes in search of transitively 
referenced fields.
+      for (cls <- func.getClass :: innerClasses) {
+        getClassReader(cls).accept(new FieldAccessFinder(accessedFields, 
cleanTransitively), 0)
+      }
+    }
+
+    logDebug(s" + fields accessed by starting closure: ${accessedFields.size} 
classes")
+    accessedFields.foreach { f => logDebug("     " + f) }
+
+    // List of outer (class, object) pairs, ordered from outermost to innermost
+    // Note that all outer objects but the outermost one (first one in this 
list) must be closures
+    var outerPairs: List[(Class[_], AnyRef)] = 
outerClasses.zip(outerObjects).reverse
+    var parent: AnyRef = null
+    if (outerPairs.nonEmpty) {
+      val outermostClass = outerPairs.head._1
+      val outermostObject = outerPairs.head._2
+
+      if (isClosure(outermostClass)) {
+        logDebug(s" + outermost object is a closure, so we clone it: 
${outermostClass}")
+      } else if (outermostClass.getName.startsWith("$line")) {
+        // SPARK-14558: if the outermost object is a REPL line object, we 
should clone
+        // and clean it as it may carry a lot of unnecessary information,
+        // e.g. hadoop conf, spark conf, etc.
+        logDebug(s" + outermost object is a REPL line object, so we clone it:" 
+
+          s" ${outermostClass}")
+      } else {
+        // The closure is ultimately nested inside a class; keep the object of 
that
+        // class without cloning it since we don't want to clone the user's 
objects.
+        // Note that we still need to keep around the outermost object itself 
because
+        // we need it to clone its child closure later (see below).
+        logDebug(s" + outermost object is not a closure or REPL line object," +
+          s" so do not clone it: ${outermostClass}")
+        parent = outermostObject // e.g. SparkContext
+        outerPairs = outerPairs.tail
+      }
+    } else {
+      logDebug(" + there are no enclosing objects!")
+    }
+
+    // Clone the closure objects themselves, nulling out any fields that are 
not
+    // used in the closure we're working on or any of its inner closures.
+    for ((cls, obj) <- outerPairs) {
+      logDebug(s" + cloning instance of class ${cls.getName}")
+      // We null out these unused references by cloning each object and then 
filling in all
+      // required fields from the original object. We need the parent here 
because the Java
+      // language specification requires the first constructor parameter of 
any closure to be
+      // its enclosing object.
+      val clone = cloneAndSetFields(parent, obj, cls, accessedFields)
+
+      // If transitive cleaning is enabled, we recursively clean any enclosing 
closure using
+      // the already populated accessed fields map of the starting closure
+      if (cleanTransitively && isClosure(clone.getClass)) {
+        logDebug(s" + cleaning cloned closure recursively (${cls.getName})")
+        clean(clone, cleanTransitively, accessedFields)
+      }
+      parent = clone
+    }
+
+    // Update the parent pointer ($outer) of this closure
+    if (parent != null) {
+      val field = func.getClass.getDeclaredField("$outer")
+      field.setAccessible(true)
+      // If the starting closure doesn't actually need our enclosing object, 
then just null it out
+      if (accessedFields.contains(func.getClass) &&
+        !accessedFields(func.getClass).contains("$outer")) {
+        logDebug(s" + the starting closure doesn't actually need $parent, so 
we null it out")
+        field.set(func, null)
+      } else {
+        // Update this closure's parent pointer to point to our enclosing 
object,
+        // which could either be a cloned closure or the original user object
+        field.set(func, parent)
+      }
+    }
+
+    logDebug(s" +++ closure $func (${func.getClass.getName}) is now cleaned 
+++")
+  }
+
+  /**
+   * Null out fields of enclosing class which are not actually accessed by a 
closure
+   * @param func the starting closure to clean
+   * @param lambdaProxy starting closure proxy
+   * @param outerThis lambda enclosing class
+   * @param cleanTransitively whether to clean enclosing closures transitively
+   */
+  private def cleanupScalaReplClosure(
+      func: AnyRef,
+      lambdaProxy: SerializedLambda,
+      outerThis: AnyRef,
+      cleanTransitively: Boolean): Unit = {
+
+    val capturingClass = outerThis.getClass
+    val accessedFields: Map[Class[_], Set[String]] = Map.empty
+    initAccessedFields(accessedFields, Seq(capturingClass))
+
+    IndylambdaScalaClosures.findAccessedFields(
+      lambdaProxy,
+      func.getClass.getClassLoader,
+      accessedFields,
+      Map.empty,
+      Map.empty,
+      cleanTransitively)
+
+    logDebug(s" + fields accessed by starting closure: ${accessedFields.size} 
classes")
+    accessedFields.foreach { f => logDebug("     " + f) }
+
+    if (accessedFields(capturingClass).size < 
capturingClass.getDeclaredFields.length) {
+      // clone and clean the enclosing `this` only when there are fields to 
null out
+      logDebug(s" + cloning instance of REPL class ${capturingClass.getName}")
+      val clonedOuterThis = cloneAndSetFields(
+        parent = null, outerThis, capturingClass, accessedFields)
+
+      val outerField = func.getClass.getDeclaredField("arg$1")
+      // SPARK-37072: When Java 17 is used and `outerField` is read-only,
+      // the content of `outerField` cannot be set by reflect api directly.
+      // But we can remove the `final` modifier of `outerField` before set 
value
+      // and reset the modifier after set value.
+      setFieldAndIgnoreModifiers(func, outerField, clonedOuterThis)
+    }
+  }
+
+
+  /**
+   * Cleans up Ammonite closures and nulls out fields captured from cmd & 
cmd$Helper objects
+   * but not actually accessed by the closure. To achieve this, it does:
+   * 1. Identify all accessed Ammonite cmd & cmd$Helper objects
+   * 2. Clone all accessed cmdX objects
+   * 3. Clone all accessed cmdX$Helper objects and set their $outer field to 
the cmdX clone
+   * 4. Iterate over these clones and set all other accessed fields to
+   *   - a clone, if the field refers to an Ammonite object
+   *   - a previous value otherwise
+   * 5. In case if capturing object is an inner class of Ammonite cmd$Helper 
object, clone & update
+   * this capturing object as well
+   *
+   * As a result:
+   *   - For all accessed cmdX objects all their references to cmdY$Helper 
objects are
+   * either nulled out or updated to cmdY clone
+   *   - For cmdX$Helper objects it means that variables defined in this 
command are
+   * nulled out if not accessed
+   * - lambda enclosing class is cleaned up as it's done for normal Scala 
closures
+   *
+   * @param func              the starting closure to clean
+   * @param lambdaProxy       starting closure proxy
+   * @param outerThis         lambda enclosing class
+   * @param cleanTransitively whether to clean enclosing closures transitively
+   */
+  private def cleanupAmmoniteReplClosure(
+      func: AnyRef,
+      lambdaProxy: SerializedLambda,
+      outerThis: AnyRef,
+      cleanTransitively: Boolean): Unit = {
+
+    val accessedFields: Map[Class[_], Set[String]] = Map.empty
+    initAccessedFields(accessedFields, Seq(outerThis.getClass))
+
+    // Ammonite generates 3 classes for a command number X:
+    //   - cmdX class containing all dependencies needed to execute the command
+    //   (i.e. previous command helpers)
+    //   - cmdX$Helper - inner class of cmdX - containing the user code. It 
pulls
+    //   required dependencies (i.e. variables defined in other commands) from 
outer command
+    //   - cmdX companion object holding an instance of cmdX and cmdX$Helper 
classes.
+    // Here, we care only about command objects and their helpers, companion 
objects are
+    // not captured by closure
+
+    // instances of cmdX and cmdX$Helper
+    val ammCmdInstances: Map[Class[_], AnyRef] = Map.empty
+    // fields accessed in those commands
+    val accessedAmmCmdFields: Map[Class[_], Set[String]] = Map.empty
+    // outer class may be either Ammonite cmd / cmd$Helper class or an inner 
class
+    // defined in a user code. We need to clean up Ammonite classes only
+    if (isAmmoniteCommandOrHelper(outerThis.getClass)) {
+      ammCmdInstances(outerThis.getClass) = outerThis
+      accessedAmmCmdFields(outerThis.getClass) = Set.empty
+    }
+
+    IndylambdaScalaClosures.findAccessedFields(
+      lambdaProxy,
+      func.getClass.getClassLoader,
+      accessedFields,
+      accessedAmmCmdFields,
+      ammCmdInstances,
+      cleanTransitively)
+
+    logTrace(s" + command fields accessed by starting closure: " +
+      s"${accessedAmmCmdFields.size} classes")
+    accessedAmmCmdFields.foreach { f => logTrace("     " + f) }
+
+    val cmdClones = Map[Class[_], AnyRef]()
+    for ((cmdClass, _) <- ammCmdInstances if 
!cmdClass.getName.contains("Helper")) {
+      logDebug(s" + Cloning instance of Ammonite command class 
${cmdClass.getName}")
+      cmdClones(cmdClass) = instantiateClass(cmdClass, enclosingObject = null)
+    }
+    for ((cmdHelperClass, cmdHelperInstance) <- ammCmdInstances
+         if cmdHelperClass.getName.contains("Helper")) {
+      val cmdHelperOuter = cmdHelperClass.getDeclaredFields
+        .find(_.getName == "$outer")
+        .map { field =>
+          field.setAccessible(true)
+          field.get(cmdHelperInstance)
+        }
+      val outerClone = cmdHelperOuter.flatMap(o => 
cmdClones.get(o.getClass)).orNull
+      logDebug(s" + Cloning instance of Ammonite command helper class 
${cmdHelperClass.getName}")
+      cmdClones(cmdHelperClass) =
+        instantiateClass(cmdHelperClass, enclosingObject = outerClone)
+    }
+
+    // set accessed fields
+    for ((_, cmdClone) <- cmdClones) {
+      val cmdClass = cmdClone.getClass
+      val accessedFields = accessedAmmCmdFields(cmdClass)
+      for (field <- cmdClone.getClass.getDeclaredFields
+           // outer fields were initialized during clone construction
+           if accessedFields.contains(field.getName) && field.getName != 
"$outer") {
+        // get command clone if exists, otherwise use an original field value
+        val value = cmdClones.getOrElse(field.getType, {
+          field.setAccessible(true)
+          field.get(ammCmdInstances(cmdClass))
+        })
+        setFieldAndIgnoreModifiers(cmdClone, field, value)
+      }
+    }
+
+    val outerThisClone = if (!isAmmoniteCommandOrHelper(outerThis.getClass)) {
+      // if outer class is not Ammonite helper / command object then is was 
not cloned
+      // in the code above. We still need to clone it and update accessed 
fields
+      logDebug(s" + Cloning instance of lambda capturing class 
${outerThis.getClass.getName}")
+      val clone = cloneAndSetFields(parent = null, outerThis, 
outerThis.getClass, accessedFields)
+      // making sure that the code below will update references to Ammonite 
objects if they exist
+      for (field <- outerThis.getClass.getDeclaredFields) {
+        field.setAccessible(true)
+        cmdClones.get(field.getType).foreach { value =>
+          setFieldAndIgnoreModifiers(clone, field, value)
+        }
+      }
+      clone
+    } else {
+      cmdClones(outerThis.getClass)
+    }
+
+    val outerField = func.getClass.getDeclaredField("arg$1")
+    // update lambda capturing class reference
+    setFieldAndIgnoreModifiers(func, outerField, outerThisClone)
+  }
+
+  private def setFieldAndIgnoreModifiers(obj: AnyRef, field: Field, value: 
AnyRef): Unit = {
+    val modifiersField = getFinalModifiersFieldForJava17(field)
+    modifiersField
+      .foreach(m => m.setInt(field, field.getModifiers & ~Modifier.FINAL))
+    field.setAccessible(true)
+    field.set(obj, value)
+
+    modifiersField
+      .foreach(m => m.setInt(field, field.getModifiers | Modifier.FINAL))
   }
 
   /**
@@ -434,19 +584,7 @@ private[spark] object ClosureCleaner extends Logging {
     } else None
   }
 
-  private def ensureSerializable(func: AnyRef): Unit = {
-    try {
-      if (SparkEnv.get != null) {
-        SparkEnv.get.closureSerializer.newInstance().serialize(func)
-      }
-    } catch {
-      case ex: Exception => throw new SparkException("Task not serializable", 
ex)
-    }
-  }
-
-  private def instantiateClass(
-      cls: Class[_],
-      enclosingObject: AnyRef): AnyRef = {
+  private def instantiateClass(cls: Class[_], enclosingObject: AnyRef): AnyRef 
= {
     // Use reflection to instantiate object without calling constructor
     val rf = sun.reflect.ReflectionFactory.getReflectionFactory()
     val parentCtor = classOf[java.lang.Object].getDeclaredConstructor()
@@ -561,6 +699,9 @@ private[spark] object IndylambdaScalaClosures extends 
Logging {
    * same for all three combined, so they can be fused together easily while 
maintaining the same
    * ordering as the existing implementation.
    *
+   * It also visits transitively Ammonite cmd and cmd%Helper objects it 
encounters
+   * and populates accessed fields for them to be able to clean up these as 
well
+   *
    * Precondition: this function expects the `accessedFields` to be populated 
with all known
    *               outer classes and their super classes to be in the map as 
keys, e.g.
    *               initializing via ClosureCleaner.initAccessedFields.
@@ -630,6 +771,8 @@ private[spark] object IndylambdaScalaClosures extends 
Logging {
       lambdaProxy: SerializedLambda,
       lambdaClassLoader: ClassLoader,
       accessedFields: Map[Class[_], Set[String]],
+      accessedAmmCmdFields: Map[Class[_], Set[String]],
+      ammCmdInstances: Map[Class[_], AnyRef],
       findTransitively: Boolean): Unit = {
 
     // We may need to visit the same class multiple times for different 
methods on it, and we'll
@@ -642,15 +785,30 @@ private[spark] object IndylambdaScalaClosures extends 
Logging {
         // scalastyle:off classforname
         val clazz = Class.forName(classExternalName, false, lambdaClassLoader)
         // scalastyle:on classforname
-        val classNode = new ClassNode()
-        val classReader = ClosureCleaner.getClassReader(clazz)
-        classReader.accept(classNode, 0)
 
-        for (m <- classNode.methods.asScala) {
-          methodNodeById(MethodIdentifier(clazz, m.name, m.desc)) = m
+        def getClassNode(clazz: Class[_]): ClassNode = {
+          val classNode = new ClassNode()
+          val classReader = ClosureCleaner.getClassReader(clazz)
+          classReader.accept(classNode, 0)
+          classNode
         }
 
-        (clazz, classNode)
+        var curClazz = clazz
+        // we need to add superclass methods as well
+        // e.g. consider the following closure:
+        // object Enclosing {
+        //   val closure = () => getClass.getName
+        // }
+        // To scan this closure properly, we need to add Object.getClass method
+        // to methodNodeById map
+        while (curClazz != null) {
+          for (m <- getClassNode(curClazz).methods.asScala) {
+            methodNodeById(MethodIdentifier(clazz, m.name, m.desc)) = m
+          }
+          curClazz = curClazz.getSuperclass
+        }
+
+        (clazz, getClassNode(clazz))
       })
       classInfo
     }
@@ -674,21 +832,55 @@ private[spark] object IndylambdaScalaClosures extends 
Logging {
     // to better find and track field accesses.
     val trackedClassInternalNames = Set[String](implClassInternalName)
 
-    // Depth-first search for inner closures and track the fields that were 
accessed in them.
+    // Breadth-first search for inner closures and track the fields that were 
accessed in them.
     // Start from the lambda body's implementation method, follow method 
invocations
     val visited = Set.empty[MethodIdentifier[_]]
-    val stack = Stack[MethodIdentifier[_]](implMethodId)
+    // Depth-first search will not work there. To make 
addAmmoniteCommandFieldsToTracking to work
+    // we need to process objects in order they appear in the reference tree.
+    // E.g. if there was a reference chain a -> b -> c, then DFS will process 
these nodes in order
+    // a -> c -> b. However, to initialize ammCmdInstances(c.getClass) we need 
to process node b
+    // first.
+    val queue = Queue[MethodIdentifier[_]](implMethodId)
     def pushIfNotVisited(methodId: MethodIdentifier[_]): Unit = {
       if (!visited.contains(methodId)) {
-        stack.push(methodId)
+        queue.enqueue(methodId)
       }
     }
 
-    while (!stack.isEmpty) {
-      val currentId = stack.pop()
+    def addAmmoniteCommandFieldsToTracking(currentClass: Class[_]): Unit = {
+      // get an instance of currentClass. It can be either lambda enclosing 
this
+      // or another already processed Ammonite object
+      val currentInstance = if (currentClass == 
lambdaProxy.getCapturedArg(0).getClass) {
+        Some(lambdaProxy.getCapturedArg(0))
+      } else {
+        // This key exists if we encountered a non-null reference to 
`currentClass` before
+        // as we're processing nodes with a breadth-first search (see comment 
above)
+        ammCmdInstances.get(currentClass)
+      }
+      currentInstance.foreach { cmdInstance =>
+        // track only cmdX and cmdX$Helper objects generated by Ammonite
+        for (otherCmdField <- cmdInstance.getClass.getDeclaredFields
+             if 
ClosureCleaner.isAmmoniteCommandOrHelper(otherCmdField.getType)) {
+          otherCmdField.setAccessible(true)
+          val otherCmdHelperRef = otherCmdField.get(cmdInstance)
+          val otherCmdClass = otherCmdField.getType
+          // Ammonite is clever enough to sometimes nullify references to 
unused commands.
+          // Ignoring these references for simplicity
+          if (otherCmdHelperRef != null && 
!ammCmdInstances.contains(otherCmdClass)) {
+            logTrace(s"      started tracking ${otherCmdClass.getName} 
Ammonite object")
+            ammCmdInstances(otherCmdClass) = otherCmdHelperRef
+            accessedAmmCmdFields(otherCmdClass) = Set()
+          }
+        }
+      }
+    }
+
+    while (queue.nonEmpty) {
+      val currentId = queue.dequeue()
       visited += currentId
 
       val currentClass = currentId.cls
+      addAmmoniteCommandFieldsToTracking(currentClass)
       val currentMethodNode = methodNodeById(currentId)
       logTrace(s"  scanning 
${currentId.cls.getName}.${currentId.name}${currentId.desc}")
       currentMethodNode.accept(new MethodVisitor(ASM9) {
@@ -704,6 +896,10 @@ private[spark] object IndylambdaScalaClosures extends 
Logging {
               logTrace(s"    found field access $name on $ownerExternalName")
               accessedFields(cl) += name
             }
+            for (cl <- accessedAmmCmdFields.keys if cl.getName == 
ownerExternalName) {
+              logTrace(s"    found Ammonite command field access $name on 
$ownerExternalName")
+              accessedAmmCmdFields(cl) += name
+            }
           }
         }
 
@@ -714,6 +910,10 @@ private[spark] object IndylambdaScalaClosures extends 
Logging {
             logTrace(s"    found intra class call to 
$ownerExternalName.$name$desc")
             // could be invoking a helper method or a field accessor method, 
just follow it.
             pushIfNotVisited(MethodIdentifier(currentClass, name, desc))
+          } else if (owner.startsWith("ammonite/$sess/cmd")) {
+            // we're inside Ammonite command / command helper object, track 
all calls from here
+            val classInfo = getOrUpdateClassInfo(owner)
+            pushIfNotVisited(MethodIdentifier(classInfo._1, name, desc))
           } else if (isInnerClassCtorCapturingOuter(
               op, owner, name, desc, currentClassInternalName)) {
             // Discover inner classes.
@@ -894,8 +1094,10 @@ private class InnerClosureFinder(output: Set[Class[_]]) 
extends ClassVisitor(ASM
         if (op == INVOKESPECIAL && name == "<init>" && argTypes.length > 0
             && argTypes(0).toString.startsWith("L") // is it an object?
             && argTypes(0).getInternalName == myName) {
-          output += Utils.classForName(owner.replace('/', '.'),
-            initialize = false, noSparkClassLoader = true)
+          output += SparkClassUtils.classForName(
+            owner.replace('/', '.'),
+            initialize = false,
+            noSparkClassLoader = true)
         }
       }
     }
diff --git 
a/common/utils/src/main/scala/org/apache/spark/util/SparkStreamUtils.scala 
b/common/utils/src/main/scala/org/apache/spark/util/SparkStreamUtils.scala
new file mode 100644
index 00000000000..b9148901f1a
--- /dev/null
+++ b/common/utils/src/main/scala/org/apache/spark/util/SparkStreamUtils.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.util
+
+import java.io.{FileInputStream, FileOutputStream, InputStream, OutputStream}
+import java.nio.channels.{FileChannel, WritableByteChannel}
+
+import org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally
+
+private[spark] trait SparkStreamUtils {
+
+  /**
+   * Copy all data from an InputStream to an OutputStream. NIO way of file 
stream to file stream
+   * copying is disabled by default unless explicitly set transferToEnabled as 
true, the parameter
+   * transferToEnabled should be configured by spark.file.transferTo = 
[true|false].
+   */
+  def copyStream(
+      in: InputStream,
+      out: OutputStream,
+      closeStreams: Boolean = false,
+      transferToEnabled: Boolean = false): Long = {
+    tryWithSafeFinally {
+      (in, out) match {
+        case (input: FileInputStream, output: FileOutputStream) if 
transferToEnabled =>
+          // When both streams are File stream, use transferTo to improve copy 
performance.
+          val inChannel = input.getChannel
+          val outChannel = output.getChannel
+          val size = inChannel.size()
+          copyFileStreamNIO(inChannel, outChannel, 0, size)
+          size
+        case (input, output) =>
+          var count = 0L
+          val buf = new Array[Byte](8192)
+          var n = 0
+          while (n != -1) {
+            n = input.read(buf)
+            if (n != -1) {
+              output.write(buf, 0, n)
+              count += n
+            }
+          }
+          count
+      }
+    } {
+      if (closeStreams) {
+        try {
+          in.close()
+        } finally {
+          out.close()
+        }
+      }
+    }
+  }
+
+  def copyFileStreamNIO(
+      input: FileChannel,
+      output: WritableByteChannel,
+      startPosition: Long,
+      bytesToCopy: Long): Unit = {
+    val outputInitialState = output match {
+      case outputFileChannel: FileChannel =>
+        Some((outputFileChannel.position(), outputFileChannel))
+      case _ => None
+    }
+    var count = 0L
+    // In case transferTo method transferred less data than we have required.
+    while (count < bytesToCopy) {
+      count += input.transferTo(count + startPosition, bytesToCopy - count, 
output)
+    }
+    assert(
+      count == bytesToCopy,
+      s"request to copy $bytesToCopy bytes, but actually copied $count bytes.")
+
+    // Check the position after transferTo loop to see if it is in the right 
position and
+    // give user information if not.
+    // Position will not be increased to the expected length after calling 
transferTo in
+    // kernel version 2.6.32, this issue can be seen in
+    // https://bugs.openjdk.java.net/browse/JDK-7052359
+    // This will lead to stream corruption issue when using sort-based shuffle 
(SPARK-3948).
+    outputInitialState.foreach { case (initialPos, outputFileChannel) =>
+      val finalPos = outputFileChannel.position()
+      val expectedPos = initialPos + bytesToCopy
+      assert(
+        finalPos == expectedPos,
+        s"""
+           |Current position $finalPos do not equal to expected position 
$expectedPos
+           |after transferTo, please check your kernel version to see if it is 
2.6.32,
+           |this is a kernel bug which will lead to unexpected behavior when 
using transferTo.
+           |You can set spark.file.transferTo = false to disable this NIO 
feature.
+         """.stripMargin)
+    }
+  }
+}
+
+private [spark] object SparkStreamUtils extends SparkStreamUtils
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
index dcc038eb51d..c4431e9a87f 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
@@ -16,6 +16,7 @@
  */
 package org.apache.spark.sql.expressions
 
+import scala.collection.mutable
 import scala.jdk.CollectionConverters._
 import scala.reflect.runtime.universe.TypeTag
 import scala.util.control.NonFatal
@@ -29,7 +30,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection
 import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
 import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, UdfPacket}
 import org.apache.spark.sql.types.DataType
-import org.apache.spark.util.{SparkClassUtils, SparkSerDeUtils}
+import org.apache.spark.util.{ClosureCleaner, SparkClassUtils, SparkSerDeUtils}
 
 /**
  * A user-defined function. To create one, use the `udf` functions in 
`functions`.
@@ -183,6 +184,7 @@ object ScalarUserDefinedFunction {
       function: AnyRef,
       inputEncoders: Seq[AgnosticEncoder[_]],
       outputEncoder: AgnosticEncoder[_]): ScalarUserDefinedFunction = {
+    SparkConnectClosureCleaner.clean(function)
     val udfPacketBytes =
       SparkSerDeUtils.serialize(UdfPacket(function, inputEncoders, 
outputEncoder))
     checkDeserializable(udfPacketBytes)
@@ -202,3 +204,9 @@ object ScalarUserDefinedFunction {
       outputEncoder = RowEncoder.encoderForDataType(returnType, lenient = 
false))
   }
 }
+
+private object SparkConnectClosureCleaner {
+  def clean(closure: AnyRef): Unit = {
+    ClosureCleaner.clean(closure, cleanTransitively = true, mutable.Map.empty)
+  }
+}
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
index 5bb8cbf3543..51e58f9b0bb 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
@@ -362,4 +362,147 @@ class ReplE2ESuite extends RemoteSparkSession with 
BeforeAndAfterEach {
     val output = runCommandsInShell(input)
     assertContains("noException: Boolean = true", output)
   }
+
+  test("closure cleaner") {
+    val input =
+      """
+        |class NonSerializable(val id: Int = -1) { }
+        |
+        |{
+        |  val x = 100
+        |  val y = new NonSerializable
+        |}
+        |
+        |val t = 200
+        |
+        |{
+        |  def foo(): Int = { x }
+        |  def bar(): Int = { y.id }
+        |  val z = new NonSerializable
+        |}
+        |
+        |{
+        |  val myLambda = (a: Int) => a + t + foo()
+        |  val myUdf = udf(myLambda)
+        |}
+        |
+        |spark.range(0, 10).
+        |  withColumn("result", myUdf(col("id"))).
+        |  agg(sum("result")).
+        |  collect()(0)(0).asInstanceOf[Long]
+        |""".stripMargin
+    val output = runCommandsInShell(input)
+    assertContains(": Long = 3045", output)
+  }
+
+  test("closure cleaner with function") {
+    val input =
+      """
+        |class NonSerializable(val id: Int = -1) { }
+        |
+        |{
+        |  val x = 100
+        |  val y = new NonSerializable
+        |}
+        |
+        |{
+        |  def foo(): Int = { x }
+        |  def bar(): Int = { y.id }
+        |  val z = new NonSerializable
+        |}
+        |
+        |def example() = {
+        |  val myLambda = (a: Int) => a + foo()
+        |  val myUdf = udf(myLambda)
+        |  spark.range(0, 10).
+        |    withColumn("result", myUdf(col("id"))).
+        |    agg(sum("result")).
+        |    collect()(0)(0).asInstanceOf[Long]
+        |}
+        |
+        |example()
+        |""".stripMargin
+    val output = runCommandsInShell(input)
+    assertContains(": Long = 1045", output)
+  }
+
+  test("closure cleaner nested") {
+    val input =
+      """
+        |class NonSerializable(val id: Int = -1) { }
+        |
+        |{
+        |  val x = 100
+        |  val y = new NonSerializable
+        |}
+        |
+        |{
+        |  def foo(): Int = { x }
+        |  def bar(): Int = { y.id }
+        |  val z = new NonSerializable
+        |}
+        |
+        |val example = () => {
+        |  val nested = () => {
+        |    val myLambda = (a: Int) => a + foo()
+        |    val myUdf = udf(myLambda)
+        |    spark.range(0, 10).
+        |      withColumn("result", myUdf(col("id"))).
+        |      agg(sum("result")).
+        |      collect()(0)(0).asInstanceOf[Long]
+        |  }
+        |  nested()
+        |}
+        |example()
+        |""".stripMargin
+    val output = runCommandsInShell(input)
+    assertContains(": Long = 1045", output)
+  }
+
+  test("closure cleaner with enclosing lambdas") {
+    val input =
+      """
+        |class NonSerializable(val id: Int = -1) { }
+        |
+        |{
+        |  val x = 100
+        |  val y = new NonSerializable
+        |}
+        |
+        |val z = new NonSerializable
+        |
+        |spark.range(0, 10).
+        |// for this call UdfUtils will create a new lambda and this lambda 
becomes enclosing
+        |  map(i => i + x).
+        |  agg(sum("value")).
+        |  collect()(0)(0).asInstanceOf[Long]
+        |""".stripMargin
+    val output = runCommandsInShell(input)
+    assertContains(": Long = 1045", output)
+  }
+
+  test("closure cleaner cleans capturing class") {
+    val input =
+      """
+        |class NonSerializable(val id: Int = -1) { }
+        |
+        |{
+        |  val x = 100
+        |  val y = new NonSerializable
+        |}
+        |
+        |class Test extends Serializable {
+        |  // capturing class is cmd$Helper$Test
+        |  val myUdf = udf((i: Int) => i + x)
+        |  val z = new NonSerializable
+        |  val res = spark.range(0, 10).
+        |    withColumn("result", myUdf(col("id"))).
+        |    agg(sum("result")).
+        |    collect()(0)(0).asInstanceOf[Long]
+        |}
+        |(new Test()).res
+        |""".stripMargin
+    val output = runCommandsInShell(input)
+    assertContains(": Long = 1045", output)
+  }
 }
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index 785e1fa4017..7ddb339b12d 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -324,6 +324,14 @@ object CheckConnectJvmClientCompatibility {
       ProblemFilters.exclude[MissingClassProblem](
         "org.apache.spark.sql.expressions.ScalarUserDefinedFunction$"),
 
+      // New private API added in the client
+      ProblemFilters
+        .exclude[MissingClassProblem](
+          "org.apache.spark.sql.expressions.SparkConnectClosureCleaner"),
+      ProblemFilters
+        .exclude[MissingClassProblem](
+          "org.apache.spark.sql.expressions.SparkConnectClosureCleaner$"),
+
       // Dataset
       ProblemFilters.exclude[DirectMissingMethodProblem](
         "org.apache.spark.sql.Dataset.plan"
diff --git a/core/pom.xml b/core/pom.xml
index 5ac3d5bb4de..e55283b75fa 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -64,10 +64,6 @@
       <artifactId>jnr-posix</artifactId>
       <scope>test</scope>
     </dependency>
-    <dependency>
-      <groupId>org.apache.xbean</groupId>
-      <artifactId>xbean-asm9-shaded</artifactId>
-    </dependency>
     <dependency>
       <groupId>org.apache.hadoop</groupId>
       <artifactId>hadoop-client-api</artifactId>
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala 
b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 893895e8fb2..c86f755bbd1 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -2699,7 +2699,7 @@ class SparkContext(config: SparkConf) extends Logging {
    * @return the cleaned closure
    */
   private[spark] def clean[F <: AnyRef](f: F, checkSerializable: Boolean = 
true): F = {
-    ClosureCleaner.clean(f, checkSerializable)
+    SparkClosureCleaner.clean(f, checkSerializable)
     f
   }
 
diff --git 
a/core/src/main/scala/org/apache/spark/util/SparkClosureCleaner.scala 
b/core/src/main/scala/org/apache/spark/util/SparkClosureCleaner.scala
new file mode 100644
index 00000000000..44e0efb4494
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/SparkClosureCleaner.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import scala.collection.mutable
+
+import org.apache.spark.{SparkEnv, SparkException}
+
+private[spark] object SparkClosureCleaner {
+  /**
+   * Clean the given closure in place.
+   *
+   * More specifically, this renders the given closure serializable as long as 
it does not
+   * explicitly reference unserializable objects.
+   *
+   * @param closure           the closure to clean
+   * @param checkSerializable whether to verify that the closure is 
serializable after cleaning
+   * @param cleanTransitively whether to clean enclosing closures transitively
+   */
+  def clean(
+      closure: AnyRef,
+      checkSerializable: Boolean = true,
+      cleanTransitively: Boolean = true): Unit = {
+    if (ClosureCleaner.clean(closure, cleanTransitively, mutable.Map.empty)) {
+      try {
+        if (checkSerializable && SparkEnv.get != null) {
+          SparkEnv.get.closureSerializer.newInstance().serialize(closure)
+        }
+      } catch {
+        case ex: Exception => throw new SparkException("Task not 
serializable", ex)
+      }
+    }
+  }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala 
b/core/src/main/scala/org/apache/spark/util/Utils.scala
index f8decbcff5f..f22bec5c2be 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -24,7 +24,7 @@ import java.lang.reflect.InvocationTargetException
 import java.math.{MathContext, RoundingMode}
 import java.net._
 import java.nio.ByteBuffer
-import java.nio.channels.{Channels, FileChannel, WritableByteChannel}
+import java.nio.channels.Channels
 import java.nio.charset.StandardCharsets
 import java.nio.file.Files
 import java.security.SecureRandom
@@ -97,7 +97,8 @@ private[spark] object Utils
   with SparkClassUtils
   with SparkErrorUtils
   with SparkFileUtils
-  with SparkSerDeUtils {
+  with SparkSerDeUtils
+  with SparkStreamUtils {
 
   private val sparkUncaughtExceptionHandler = new SparkUncaughtExceptionHandler
   @volatile private var cachedLocalDir: String = ""
@@ -244,49 +245,6 @@ private[spark] object Utils
     dir
   }
 
-  /**
-   * Copy all data from an InputStream to an OutputStream. NIO way of file 
stream to file stream
-   * copying is disabled by default unless explicitly set transferToEnabled as 
true,
-   * the parameter transferToEnabled should be configured by 
spark.file.transferTo = [true|false].
-   */
-  def copyStream(
-      in: InputStream,
-      out: OutputStream,
-      closeStreams: Boolean = false,
-      transferToEnabled: Boolean = false): Long = {
-    tryWithSafeFinally {
-      (in, out) match {
-        case (input: FileInputStream, output: FileOutputStream) if 
transferToEnabled =>
-          // When both streams are File stream, use transferTo to improve copy 
performance.
-          val inChannel = input.getChannel
-          val outChannel = output.getChannel
-          val size = inChannel.size()
-          copyFileStreamNIO(inChannel, outChannel, 0, size)
-          size
-        case (input, output) =>
-          var count = 0L
-          val buf = new Array[Byte](8192)
-          var n = 0
-          while (n != -1) {
-            n = input.read(buf)
-            if (n != -1) {
-              output.write(buf, 0, n)
-              count += n
-            }
-          }
-          count
-      }
-    } {
-      if (closeStreams) {
-        try {
-          in.close()
-        } finally {
-          out.close()
-        }
-      }
-    }
-  }
-
   /**
    * Copy the first `maxSize` bytes of data from the InputStream to an 
in-memory
    * buffer, primarily to check for corruption.
@@ -331,43 +289,6 @@ private[spark] object Utils
     }
   }
 
-  def copyFileStreamNIO(
-      input: FileChannel,
-      output: WritableByteChannel,
-      startPosition: Long,
-      bytesToCopy: Long): Unit = {
-    val outputInitialState = output match {
-      case outputFileChannel: FileChannel =>
-        Some((outputFileChannel.position(), outputFileChannel))
-      case _ => None
-    }
-    var count = 0L
-    // In case transferTo method transferred less data than we have required.
-    while (count < bytesToCopy) {
-      count += input.transferTo(count + startPosition, bytesToCopy - count, 
output)
-    }
-    assert(count == bytesToCopy,
-      s"request to copy $bytesToCopy bytes, but actually copied $count bytes.")
-
-    // Check the position after transferTo loop to see if it is in the right 
position and
-    // give user information if not.
-    // Position will not be increased to the expected length after calling 
transferTo in
-    // kernel version 2.6.32, this issue can be seen in
-    // https://bugs.openjdk.java.net/browse/JDK-7052359
-    // This will lead to stream corruption issue when using sort-based shuffle 
(SPARK-3948).
-    outputInitialState.foreach { case (initialPos, outputFileChannel) =>
-      val finalPos = outputFileChannel.position()
-      val expectedPos = initialPos + bytesToCopy
-      assert(finalPos == expectedPos,
-        s"""
-           |Current position $finalPos do not equal to expected position 
$expectedPos
-           |after transferTo, please check your kernel version to see if it is 
2.6.32,
-           |this is a kernel bug which will lead to unexpected behavior when 
using transferTo.
-           |You can set spark.file.transferTo = false to disable this NIO 
feature.
-         """.stripMargin)
-    }
-  }
-
   /**
    * A file name may contain some invalid URI characters, such as " ". This 
method will convert the
    * file name to a raw path accepted by `java.net.URI(String)`.
diff --git 
a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala 
b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
index cef0d8c1de0..2f084b2037e 100644
--- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
@@ -373,7 +373,7 @@ class TestCreateNullValue {
         println(getX)
       }
       // scalastyle:on println
-      ClosureCleaner.clean(closure)
+      SparkClosureCleaner.clean(closure)
     }
     nestedClosure()
   }
diff --git 
a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala 
b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala
index 0635b4a358a..b055dae3994 100644
--- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala
+++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala
@@ -96,10 +96,10 @@ class ClosureCleanerSuite2 extends SparkFunSuite with 
BeforeAndAfterAll with Pri
     // If the resulting closure is not serializable even after
     // cleaning, we expect ClosureCleaner to throw a SparkException
     if (serializableAfter) {
-      ClosureCleaner.clean(closure, checkSerializable = true, transitive)
+      SparkClosureCleaner.clean(closure, checkSerializable = true, transitive)
     } else {
       intercept[SparkException] {
-        ClosureCleaner.clean(closure, checkSerializable = true, transitive)
+        SparkClosureCleaner.clean(closure, checkSerializable = true, 
transitive)
       }
     }
     assertSerializable(closure, serializableAfter)
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 47fd7881d2f..10864390e3f 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -43,7 +43,9 @@ object MimaExcludes {
     // [SPARK-44198][CORE] Support propagation of the log level to the 
executors
     
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages$SparkAppConfig$"),
     // [SPARK-45427][CORE] Add RPC SSL settings to SSLOptions and 
SparkTransportConf
-    
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.SparkTransportConf.fromSparkConf")
+    
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.SparkTransportConf.fromSparkConf"),
+    // [SPARK-45136][CONNECT] Enhance ClosureCleaner with Ammonite support
+    
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.MethodIdentifier$")
   )
 
   // Default exclude rules
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index 1c77b87dbf1..dc5e22f0571 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -38,7 +38,7 @@ import org.apache.spark.sql.errors.QueryErrorsBase
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
-import org.apache.spark.util.{ClosureCleaner, Utils}
+import org.apache.spark.util.{SparkClosureCleaner, Utils}
 
 case class RepeatedStruct(s: Seq[PrimitiveData])
 
@@ -689,7 +689,7 @@ class ExpressionEncoderSuite extends 
CodegenInterpretedPlanTest with AnalysisTes
       val encoder = implicitly[ExpressionEncoder[T]]
 
       // Make sure encoder is serializable.
-      ClosureCleaner.clean((s: String) => encoder.getClass.getName)
+      SparkClosureCleaner.clean((s: String) => encoder.getClass.getName)
 
       val row = encoder.createSerializer().apply(input)
       val schema = toAttributes(encoder.schema)
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala 
b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
index dcd698c860d..f04b9da9b45 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
@@ -22,7 +22,7 @@ import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.{JavaPairRDD, JavaUtils, Optional}
 import org.apache.spark.api.java.function.{Function3 => JFunction3, Function4 
=> JFunction4}
 import org.apache.spark.rdd.RDD
-import org.apache.spark.util.ClosureCleaner
+import org.apache.spark.util.SparkClosureCleaner
 
 /**
  * :: Experimental ::
@@ -157,7 +157,7 @@ object StateSpec {
   def function[KeyType, ValueType, StateType, MappedType](
       mappingFunction: (Time, KeyType, Option[ValueType], State[StateType]) => 
Option[MappedType]
     ): StateSpec[KeyType, ValueType, StateType, MappedType] = {
-    ClosureCleaner.clean(mappingFunction, checkSerializable = true)
+    SparkClosureCleaner.clean(mappingFunction, checkSerializable = true)
     new StateSpecImpl(mappingFunction)
   }
 
@@ -175,7 +175,7 @@ object StateSpec {
   def function[KeyType, ValueType, StateType, MappedType](
       mappingFunction: (KeyType, Option[ValueType], State[StateType]) => 
MappedType
     ): StateSpec[KeyType, ValueType, StateType, MappedType] = {
-    ClosureCleaner.clean(mappingFunction, checkSerializable = true)
+    SparkClosureCleaner.clean(mappingFunction, checkSerializable = true)
     val wrappedFunction =
       (time: Time, key: KeyType, value: Option[ValueType], state: 
State[StateType]) => {
         Some(mappingFunction(key, value, state))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to