This is an automated email from the ASF dual-hosted git repository. aljoscha pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit 1b9cdf559adacd9af8a1f84647a1aed99b2c56dc Author: Aljoscha Krettek <[email protected]> AuthorDate: Tue Oct 2 10:12:03 2018 +0200 [FLINK-7816] Support Scala 2.12 closures and Java 8 lambdas in ClosureCleaner This updates the ClosureCleaner with recent changes from SPARK-14540. --- flink-scala/pom.xml | 5 + .../apache/flink/api/scala/ClosureCleaner.scala | 564 ++++++++++++++++----- .../api/scala/functions/ClosureCleanerITCase.scala | 2 +- pom.xml | 6 + 4 files changed, 446 insertions(+), 131 deletions(-) diff --git a/flink-scala/pom.xml b/flink-scala/pom.xml index 508a8de..4472207 100644 --- a/flink-scala/pom.xml +++ b/flink-scala/pom.xml @@ -51,6 +51,11 @@ under the License. </dependency> <dependency> + <groupId>org.apache.flink</groupId> + <artifactId>flink-shaded-asm-6</artifactId> + </dependency> + + <dependency> <groupId>org.scala-lang</groupId> <artifactId>scala-reflect</artifactId> </dependency> diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/ClosureCleaner.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/ClosureCleaner.scala index 7965346..2932ed5 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/ClosureCleaner.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/ClosureCleaner.scala @@ -18,35 +18,40 @@ package org.apache.flink.api.scala import java.io._ +import java.lang.invoke.SerializedLambda import org.apache.flink.annotation.Internal import org.apache.flink.api.common.InvalidProgramException -import org.apache.flink.util.InstantiationUtil +import org.apache.flink.util.{FlinkException, InstantiationUtil} import org.slf4j.LoggerFactory import scala.collection.mutable.Map import scala.collection.mutable.Set +import org.apache.flink.shaded.asm6.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type} +import org.apache.flink.shaded.asm6.org.objectweb.asm.Opcodes._ -import org.apache.flink.shaded.asm5.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type} -import org.apache.flink.shaded.asm5.org.objectweb.asm.Opcodes._ +import scala.collection.mutable /* This code is originally from the Apache Spark project. */ @Internal object ClosureCleaner { + val LOG = LoggerFactory.getLogger(this.getClass) + private val isScala2_11 = scala.util.Properties.versionString.contains("2.11") + // Get an ASM class reader for a given class from the JAR that loaded it - private def getClassReader(cls: Class[_]): ClassReader = { + private[scala] def getClassReader(cls: Class[_]): ClassReader = { // Copy data over, before delegating to ClassReader - else we can run out of open file handles. val className = cls.getName.replaceFirst("^.*\\.", "") + ".class" val resourceStream = cls.getResourceAsStream(className) - // todo: Fixme - continuing with earlier behavior ... - if (resourceStream == null) return new ClassReader(resourceStream) - - val baos = new ByteArrayOutputStream(128) - - copyStream(resourceStream, baos, true) - new ClassReader(new ByteArrayInputStream(baos.toByteArray)) + if (resourceStream == null) { + null + } else { + val baos = new ByteArrayOutputStream(128) + copyStream(resourceStream, baos, true) + new ClassReader(new ByteArrayInputStream(baos.toByteArray)) + } } // Check whether a class represents a Scala closure @@ -54,110 +59,339 @@ object ClosureCleaner { cls.getName.contains("$anonfun$") } - // Get a list of the classes of the outer objects of a given closure object, obj; + // Get a list of the outer objects and their classes of a given closure object, obj; // the outer objects are defined as any closures that obj is nested within, plus // possibly the class that the outermost closure is in, if any. We stop searching // for outer objects beyond that because cloning the user's object is probably // not a good idea (whereas we can clone closure objects just fine since we // understand how all their fields are used). - private def getOuterClasses(obj: AnyRef): List[Class[_]] = { + private def getOuterClassesAndObjects(obj: AnyRef): (List[Class[_]], List[AnyRef]) = { for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") { f.setAccessible(true) - if (isClosure(f.getType)) { - return f.getType :: getOuterClasses(f.get(obj)) - } else { - return f.getType :: Nil // Stop at the first $outer that is not a closure + val outer = f.get(obj) + // The outer pointer may be null if we have cleaned this closure before + if (outer != null) { + if (isClosure(f.getType)) { + val recurRet = getOuterClassesAndObjects(outer) + return (f.getType :: recurRet._1, outer :: recurRet._2) + } else { + return (f.getType :: Nil, outer :: Nil) // Stop at the first $outer that is not a closure + } } } - Nil + (Nil, Nil) } - - // Get a list of the outer objects for a given closure object. - private def getOuterObjects(obj: AnyRef): List[AnyRef] = { - for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") { - f.setAccessible(true) - if (isClosure(f.getType)) { - return f.get(obj) :: getOuterObjects(f.get(obj)) - } else { - return f.get(obj) :: Nil // Stop at the first $outer that is not a closure + /** + * Return a list of classes that represent closures enclosed in the given closure object. + */ + private def getInnerClosureClasses(obj: AnyRef): List[Class[_]] = { + val seen = Set[Class[_]](obj.getClass) + val stack = mutable.Stack[Class[_]](obj.getClass) + while (!stack.isEmpty) { + val cr = getClassReader(stack.pop()) + if (cr != null) { + val set = Set.empty[Class[_]] + cr.accept(new InnerClosureFinder(set), 0) + for (cls <- set -- seen) { + seen += cls + stack.push(cls) + } } } - Nil + (seen - obj.getClass).toList } - private def getInnerClasses(obj: AnyRef): List[Class[_]] = { - val seen = Set[Class[_]](obj.getClass) - var stack = List[Class[_]](obj.getClass) - while (stack.nonEmpty) { - val cr = getClassReader(stack.head) - stack = stack.tail - val set = Set[Class[_]]() - cr.accept(new InnerClosureFinder(set), 0) - for (cls <- set -- seen) { - seen += cls - stack = cls :: stack + /** Initializes the accessed fields for outer classes and their super classes. */ + private def initAccessedFields( + accessedFields: Map[Class[_], Set[String]], + outerClasses: Seq[Class[_]]): Unit = { + for (cls <- outerClasses) { + var currentClass = cls + assert(currentClass != null, "The outer class can't be null.") + + while (currentClass != null) { + accessedFields(currentClass) = Set.empty[String] + currentClass = currentClass.getSuperclass() } } - (seen - obj.getClass).toList } - private def createNullValue(cls: Class[_]): AnyRef = { - if (cls.isPrimitive) { - new java.lang.Byte(0: Byte) // Should be convertible to any primitive type - } else { - null + /** Sets accessed fields for given class in clone object based on given object. */ + private def setAccessedFields( + outerClass: Class[_], + clone: AnyRef, + obj: AnyRef, + accessedFields: Map[Class[_], Set[String]]): Unit = { + for (fieldName <- accessedFields(outerClass)) { + val field = outerClass.getDeclaredField(fieldName) + field.setAccessible(true) + val value = field.get(obj) + field.set(clone, value) } } - def clean(func: AnyRef, checkSerializable: Boolean = true) { - // TODO: cache outerClasses / innerClasses / accessedFields - val outerClasses = getOuterClasses(func) - val innerClasses = getInnerClasses(func) - val outerObjects = getOuterObjects(func) + /** Clones a given object and sets accessed fields in cloned object. */ + private def cloneAndSetFields( + parent: AnyRef, + obj: AnyRef, + outerClass: Class[_], + accessedFields: Map[Class[_], Set[String]]): AnyRef = { + val clone = instantiateClass(outerClass, parent) - val accessedFields = Map[Class[_], Set[String]]() + var currentClass = outerClass + assert(currentClass != null, "The outer class can't be null.") - getClassReader(func.getClass).accept(new ReturnStatementFinder(), 0) + while (currentClass != null) { + setAccessedFields(currentClass, clone, obj, accessedFields) + currentClass = currentClass.getSuperclass() + } - for (cls <- outerClasses) - accessedFields(cls) = Set[String]() - for (cls <- func.getClass :: innerClasses) - getClassReader(cls).accept(new FieldAccessFinder(accessedFields), 0) + clone + } - if (LOG.isDebugEnabled) { - LOG.debug("accessedFields: " + accessedFields) - } + /** + * Clean the given closure in place. + * + * More specifically, this renders the given closure serializable as long as it does not + * explicitly reference unserializable objects. + * + * @param closure the closure to clean + * @param checkSerializable whether to verify that the closure is serializable after cleaning + * @param cleanTransitively whether to clean enclosing closures transitively + */ + def clean( + closure: AnyRef, + checkSerializable: Boolean = true, + cleanTransitively: Boolean = true): Unit = { + clean(closure, checkSerializable, cleanTransitively, Map.empty) + } - var outerPairs: List[(Class[_], AnyRef)] = (outerClasses zip outerObjects).reverse - var outer: AnyRef = null - if (outerPairs.nonEmpty && !isClosure(outerPairs.head._1)) { - // The closure is ultimately nested inside a class; keep the object of that - // class without cloning it since we don't want to clone the user's objects. - outer = outerPairs.head._2 - outerPairs = outerPairs.tail + /** + * Try to get a serialized Lambda from the closure. + * + * @param closure the closure to check. + */ + private def getSerializedLambda(closure: AnyRef): Option[SerializedLambda] = { + if (isScala2_11) { + return None } - // Clone the closure objects themselves, nulling out any fields that are not - // used in the closure we're working on or any of its inner closures. - for ((cls, obj) <- outerPairs) { - outer = instantiateClass(cls, outer) - for (fieldName <- accessedFields(cls)) { - val field = cls.getDeclaredField(fieldName) - field.setAccessible(true) - val value = field.get(obj) - if (LOG.isDebugEnabled) { - LOG.debug("1: Setting " + fieldName + " on " + cls + " to " + value) - } - field.set(outer, value) + val isClosureCandidate = + closure.getClass.isSynthetic && + closure + .getClass + .getInterfaces.exists(_.getName == "scala.Serializable") + + if (isClosureCandidate) { + try { + Option(inspect(closure)) + } catch { + case e: Exception => + if (LOG.isDebugEnabled) { + LOG.debug("Closure is not a serialized lambda.", e) + } + None } + } else { + None } + } + + private def inspect(closure: AnyRef): SerializedLambda = { + val writeReplace = closure.getClass.getDeclaredMethod("writeReplace") + writeReplace.setAccessible(true) + writeReplace.invoke(closure).asInstanceOf[java.lang.invoke.SerializedLambda] + } + + /** + * Helper method to clean the given closure in place. + * + * The mechanism is to traverse the hierarchy of enclosing closures and null out any + * references along the way that are not actually used by the starting closure, but are + * nevertheless included in the compiled anonymous classes. Note that it is unsafe to + * simply mutate the enclosing closures in place, as other code paths may depend on them. + * Instead, we clone each enclosing closure and set the parent pointers accordingly. + * + * By default, closures are cleaned transitively. This means we detect whether enclosing + * objects are actually referenced by the starting one, either directly or transitively, + * and, if not, sever these closures from the hierarchy. In other words, in addition to + * nulling out unused field references, we also null out any parent pointers that refer + * to enclosing objects not actually needed by the starting closure. We determine + * transitivity by tracing through the tree of all methods ultimately invoked by the + * inner closure and record all the fields referenced in the process. + * + * For instance, transitive cleaning is necessary in the following scenario: + * + * class SomethingNotSerializable { + * def someValue = 1 + * def scope(name: String)(body: => Unit) = body + * def someMethod(): Unit = scope("one") { + * def x = someValue + * def y = 2 + * scope("two") { println(y + 1) } + * } + * } + * + * In this example, scope "two" is not serializable because it references scope "one", which + * references SomethingNotSerializable. Note that, however, the body of scope "two" does not + * actually depend on SomethingNotSerializable. This means we can safely null out the parent + * pointer of a cloned scope "one" and set it the parent of scope "two", such that scope "two" + * no longer references SomethingNotSerializable transitively. + * + * @param func the starting closure to clean + * @param checkSerializable whether to verify that the closure is serializable after cleaning + * @param cleanTransitively whether to clean enclosing closures transitively + * @param accessedFields a map from a class to a set of its fields that are accessed by + * the starting closure + */ + private def clean( + func: AnyRef, + checkSerializable: Boolean, + cleanTransitively: Boolean, + accessedFields: Map[Class[_], Set[String]]): Unit = { + + // most likely to be the case with 2.12, 2.13 + // so we check first + // non LMF-closures should be less frequent from now on + val lambdaFunc = getSerializedLambda(func) + + if (!isClosure(func.getClass) && lambdaFunc.isEmpty) { + LOG.debug(s"Expected a closure; got ${func.getClass.getName}") + return + } + + // TODO: clean all inner closures first. This requires us to find the inner objects. + // TODO: cache outerClasses / innerClasses / accessedFields + + if (func == null) { + return + } + + if (lambdaFunc.isEmpty) { + LOG.debug(s"+++ Cleaning closure $func (${func.getClass.getName}) +++") + + // A list of classes that represents closures enclosed in the given one + val innerClasses = getInnerClosureClasses(func) + + // A list of enclosing objects and their respective classes, from innermost to outermost + // An outer object at a given index is of type outer class at the same index + val (outerClasses, outerObjects) = getOuterClassesAndObjects(func) + + // For logging purposes only + val declaredFields = func.getClass.getDeclaredFields + val declaredMethods = func.getClass.getDeclaredMethods - if (outer != null) { if (LOG.isDebugEnabled) { - LOG.debug("2: Setting $outer on " + func.getClass + " to " + outer) + LOG.debug(s" + declared fields: ${declaredFields.size}") + declaredFields.foreach { f => LOG.debug(s" $f") } + LOG.debug(s" + declared methods: ${declaredMethods.size}") + declaredMethods.foreach { m => LOG.debug(s" $m") } + LOG.debug(s" + inner classes: ${innerClasses.size}") + innerClasses.foreach { c => LOG.debug(s" ${c.getName}") } + LOG.debug(s" + outer classes: ${outerClasses.size}" ) + outerClasses.foreach { c => LOG.debug(s" ${c.getName}") } + LOG.debug(s" + outer objects: ${outerObjects.size}") + outerObjects.foreach { o => LOG.debug(s" $o") } } - val field = func.getClass.getDeclaredField("$outer") - field.setAccessible(true) - field.set(func, outer) + + // Fail fast if we detect return statements in closures + getClassReader(func.getClass).accept(new ReturnStatementFinder(), 0) + + // If accessed fields is not populated yet, we assume that + // the closure we are trying to clean is the starting one + if (accessedFields.isEmpty) { + LOG.debug(" + populating accessed fields because this is the starting closure") + // Initialize accessed fields with the outer classes first + // This step is needed to associate the fields to the correct classes later + initAccessedFields(accessedFields, outerClasses) + + // Populate accessed fields by visiting all fields and methods accessed by this and + // all of its inner closures. If transitive cleaning is enabled, this may recursively + // visits methods that belong to other classes in search of transitively referenced fields. + for (cls <- func.getClass :: innerClasses) { + getClassReader(cls).accept(new FieldAccessFinder(accessedFields, cleanTransitively), 0) + } + } + + LOG.debug(s" + fields accessed by starting closure: " + accessedFields.size) + accessedFields.foreach { f => LOG.debug(" " + f) } + + // List of outer (class, object) pairs, ordered from outermost to innermost + // Note that all outer objects but the outermost one (first one in this list) must be closures + var outerPairs: List[(Class[_], AnyRef)] = outerClasses.zip(outerObjects).reverse + var parent: AnyRef = null + if (outerPairs.nonEmpty) { + val (outermostClass, outermostObject) = outerPairs.head + if (isClosure(outermostClass)) { + LOG.debug(s" + outermost object is a closure, so we clone it: ${outerPairs.head}") + } else if (outermostClass.getName.startsWith("$line")) { + // SPARK-14558: if the outermost object is a REPL line object, we should clone + // and clean it as it may carray a lot of unnecessary information, + // e.g. hadoop conf, spark conf, etc. + LOG.debug( + s" + outermost object is a REPL line object, so we clone it: ${outerPairs.head}") + } else { + // The closure is ultimately nested inside a class; keep the object of that + // class without cloning it since we don't want to clone the user's objects. + // Note that we still need to keep around the outermost object itself because + // we need it to clone its child closure later (see below). + LOG.debug(" + outermost object is not a closure or REPL line object," + + "so do not clone it: " + outerPairs.head) + parent = outermostObject // e.g. SparkContext + outerPairs = outerPairs.tail + } + } else { + LOG.debug(" + there are no enclosing objects!") + } + + // Clone the closure objects themselves, nulling out any fields that are not + // used in the closure we're working on or any of its inner closures. + for ((cls, obj) <- outerPairs) { + LOG.debug(s" + cloning the object $obj of class ${cls.getName}") + // We null out these unused references by cloning each object and then filling in all + // required fields from the original object. We need the parent here because the Java + // language specification requires the first constructor parameter of any closure to be + // its enclosing object. + val clone = cloneAndSetFields(parent, obj, cls, accessedFields) + + // If transitive cleaning is enabled, we recursively clean any enclosing closure using + // the already populated accessed fields map of the starting closure + if (cleanTransitively && isClosure(clone.getClass)) { + LOG.debug(s" + cleaning cloned closure $clone recursively (${cls.getName})") + // No need to check serializable here for the outer closures because we're + // only interested in the serializability of the starting closure + clean(clone, checkSerializable = false, cleanTransitively, accessedFields) + } + parent = clone + } + + // Update the parent pointer ($outer) of this closure + if (parent != null) { + val field = func.getClass.getDeclaredField("$outer") + field.setAccessible(true) + // If the starting closure doesn't actually need our enclosing object, then just null it out + if (accessedFields.contains(func.getClass) && + !accessedFields(func.getClass).contains("$outer")) { + LOG.debug(s" + the starting closure doesn't actually need $parent, so we null it out") + field.set(func, null) + } else { + // Update this closure's parent pointer to point to our enclosing object, + // which could either be a cloned closure or the original user object + field.set(func, parent) + } + } + + LOG.debug(s" +++ closure $func (${func.getClass.getName}) is now cleaned +++") + } else { + LOG.debug(s"Cleaning lambda: ${lambdaFunc.get.getImplMethodName}") + + // scalastyle:off classforname + val captClass = Class.forName(lambdaFunc.get.getCapturingClass.replace('/', '.'), + false, Thread.currentThread.getContextClassLoader) + // scalastyle:on classforname + // Fail fast if we detect return statements in closures + getClassReader(captClass) + .accept(new ReturnStatementFinder(Some(lambdaFunc.get.getImplMethodName)), 0) + LOG.debug(s" +++ Lambda closure (${lambdaFunc.get.getImplMethodName}) is now cleaned +++") } if (checkSerializable) { @@ -165,7 +399,7 @@ object ClosureCleaner { } } - def ensureSerializable(func: AnyRef) { + private[flink] def ensureSerializable(func: AnyRef) { try { InstantiationUtil.serializeObject(func) } catch { @@ -173,24 +407,28 @@ object ClosureCleaner { } } - private def instantiateClass(cls: Class[_], outer: AnyRef): AnyRef = { - if (LOG.isDebugEnabled) { - LOG.debug("Creating a " + cls + " with outer = " + outer) - } - // This is a bona fide closure class, whose constructor has no effects - // other than to set its fields, so use its constructor - val cons = cls.getConstructors()(0) - val params = cons.getParameterTypes.map(createNullValue) - if (outer != null) { - params(0) = outer // First param is always outer object + private def instantiateClass( + cls: Class[_], + enclosingObject: AnyRef): AnyRef = { + // Use reflection to instantiate object without calling constructor + val rf = sun.reflect.ReflectionFactory.getReflectionFactory() + val parentCtor = classOf[java.lang.Object].getDeclaredConstructor() + val newCtor = rf.newConstructorForSerialization(cls, parentCtor) + val obj = newCtor.newInstance().asInstanceOf[AnyRef] + if (enclosingObject != null) { + val field = cls.getDeclaredField("$outer") + field.setAccessible(true) + field.set(obj, enclosingObject) } - cons.newInstance(params: _*).asInstanceOf[AnyRef] + obj } + /** Copy all data from an InputStream to an OutputStream */ - def copyStream(in: InputStream, - out: OutputStream, - closeStreams: Boolean = false): Long = + def copyStream( + in: InputStream, + out: OutputStream, + closeStreams: Boolean = false): Long = { var count = 0L try { @@ -228,46 +466,107 @@ object ClosureCleaner { } } -@Internal -private[flink] -class ReturnStatementFinder extends ClassVisitor(ASM5) { - override def visitMethod(access: Int, name: String, desc: String, - sig: String, exceptions: Array[String]): MethodVisitor = { - if (name.contains("apply")) { - new MethodVisitor(ASM5) { +private class ReturnStatementInClosureException + extends FlinkException("Return statements aren't allowed in Flink closures") + +private class ReturnStatementFinder(targetMethodName: Option[String] = None) + extends ClassVisitor(ASM6) { + override def visitMethod( + access: Int, + name: String, + desc: String, + sig: String, exceptions: Array[String]): MethodVisitor = { + + // $anonfun$ covers Java 8 lambdas + if (name.contains("apply") || name.contains("$anonfun$")) { + // A method with suffix "$adapted" will be generated in cases like + // { _:Int => return; Seq()} but not { _:Int => return; true} + // closure passed is $anonfun$t$1$adapted while actual code resides in $anonfun$s$1 + // visitor will see only $anonfun$s$1$adapted, so we remove the suffix, see + // https://github.com/scala/scala-dev/issues/109 + val isTargetMethod = targetMethodName.isEmpty || + name == targetMethodName.get || name == targetMethodName.get.stripSuffix("$adapted") + + new MethodVisitor(ASM6) { override def visitTypeInsn(op: Int, tp: String) { - if (op == NEW && tp.contains("scala/runtime/NonLocalReturnControl")) { - throw new InvalidProgramException("Return statements aren't allowed in Flink closures") + if (op == NEW && tp.contains("scala/runtime/NonLocalReturnControl") && isTargetMethod) { + throw new ReturnStatementInClosureException } } } } else { - new MethodVisitor(ASM5) {} + new MethodVisitor(ASM6) {} } } } -@Internal -private[flink] -class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends ClassVisitor(ASM5) { - override def visitMethod(access: Int, name: String, desc: String, - sig: String, exceptions: Array[String]): MethodVisitor = { - new MethodVisitor(ASM5) { +/** Helper class to identify a method. */ +private case class MethodIdentifier[T](cls: Class[T], name: String, desc: String) + +/** + * Find the fields accessed by a given class. + * + * The resulting fields are stored in the mutable map passed in through the constructor. + * This map is assumed to have its keys already populated with the classes of interest. + * + * @param fields the mutable map that stores the fields to return + * @param findTransitively if true, find fields indirectly referenced through method calls + * @param specificMethod if not empty, visit only this specific method + * @param visitedMethods a set of visited methods to avoid cycles + */ +private class FieldAccessFinder( + fields: Map[Class[_], Set[String]], + findTransitively: Boolean, + specificMethod: Option[MethodIdentifier[_]] = None, + visitedMethods: Set[MethodIdentifier[_]] = Set.empty) + extends ClassVisitor(ASM6) { + + override def visitMethod( + access: Int, + name: String, + desc: String, + sig: String, + exceptions: Array[String]): MethodVisitor = { + + // If we are told to visit only a certain method and this is not the one, ignore it + if (specificMethod.isDefined && + (specificMethod.get.name != name || specificMethod.get.desc != desc)) { + return null + } + + new MethodVisitor(ASM6) { override def visitFieldInsn(op: Int, owner: String, name: String, desc: String) { if (op == GETFIELD) { - for (cl <- output.keys if cl.getName == owner.replace('/', '.')) { - output(cl) += name + for (cl <- fields.keys if cl.getName == owner.replace('/', '.')) { + fields(cl) += name } } } - override def visitMethodInsn(op: Int, owner: String, name: String, - desc: String) { - // Check for calls a getter method for a variable in an interpreter wrapper object. - // This means that the corresponding field will be accessed, so we should save it. - if (op == INVOKEVIRTUAL && owner.endsWith("$iwC") && !name.endsWith("$outer")) { - for (cl <- output.keys if cl.getName == owner.replace('/', '.')) { - output(cl) += name + override def visitMethodInsn( + op: Int, owner: String, name: String, desc: String, itf: Boolean) { + for (cl <- fields.keys if cl.getName == owner.replace('/', '.')) { + // Check for calls a getter method for a variable in an interpreter wrapper object. + // This means that the corresponding field will be accessed, so we should save it. + if (op == INVOKEVIRTUAL && owner.endsWith("$iwC") && !name.endsWith("$outer")) { + fields(cl) += name + } + // Optionally visit other methods to find fields that are transitively referenced + if (findTransitively) { + val m = MethodIdentifier(cl, name, desc) + if (!visitedMethods.contains(m)) { + // Keep track of visited methods to avoid potential infinite cycles + visitedMethods += m + + var currentClass = cl + assert(currentClass != null, "The outer class can't be null.") + + while (currentClass != null) { + ClosureCleaner.getClassReader(currentClass).accept( + new FieldAccessFinder(fields, findTransitively, Some(m), visitedMethods), 0) + currentClass = currentClass.getSuperclass() + } + } } } } @@ -275,10 +574,14 @@ class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends ClassVisitor } } -@Internal -private[flink] class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM5) { +private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM6) { var myName: String = null + // TODO: Recursively find inner closures that we indirectly reference, e.g. + // val closure1 = () = { () => 1 } + // val closure2 = () => { (1 to 5).map(closure1) } + // The second closure technically has two inner closures, but this finder only finds one + override def visit(version: Int, access: Int, name: String, sig: String, superName: String, interfaces: Array[String]) { myName = name @@ -286,20 +589,21 @@ private[flink] class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisi override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { - new MethodVisitor(ASM5) { - override def visitMethodInsn(op: Int, owner: String, name: String, - desc: String) { + new MethodVisitor(ASM6) { + override def visitMethodInsn( + op: Int, owner: String, name: String, desc: String, itf: Boolean) { val argTypes = Type.getArgumentTypes(desc) - if (op == INVOKESPECIAL && name == "<init>" && argTypes.nonEmpty + if (op == INVOKESPECIAL && name == "<init>" && argTypes.length > 0 && argTypes(0).toString.startsWith("L") // is it an object? && argTypes(0).getInternalName == myName) { + // scalastyle:off classforname output += Class.forName( owner.replace('/', '.'), false, Thread.currentThread.getContextClassLoader) + // scalastyle:on classforname } } } } } - diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/functions/ClosureCleanerITCase.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/functions/ClosureCleanerITCase.scala index 8f1e1f8..2a4dcf3 100644 --- a/flink-tests/src/test/scala/org/apache/flink/api/scala/functions/ClosureCleanerITCase.scala +++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/functions/ClosureCleanerITCase.scala @@ -184,7 +184,7 @@ object TestObjectWithBogusReturns { try { nums.map { x => return 1; x * 2}.print() } catch { - case inv: InvalidProgramException => // all good + case inv: ReturnStatementInClosureException => // all good case _: Throwable => fail("Bogus return statement not detected.") } diff --git a/pom.xml b/pom.xml index 96bc64a..875b32d 100644 --- a/pom.xml +++ b/pom.xml @@ -246,6 +246,12 @@ under the License. <dependency> <groupId>org.apache.flink</groupId> + <artifactId>flink-shaded-asm-6</artifactId> + <version>6.2.1-${flink.shaded.version}</version> + </dependency> + + <dependency> + <groupId>org.apache.flink</groupId> <artifactId>flink-shaded-guava</artifactId> <version>18.0-${flink.shaded.version}</version> </dependency>
