This is an automated email from the ASF dual-hosted git repository.

aljoscha pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 1b9cdf559adacd9af8a1f84647a1aed99b2c56dc
Author: Aljoscha Krettek <[email protected]>
AuthorDate: Tue Oct 2 10:12:03 2018 +0200

    [FLINK-7816] Support Scala 2.12 closures and Java 8 lambdas in 
ClosureCleaner
    
    This updates the ClosureCleaner with recent changes from SPARK-14540.
---
 flink-scala/pom.xml                                |   5 +
 .../apache/flink/api/scala/ClosureCleaner.scala    | 564 ++++++++++++++++-----
 .../api/scala/functions/ClosureCleanerITCase.scala |   2 +-
 pom.xml                                            |   6 +
 4 files changed, 446 insertions(+), 131 deletions(-)

diff --git a/flink-scala/pom.xml b/flink-scala/pom.xml
index 508a8de..4472207 100644
--- a/flink-scala/pom.xml
+++ b/flink-scala/pom.xml
@@ -51,6 +51,11 @@ under the License.
                </dependency>
 
                <dependency>
+                       <groupId>org.apache.flink</groupId>
+                       <artifactId>flink-shaded-asm-6</artifactId>
+               </dependency>
+
+               <dependency>
                        <groupId>org.scala-lang</groupId>
                        <artifactId>scala-reflect</artifactId>
                </dependency>
diff --git 
a/flink-scala/src/main/scala/org/apache/flink/api/scala/ClosureCleaner.scala 
b/flink-scala/src/main/scala/org/apache/flink/api/scala/ClosureCleaner.scala
index 7965346..2932ed5 100644
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/ClosureCleaner.scala
+++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/ClosureCleaner.scala
@@ -18,35 +18,40 @@
 package org.apache.flink.api.scala
 
 import java.io._
+import java.lang.invoke.SerializedLambda
 
 import org.apache.flink.annotation.Internal
 import org.apache.flink.api.common.InvalidProgramException
-import org.apache.flink.util.InstantiationUtil
+import org.apache.flink.util.{FlinkException, InstantiationUtil}
 import org.slf4j.LoggerFactory
 
 import scala.collection.mutable.Map
 import scala.collection.mutable.Set
+import org.apache.flink.shaded.asm6.org.objectweb.asm.{ClassReader, 
ClassVisitor, MethodVisitor, Type}
+import org.apache.flink.shaded.asm6.org.objectweb.asm.Opcodes._
 
-import org.apache.flink.shaded.asm5.org.objectweb.asm.{ClassReader, 
ClassVisitor, MethodVisitor, Type}
-import org.apache.flink.shaded.asm5.org.objectweb.asm.Opcodes._
+import scala.collection.mutable
 
 /* This code is originally from the Apache Spark project. */
 @Internal
 object ClosureCleaner {
+
   val LOG = LoggerFactory.getLogger(this.getClass)
 
+  private val isScala2_11 = 
scala.util.Properties.versionString.contains("2.11")
+
   // Get an ASM class reader for a given class from the JAR that loaded it
-  private def getClassReader(cls: Class[_]): ClassReader = {
+  private[scala] def getClassReader(cls: Class[_]): ClassReader = {
     // Copy data over, before delegating to ClassReader - else we can run out 
of open file handles.
     val className = cls.getName.replaceFirst("^.*\\.", "") + ".class"
     val resourceStream = cls.getResourceAsStream(className)
-    // todo: Fixme - continuing with earlier behavior ...
-    if (resourceStream == null) return new ClassReader(resourceStream)
-
-    val baos = new ByteArrayOutputStream(128)
-
-    copyStream(resourceStream, baos, true)
-    new ClassReader(new ByteArrayInputStream(baos.toByteArray))
+    if (resourceStream == null) {
+      null
+    } else {
+      val baos = new ByteArrayOutputStream(128)
+      copyStream(resourceStream, baos, true)
+      new ClassReader(new ByteArrayInputStream(baos.toByteArray))
+    }
   }
 
   // Check whether a class represents a Scala closure
@@ -54,110 +59,339 @@ object ClosureCleaner {
     cls.getName.contains("$anonfun$")
   }
 
-  // Get a list of the classes of the outer objects of a given closure object, 
obj;
+  // Get a list of the outer objects and their classes of a given closure 
object, obj;
   // the outer objects are defined as any closures that obj is nested within, 
plus
   // possibly the class that the outermost closure is in, if any. We stop 
searching
   // for outer objects beyond that because cloning the user's object is 
probably
   // not a good idea (whereas we can clone closure objects just fine since we
   // understand how all their fields are used).
-  private def getOuterClasses(obj: AnyRef): List[Class[_]] = {
+  private def getOuterClassesAndObjects(obj: AnyRef): (List[Class[_]], 
List[AnyRef]) = {
     for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") {
       f.setAccessible(true)
-      if (isClosure(f.getType)) {
-        return f.getType :: getOuterClasses(f.get(obj))
-      } else {
-        return f.getType :: Nil // Stop at the first $outer that is not a 
closure
+      val outer = f.get(obj)
+      // The outer pointer may be null if we have cleaned this closure before
+      if (outer != null) {
+        if (isClosure(f.getType)) {
+          val recurRet = getOuterClassesAndObjects(outer)
+          return (f.getType :: recurRet._1, outer :: recurRet._2)
+        } else {
+          return (f.getType :: Nil, outer :: Nil) // Stop at the first $outer 
that is not a closure
+        }
       }
     }
-    Nil
+    (Nil, Nil)
   }
-
-  // Get a list of the outer objects for a given closure object.
-  private def getOuterObjects(obj: AnyRef): List[AnyRef] = {
-    for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") {
-      f.setAccessible(true)
-      if (isClosure(f.getType)) {
-        return f.get(obj) :: getOuterObjects(f.get(obj))
-      } else {
-        return f.get(obj) :: Nil // Stop at the first $outer that is not a 
closure
+  /**
+   * Return a list of classes that represent closures enclosed in the given 
closure object.
+   */
+  private def getInnerClosureClasses(obj: AnyRef): List[Class[_]] = {
+    val seen = Set[Class[_]](obj.getClass)
+    val stack = mutable.Stack[Class[_]](obj.getClass)
+    while (!stack.isEmpty) {
+      val cr = getClassReader(stack.pop())
+      if (cr != null) {
+        val set = Set.empty[Class[_]]
+        cr.accept(new InnerClosureFinder(set), 0)
+        for (cls <- set -- seen) {
+          seen += cls
+          stack.push(cls)
+        }
       }
     }
-    Nil
+    (seen - obj.getClass).toList
   }
 
-  private def getInnerClasses(obj: AnyRef): List[Class[_]] = {
-    val seen = Set[Class[_]](obj.getClass)
-    var stack = List[Class[_]](obj.getClass)
-    while (stack.nonEmpty) {
-      val cr = getClassReader(stack.head)
-      stack = stack.tail
-      val set = Set[Class[_]]()
-      cr.accept(new InnerClosureFinder(set), 0)
-      for (cls <- set -- seen) {
-        seen += cls
-        stack = cls :: stack
+  /** Initializes the accessed fields for outer classes and their super 
classes. */
+  private def initAccessedFields(
+      accessedFields: Map[Class[_], Set[String]],
+      outerClasses: Seq[Class[_]]): Unit = {
+    for (cls <- outerClasses) {
+      var currentClass = cls
+      assert(currentClass != null, "The outer class can't be null.")
+
+      while (currentClass != null) {
+        accessedFields(currentClass) = Set.empty[String]
+        currentClass = currentClass.getSuperclass()
       }
     }
-    (seen - obj.getClass).toList
   }
 
-  private def createNullValue(cls: Class[_]): AnyRef = {
-    if (cls.isPrimitive) {
-      new java.lang.Byte(0: Byte) // Should be convertible to any primitive 
type
-    } else {
-      null
+  /** Sets accessed fields for given class in clone object based on given 
object. */
+  private def setAccessedFields(
+      outerClass: Class[_],
+      clone: AnyRef,
+      obj: AnyRef,
+      accessedFields: Map[Class[_], Set[String]]): Unit = {
+    for (fieldName <- accessedFields(outerClass)) {
+      val field = outerClass.getDeclaredField(fieldName)
+      field.setAccessible(true)
+      val value = field.get(obj)
+      field.set(clone, value)
     }
   }
 
-  def clean(func: AnyRef, checkSerializable: Boolean = true) {
-    // TODO: cache outerClasses / innerClasses / accessedFields
-    val outerClasses = getOuterClasses(func)
-    val innerClasses = getInnerClasses(func)
-    val outerObjects = getOuterObjects(func)
+  /** Clones a given object and sets accessed fields in cloned object. */
+  private def cloneAndSetFields(
+      parent: AnyRef,
+      obj: AnyRef,
+      outerClass: Class[_],
+      accessedFields: Map[Class[_], Set[String]]): AnyRef = {
+    val clone = instantiateClass(outerClass, parent)
 
-    val accessedFields = Map[Class[_], Set[String]]()
+    var currentClass = outerClass
+    assert(currentClass != null, "The outer class can't be null.")
 
-    getClassReader(func.getClass).accept(new ReturnStatementFinder(), 0)
+    while (currentClass != null) {
+      setAccessedFields(currentClass, clone, obj, accessedFields)
+      currentClass = currentClass.getSuperclass()
+    }
 
-    for (cls <- outerClasses)
-      accessedFields(cls) = Set[String]()
-    for (cls <- func.getClass :: innerClasses)
-      getClassReader(cls).accept(new FieldAccessFinder(accessedFields), 0)
+    clone
+  }
 
-    if (LOG.isDebugEnabled) {
-      LOG.debug("accessedFields: " + accessedFields)
-    }
+  /**
+   * Clean the given closure in place.
+   *
+   * More specifically, this renders the given closure serializable as long as 
it does not
+   * explicitly reference unserializable objects.
+   *
+   * @param closure the closure to clean
+   * @param checkSerializable whether to verify that the closure is 
serializable after cleaning
+   * @param cleanTransitively whether to clean enclosing closures transitively
+   */
+  def clean(
+      closure: AnyRef,
+      checkSerializable: Boolean = true,
+      cleanTransitively: Boolean = true): Unit = {
+    clean(closure, checkSerializable, cleanTransitively, Map.empty)
+  }
 
-    var outerPairs: List[(Class[_], AnyRef)] = (outerClasses zip 
outerObjects).reverse
-    var outer: AnyRef = null
-    if (outerPairs.nonEmpty && !isClosure(outerPairs.head._1)) {
-      // The closure is ultimately nested inside a class; keep the object of 
that
-      // class without cloning it since we don't want to clone the user's 
objects.
-      outer = outerPairs.head._2
-      outerPairs = outerPairs.tail
+  /**
+   * Try to get a serialized Lambda from the closure.
+   *
+   * @param closure the closure to check.
+   */
+  private def getSerializedLambda(closure: AnyRef): Option[SerializedLambda] = 
{
+    if (isScala2_11) {
+      return None
     }
-    // Clone the closure objects themselves, nulling out any fields that are 
not
-    // used in the closure we're working on or any of its inner closures.
-    for ((cls, obj) <- outerPairs) {
-      outer = instantiateClass(cls, outer)
-      for (fieldName <- accessedFields(cls)) {
-        val field = cls.getDeclaredField(fieldName)
-        field.setAccessible(true)
-        val value = field.get(obj)
-        if (LOG.isDebugEnabled) {
-          LOG.debug("1: Setting " + fieldName + " on " + cls + " to " + value)
-        }
-        field.set(outer, value)
+    val isClosureCandidate =
+      closure.getClass.isSynthetic &&
+        closure
+          .getClass
+          .getInterfaces.exists(_.getName == "scala.Serializable")
+
+    if (isClosureCandidate) {
+      try {
+        Option(inspect(closure))
+      } catch {
+        case e: Exception =>
+          if (LOG.isDebugEnabled) {
+            LOG.debug("Closure is not a serialized lambda.", e)
+          }
+          None
       }
+    } else {
+      None
     }
+  }
+
+  private def inspect(closure: AnyRef): SerializedLambda = {
+    val writeReplace = closure.getClass.getDeclaredMethod("writeReplace")
+    writeReplace.setAccessible(true)
+    
writeReplace.invoke(closure).asInstanceOf[java.lang.invoke.SerializedLambda]
+  }
+
+  /**
+   * Helper method to clean the given closure in place.
+   *
+   * The mechanism is to traverse the hierarchy of enclosing closures and null 
out any
+   * references along the way that are not actually used by the starting 
closure, but are
+   * nevertheless included in the compiled anonymous classes. Note that it is 
unsafe to
+   * simply mutate the enclosing closures in place, as other code paths may 
depend on them.
+   * Instead, we clone each enclosing closure and set the parent pointers 
accordingly.
+   *
+   * By default, closures are cleaned transitively. This means we detect 
whether enclosing
+   * objects are actually referenced by the starting one, either directly or 
transitively,
+   * and, if not, sever these closures from the hierarchy. In other words, in 
addition to
+   * nulling out unused field references, we also null out any parent pointers 
that refer
+   * to enclosing objects not actually needed by the starting closure. We 
determine
+   * transitivity by tracing through the tree of all methods ultimately 
invoked by the
+   * inner closure and record all the fields referenced in the process.
+   *
+   * For instance, transitive cleaning is necessary in the following scenario:
+   *
+   *   class SomethingNotSerializable {
+   *     def someValue = 1
+   *     def scope(name: String)(body: => Unit) = body
+   *     def someMethod(): Unit = scope("one") {
+   *       def x = someValue
+   *       def y = 2
+   *       scope("two") { println(y + 1) }
+   *     }
+   *   }
+   *
+   * In this example, scope "two" is not serializable because it references 
scope "one", which
+   * references SomethingNotSerializable. Note that, however, the body of 
scope "two" does not
+   * actually depend on SomethingNotSerializable. This means we can safely 
null out the parent
+   * pointer of a cloned scope "one" and set it the parent of scope "two", 
such that scope "two"
+   * no longer references SomethingNotSerializable transitively.
+   *
+   * @param func the starting closure to clean
+   * @param checkSerializable whether to verify that the closure is 
serializable after cleaning
+   * @param cleanTransitively whether to clean enclosing closures transitively
+   * @param accessedFields a map from a class to a set of its fields that are 
accessed by
+   *                       the starting closure
+   */
+  private def clean(
+      func: AnyRef,
+      checkSerializable: Boolean,
+      cleanTransitively: Boolean,
+      accessedFields: Map[Class[_], Set[String]]): Unit = {
+
+    // most likely to be the case with 2.12, 2.13
+    // so we check first
+    // non LMF-closures should be less frequent from now on
+    val lambdaFunc = getSerializedLambda(func)
+
+    if (!isClosure(func.getClass) && lambdaFunc.isEmpty) {
+      LOG.debug(s"Expected a closure; got ${func.getClass.getName}")
+      return
+    }
+
+    // TODO: clean all inner closures first. This requires us to find the 
inner objects.
+    // TODO: cache outerClasses / innerClasses / accessedFields
+
+    if (func == null) {
+      return
+    }
+
+    if (lambdaFunc.isEmpty) {
+      LOG.debug(s"+++ Cleaning closure $func (${func.getClass.getName}) +++")
+
+      // A list of classes that represents closures enclosed in the given one
+      val innerClasses = getInnerClosureClasses(func)
+
+      // A list of enclosing objects and their respective classes, from 
innermost to outermost
+      // An outer object at a given index is of type outer class at the same 
index
+      val (outerClasses, outerObjects) = getOuterClassesAndObjects(func)
+
+      // For logging purposes only
+      val declaredFields = func.getClass.getDeclaredFields
+      val declaredMethods = func.getClass.getDeclaredMethods
 
-    if (outer != null) {
       if (LOG.isDebugEnabled) {
-        LOG.debug("2: Setting $outer on " + func.getClass + " to " + outer)
+        LOG.debug(s" + declared fields: ${declaredFields.size}")
+        declaredFields.foreach { f => LOG.debug(s"     $f") }
+        LOG.debug(s" + declared methods: ${declaredMethods.size}")
+        declaredMethods.foreach { m => LOG.debug(s"     $m") }
+        LOG.debug(s" + inner classes: ${innerClasses.size}")
+        innerClasses.foreach { c => LOG.debug(s"     ${c.getName}") }
+        LOG.debug(s" + outer classes: ${outerClasses.size}" )
+        outerClasses.foreach { c => LOG.debug(s"     ${c.getName}") }
+        LOG.debug(s" + outer objects: ${outerObjects.size}")
+        outerObjects.foreach { o => LOG.debug(s"     $o") }
       }
-      val field = func.getClass.getDeclaredField("$outer")
-      field.setAccessible(true)
-      field.set(func, outer)
+
+      // Fail fast if we detect return statements in closures
+      getClassReader(func.getClass).accept(new ReturnStatementFinder(), 0)
+
+      // If accessed fields is not populated yet, we assume that
+      // the closure we are trying to clean is the starting one
+      if (accessedFields.isEmpty) {
+        LOG.debug(" + populating accessed fields because this is the starting 
closure")
+        // Initialize accessed fields with the outer classes first
+        // This step is needed to associate the fields to the correct classes 
later
+        initAccessedFields(accessedFields, outerClasses)
+
+        // Populate accessed fields by visiting all fields and methods 
accessed by this and
+        // all of its inner closures. If transitive cleaning is enabled, this 
may recursively
+        // visits methods that belong to other classes in search of 
transitively referenced fields.
+        for (cls <- func.getClass :: innerClasses) {
+          getClassReader(cls).accept(new FieldAccessFinder(accessedFields, 
cleanTransitively), 0)
+        }
+      }
+
+      LOG.debug(s" + fields accessed by starting closure: " + 
accessedFields.size)
+      accessedFields.foreach { f => LOG.debug("     " + f) }
+
+      // List of outer (class, object) pairs, ordered from outermost to 
innermost
+      // Note that all outer objects but the outermost one (first one in this 
list) must be closures
+      var outerPairs: List[(Class[_], AnyRef)] = 
outerClasses.zip(outerObjects).reverse
+      var parent: AnyRef = null
+      if (outerPairs.nonEmpty) {
+        val (outermostClass, outermostObject) = outerPairs.head
+        if (isClosure(outermostClass)) {
+          LOG.debug(s" + outermost object is a closure, so we clone it: 
${outerPairs.head}")
+        } else if (outermostClass.getName.startsWith("$line")) {
+          // SPARK-14558: if the outermost object is a REPL line object, we 
should clone
+          // and clean it as it may carray a lot of unnecessary information,
+          // e.g. hadoop conf, spark conf, etc.
+          LOG.debug(
+            s" + outermost object is a REPL line object, so we clone it: 
${outerPairs.head}")
+        } else {
+          // The closure is ultimately nested inside a class; keep the object 
of that
+          // class without cloning it since we don't want to clone the user's 
objects.
+          // Note that we still need to keep around the outermost object 
itself because
+          // we need it to clone its child closure later (see below).
+          LOG.debug(" + outermost object is not a closure or REPL line 
object," +
+            "so do not clone it: " +  outerPairs.head)
+          parent = outermostObject // e.g. SparkContext
+          outerPairs = outerPairs.tail
+        }
+      } else {
+        LOG.debug(" + there are no enclosing objects!")
+      }
+
+      // Clone the closure objects themselves, nulling out any fields that are 
not
+      // used in the closure we're working on or any of its inner closures.
+      for ((cls, obj) <- outerPairs) {
+        LOG.debug(s" + cloning the object $obj of class ${cls.getName}")
+        // We null out these unused references by cloning each object and then 
filling in all
+        // required fields from the original object. We need the parent here 
because the Java
+        // language specification requires the first constructor parameter of 
any closure to be
+        // its enclosing object.
+        val clone = cloneAndSetFields(parent, obj, cls, accessedFields)
+
+        // If transitive cleaning is enabled, we recursively clean any 
enclosing closure using
+        // the already populated accessed fields map of the starting closure
+        if (cleanTransitively && isClosure(clone.getClass)) {
+          LOG.debug(s" + cleaning cloned closure $clone recursively 
(${cls.getName})")
+          // No need to check serializable here for the outer closures because 
we're
+          // only interested in the serializability of the starting closure
+          clean(clone, checkSerializable = false, cleanTransitively, 
accessedFields)
+        }
+        parent = clone
+      }
+
+      // Update the parent pointer ($outer) of this closure
+      if (parent != null) {
+        val field = func.getClass.getDeclaredField("$outer")
+        field.setAccessible(true)
+        // If the starting closure doesn't actually need our enclosing object, 
then just null it out
+        if (accessedFields.contains(func.getClass) &&
+          !accessedFields(func.getClass).contains("$outer")) {
+          LOG.debug(s" + the starting closure doesn't actually need $parent, 
so we null it out")
+          field.set(func, null)
+        } else {
+          // Update this closure's parent pointer to point to our enclosing 
object,
+          // which could either be a cloned closure or the original user object
+          field.set(func, parent)
+        }
+      }
+
+      LOG.debug(s" +++ closure $func (${func.getClass.getName}) is now cleaned 
+++")
+    } else {
+      LOG.debug(s"Cleaning lambda: ${lambdaFunc.get.getImplMethodName}")
+
+      // scalastyle:off classforname
+      val captClass = 
Class.forName(lambdaFunc.get.getCapturingClass.replace('/', '.'),
+        false, Thread.currentThread.getContextClassLoader)
+      // scalastyle:on classforname
+      // Fail fast if we detect return statements in closures
+      getClassReader(captClass)
+        .accept(new 
ReturnStatementFinder(Some(lambdaFunc.get.getImplMethodName)), 0)
+      LOG.debug(s" +++ Lambda closure (${lambdaFunc.get.getImplMethodName}) is 
now cleaned +++")
     }
 
     if (checkSerializable) {
@@ -165,7 +399,7 @@ object ClosureCleaner {
     }
   }
 
-  def ensureSerializable(func: AnyRef) {
+  private[flink] def ensureSerializable(func: AnyRef) {
     try {
       InstantiationUtil.serializeObject(func)
     } catch {
@@ -173,24 +407,28 @@ object ClosureCleaner {
     }
   }
 
-  private def instantiateClass(cls: Class[_], outer: AnyRef): AnyRef = {
-    if (LOG.isDebugEnabled) {
-      LOG.debug("Creating a " + cls + " with outer = " + outer)
-    }
-    // This is a bona fide closure class, whose constructor has no effects
-    // other than to set its fields, so use its constructor
-    val cons = cls.getConstructors()(0)
-    val params = cons.getParameterTypes.map(createNullValue)
-    if (outer != null) {
-      params(0) = outer // First param is always outer object
+  private def instantiateClass(
+      cls: Class[_],
+      enclosingObject: AnyRef): AnyRef = {
+    // Use reflection to instantiate object without calling constructor
+    val rf = sun.reflect.ReflectionFactory.getReflectionFactory()
+    val parentCtor = classOf[java.lang.Object].getDeclaredConstructor()
+    val newCtor = rf.newConstructorForSerialization(cls, parentCtor)
+    val obj = newCtor.newInstance().asInstanceOf[AnyRef]
+    if (enclosingObject != null) {
+      val field = cls.getDeclaredField("$outer")
+      field.setAccessible(true)
+      field.set(obj, enclosingObject)
     }
-    cons.newInstance(params: _*).asInstanceOf[AnyRef]
+    obj
   }
 
+
   /** Copy all data from an InputStream to an OutputStream */
-  def copyStream(in: InputStream,
-                 out: OutputStream,
-                 closeStreams: Boolean = false): Long =
+  def copyStream(
+      in: InputStream,
+      out: OutputStream,
+      closeStreams: Boolean = false): Long =
   {
     var count = 0L
     try {
@@ -228,46 +466,107 @@ object ClosureCleaner {
   }
 }
 
-@Internal
-private[flink]
-class ReturnStatementFinder extends ClassVisitor(ASM5) {
-  override def visitMethod(access: Int, name: String, desc: String,
-                           sig: String, exceptions: Array[String]): 
MethodVisitor = {
-    if (name.contains("apply")) {
-      new MethodVisitor(ASM5) {
+private class ReturnStatementInClosureException
+  extends FlinkException("Return statements aren't allowed in Flink closures")
+
+private class ReturnStatementFinder(targetMethodName: Option[String] = None)
+  extends ClassVisitor(ASM6) {
+  override def visitMethod(
+      access: Int,
+      name: String,
+      desc: String,
+      sig: String, exceptions: Array[String]): MethodVisitor = {
+
+    // $anonfun$ covers Java 8 lambdas
+    if (name.contains("apply") || name.contains("$anonfun$")) {
+      // A method with suffix "$adapted" will be generated in cases like
+      // { _:Int => return; Seq()} but not { _:Int => return; true}
+      // closure passed is $anonfun$t$1$adapted while actual code resides in 
$anonfun$s$1
+      // visitor will see only $anonfun$s$1$adapted, so we remove the suffix, 
see
+      // https://github.com/scala/scala-dev/issues/109
+      val isTargetMethod = targetMethodName.isEmpty ||
+        name == targetMethodName.get || name == 
targetMethodName.get.stripSuffix("$adapted")
+
+      new MethodVisitor(ASM6) {
         override def visitTypeInsn(op: Int, tp: String) {
-          if (op == NEW && tp.contains("scala/runtime/NonLocalReturnControl")) 
{
-            throw new InvalidProgramException("Return statements aren't 
allowed in Flink closures")
+          if (op == NEW && tp.contains("scala/runtime/NonLocalReturnControl") 
&& isTargetMethod) {
+            throw new ReturnStatementInClosureException
           }
         }
       }
     } else {
-      new MethodVisitor(ASM5) {}
+      new MethodVisitor(ASM6) {}
     }
   }
 }
 
-@Internal
-private[flink]
-class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends 
ClassVisitor(ASM5) {
-  override def visitMethod(access: Int, name: String, desc: String,
-                           sig: String, exceptions: Array[String]): 
MethodVisitor = {
-    new MethodVisitor(ASM5) {
+/** Helper class to identify a method. */
+private case class MethodIdentifier[T](cls: Class[T], name: String, desc: 
String)
+
+/**
+ * Find the fields accessed by a given class.
+ *
+ * The resulting fields are stored in the mutable map passed in through the 
constructor.
+ * This map is assumed to have its keys already populated with the classes of 
interest.
+ *
+ * @param fields the mutable map that stores the fields to return
+ * @param findTransitively if true, find fields indirectly referenced through 
method calls
+ * @param specificMethod if not empty, visit only this specific method
+ * @param visitedMethods a set of visited methods to avoid cycles
+ */
+private class FieldAccessFinder(
+    fields: Map[Class[_], Set[String]],
+    findTransitively: Boolean,
+    specificMethod: Option[MethodIdentifier[_]] = None,
+    visitedMethods: Set[MethodIdentifier[_]] = Set.empty)
+  extends ClassVisitor(ASM6) {
+
+  override def visitMethod(
+      access: Int,
+      name: String,
+      desc: String,
+      sig: String,
+      exceptions: Array[String]): MethodVisitor = {
+
+    // If we are told to visit only a certain method and this is not the one, 
ignore it
+    if (specificMethod.isDefined &&
+      (specificMethod.get.name != name || specificMethod.get.desc != desc)) {
+      return null
+    }
+
+    new MethodVisitor(ASM6) {
       override def visitFieldInsn(op: Int, owner: String, name: String, desc: 
String) {
         if (op == GETFIELD) {
-          for (cl <- output.keys if cl.getName == owner.replace('/', '.')) {
-            output(cl) += name
+          for (cl <- fields.keys if cl.getName == owner.replace('/', '.')) {
+            fields(cl) += name
           }
         }
       }
 
-      override def visitMethodInsn(op: Int, owner: String, name: String,
-                                   desc: String) {
-        // Check for calls a getter method for a variable in an interpreter 
wrapper object.
-        // This means that the corresponding field will be accessed, so we 
should save it.
-        if (op == INVOKEVIRTUAL && owner.endsWith("$iwC") && 
!name.endsWith("$outer")) {
-          for (cl <- output.keys if cl.getName == owner.replace('/', '.')) {
-            output(cl) += name
+      override def visitMethodInsn(
+          op: Int, owner: String, name: String, desc: String, itf: Boolean) {
+        for (cl <- fields.keys if cl.getName == owner.replace('/', '.')) {
+          // Check for calls a getter method for a variable in an interpreter 
wrapper object.
+          // This means that the corresponding field will be accessed, so we 
should save it.
+          if (op == INVOKEVIRTUAL && owner.endsWith("$iwC") && 
!name.endsWith("$outer")) {
+            fields(cl) += name
+          }
+          // Optionally visit other methods to find fields that are 
transitively referenced
+          if (findTransitively) {
+            val m = MethodIdentifier(cl, name, desc)
+            if (!visitedMethods.contains(m)) {
+              // Keep track of visited methods to avoid potential infinite 
cycles
+              visitedMethods += m
+
+              var currentClass = cl
+              assert(currentClass != null, "The outer class can't be null.")
+
+              while (currentClass != null) {
+                ClosureCleaner.getClassReader(currentClass).accept(
+                  new FieldAccessFinder(fields, findTransitively, Some(m), 
visitedMethods), 0)
+                currentClass = currentClass.getSuperclass()
+              }
+            }
           }
         }
       }
@@ -275,10 +574,14 @@ class FieldAccessFinder(output: Map[Class[_], 
Set[String]]) extends ClassVisitor
   }
 }
 
-@Internal
-private[flink] class InnerClosureFinder(output: Set[Class[_]]) extends 
ClassVisitor(ASM5) {
+private class InnerClosureFinder(output: Set[Class[_]]) extends 
ClassVisitor(ASM6) {
   var myName: String = null
 
+  // TODO: Recursively find inner closures that we indirectly reference, e.g.
+  //   val closure1 = () = { () => 1 }
+  //   val closure2 = () => { (1 to 5).map(closure1) }
+  // The second closure technically has two inner closures, but this finder 
only finds one
+
   override def visit(version: Int, access: Int, name: String, sig: String,
                      superName: String, interfaces: Array[String]) {
     myName = name
@@ -286,20 +589,21 @@ private[flink] class InnerClosureFinder(output: 
Set[Class[_]]) extends ClassVisi
 
   override def visitMethod(access: Int, name: String, desc: String,
                            sig: String, exceptions: Array[String]): 
MethodVisitor = {
-    new MethodVisitor(ASM5) {
-      override def visitMethodInsn(op: Int, owner: String, name: String,
-                                   desc: String) {
+    new MethodVisitor(ASM6) {
+      override def visitMethodInsn(
+          op: Int, owner: String, name: String, desc: String, itf: Boolean) {
         val argTypes = Type.getArgumentTypes(desc)
-        if (op == INVOKESPECIAL && name == "<init>" && argTypes.nonEmpty
+        if (op == INVOKESPECIAL && name == "<init>" && argTypes.length > 0
           && argTypes(0).toString.startsWith("L") // is it an object?
           && argTypes(0).getInternalName == myName) {
+          // scalastyle:off classforname
           output += Class.forName(
             owner.replace('/', '.'),
             false,
             Thread.currentThread.getContextClassLoader)
+          // scalastyle:on classforname
         }
       }
     }
   }
 }
-
diff --git 
a/flink-tests/src/test/scala/org/apache/flink/api/scala/functions/ClosureCleanerITCase.scala
 
b/flink-tests/src/test/scala/org/apache/flink/api/scala/functions/ClosureCleanerITCase.scala
index 8f1e1f8..2a4dcf3 100644
--- 
a/flink-tests/src/test/scala/org/apache/flink/api/scala/functions/ClosureCleanerITCase.scala
+++ 
b/flink-tests/src/test/scala/org/apache/flink/api/scala/functions/ClosureCleanerITCase.scala
@@ -184,7 +184,7 @@ object TestObjectWithBogusReturns {
     try {
       nums.map { x => return 1; x * 2}.print()
     } catch {
-      case inv: InvalidProgramException => // all good
+      case inv: ReturnStatementInClosureException => // all good
       case _: Throwable => fail("Bogus return statement not detected.")
     }
 
diff --git a/pom.xml b/pom.xml
index 96bc64a..875b32d 100644
--- a/pom.xml
+++ b/pom.xml
@@ -246,6 +246,12 @@ under the License.
 
                        <dependency>
                                <groupId>org.apache.flink</groupId>
+                               <artifactId>flink-shaded-asm-6</artifactId>
+                               <version>6.2.1-${flink.shaded.version}</version>
+                       </dependency>
+
+                       <dependency>
+                               <groupId>org.apache.flink</groupId>
                                <artifactId>flink-shaded-guava</artifactId>
                                <version>18.0-${flink.shaded.version}</version>
                        </dependency>

Reply via email to