Github user pwendell commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5685#discussion_r29405393
  
    --- Diff: 
core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala ---
    @@ -0,0 +1,562 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.util
    +
    +import java.io.NotSerializableException
    +
    +import scala.collection.mutable
    +
    +import org.scalatest.{BeforeAndAfterAll, FunSuite, PrivateMethodTester}
    +
    +import org.apache.spark.{SparkContext, SparkException}
    +import org.apache.spark.serializer.SerializerInstance
    +
    +/**
    + * Another test suite for the closure cleaner that is finer-grained.
    + * For tests involving end-to-end Spark jobs, see {{ClosureCleanerSuite}}.
    + */
    +class ClosureCleanerSuite2 extends FunSuite with BeforeAndAfterAll with 
PrivateMethodTester {
    +
    +  // Start a SparkContext so that the closure serializer is accessible
    +  // We do not actually use this explicitly otherwise
    +  private var sc: SparkContext = null
    +  private var closureSerializer: SerializerInstance = null
    +
    +  override def beforeAll(): Unit = {
    +    sc = new SparkContext("local", "test")
    +    closureSerializer = sc.env.closureSerializer.newInstance()
    +  }
    +
    +  override def afterAll(): Unit = {
    +    sc.stop()
    +    sc = null
    +    closureSerializer = null
    +  }
    +
    +  // Some fields and methods to reference in inner closures later
    +  private val someSerializableValue = 1
    +  private val someNonSerializableValue = new NonSerializable
    +  private def someSerializableMethod() = 1
    +  private def someNonSerializableMethod() = new NonSerializable
    +
    +  /** Assert that the given closure is serializable (or not). */
    +  private def assertSerializable(closure: AnyRef, serializable: Boolean): 
Unit = {
    +    if (serializable) {
    +      closureSerializer.serialize(closure)
    +    } else {
    +      intercept[NotSerializableException] {
    +        closureSerializer.serialize(closure)
    +      }
    +    }
    +  }
    +
    +  /**
    +   * Helper method for testing whether closure cleaning works as expected.
    +   * This cleans the given closure twice, with and without transitive 
cleaning.
    +   */
    +  private def testClean(
    +      closure: AnyRef,
    +      serializableBefore: Boolean,
    +      serializableAfter: Boolean): Unit = {
    +    testClean(closure, serializableBefore, serializableAfter, transitive = 
true)
    +    testClean(closure, serializableBefore, serializableAfter, transitive = 
false)
    +  }
    +
    +  /** Helper method for testing whether closure cleaning works as 
expected. */
    +  private def testClean(
    +      closure: AnyRef,
    +      serializableBefore: Boolean,
    +      serializableAfter: Boolean,
    +      transitive: Boolean): Unit = {
    +    assertSerializable(closure, serializableBefore)
    +    // If the resulting closure is not serializable even after
    +    // cleaning, we expect ClosureCleaner to throw a SparkException
    +    if (serializableAfter) {
    +      ClosureCleaner.clean(closure, checkSerializable = true, transitive)
    +    } else {
    +      intercept[SparkException] {
    +        ClosureCleaner.clean(closure, checkSerializable = true, transitive)
    +      }
    +    }
    +    assertSerializable(closure, serializableAfter)
    +  }
    +
    +  /**
    +   * Return the fields accessed by the given closure by class.
    +   * This also optionally finds the fields transitively referenced through 
methods invocations.
    +   */
    +  private def findAccessedFields(
    +      closure: AnyRef,
    +      outerClasses: Seq[Class[_]],
    +      findTransitively: Boolean): Map[Class[_], Set[String]] = {
    +    val fields = new mutable.HashMap[Class[_], mutable.Set[String]]
    +    outerClasses.foreach { c => fields(c) = new mutable.HashSet[String] }
    +    ClosureCleaner.getClassReader(closure.getClass)
    +      .accept(new FieldAccessFinder(fields, findTransitively), 0)
    +    fields.mapValues(_.toSet).toMap
    +  }
    +
    +  // Accessors for private methods
    +  private val _isClosure = PrivateMethod[Boolean]('isClosure)
    +  private val _getInnerClasses = 
PrivateMethod[List[Class[_]]]('getInnerClasses)
    +  private val _getOuterClasses = 
PrivateMethod[List[Class[_]]]('getOuterClasses)
    +  private val _getOuterObjects = 
PrivateMethod[List[AnyRef]]('getOuterObjects)
    +
    +  private def isClosure(obj: AnyRef): Boolean = {
    +    ClosureCleaner invokePrivate _isClosure(obj)
    +  }
    +
    +  private def getInnerClasses(closure: AnyRef): List[Class[_]] = {
    +    ClosureCleaner invokePrivate _getInnerClasses(closure)
    +  }
    +
    +  private def getOuterClasses(closure: AnyRef): List[Class[_]] = {
    +    ClosureCleaner invokePrivate _getOuterClasses(closure)
    +  }
    +
    +  private def getOuterObjects(closure: AnyRef): List[AnyRef] = {
    +    ClosureCleaner invokePrivate _getOuterObjects(closure)
    +  }
    +
    +  test("get inner classes") {
    +    val closure1 = () => 1
    +    val closure2 = () => { () => 1 }
    +    val closure3 = (i: Int) => {
    +      (1 to i).map { x => x + 1 }.filter { x => x > 5 }
    +    }
    +    val closure4 = (j: Int) => {
    +      (1 to j).flatMap { x =>
    +        (1 to x).flatMap { y =>
    +          (1 to y).map { z => z + 1 }
    +        }
    +      }
    +    }
    +    val inner1 = getInnerClasses(closure1)
    +    val inner2 = getInnerClasses(closure2)
    +    val inner3 = getInnerClasses(closure3)
    +    val inner4 = getInnerClasses(closure4)
    +    assert(inner1.isEmpty)
    +    assert(inner2.size === 1)
    +    assert(inner3.size === 2)
    +    assert(inner4.size === 3)
    +    assert(inner2.forall(isClosure))
    +    assert(inner3.forall(isClosure))
    +    assert(inner4.forall(isClosure))
    +  }
    +
    +  test("get outer classes and objects") {
    +    val localValue = someSerializableValue
    +    val closure1 = () => 1
    +    val closure2 = () => localValue
    +    val closure3 = () => someSerializableValue
    +    val closure4 = () => someSerializableMethod()
    +    val outerClasses1 = getOuterClasses(closure1)
    +    val outerClasses2 = getOuterClasses(closure2)
    +    val outerClasses3 = getOuterClasses(closure3)
    +    val outerClasses4 = getOuterClasses(closure4)
    +    val outerObjects1 = getOuterObjects(closure1)
    +    val outerObjects2 = getOuterObjects(closure2)
    +    val outerObjects3 = getOuterObjects(closure3)
    +    val outerObjects4 = getOuterObjects(closure4)
    +
    +    // The classes and objects should have the same size
    +    assert(outerClasses1.size === outerObjects1.size)
    +    assert(outerClasses2.size === outerObjects2.size)
    +    assert(outerClasses3.size === outerObjects3.size)
    +    assert(outerClasses4.size === outerObjects4.size)
    +
    +    // These do not have $outer pointers because they reference only local 
variables
    +    assert(outerClasses1.isEmpty)
    +    assert(outerClasses2.isEmpty)
    +
    +    // These closures do have $outer pointers because they ultimately 
reference `this`
    +    // The first $outer pointer refers to the closure defines this test 
(see FunSuite#test)
    +    // The second $outer pointer refers to ClosureCleanerSuite2
    +    assert(outerClasses3.size === 2)
    +    assert(outerClasses4.size === 2)
    +    assert(isClosure(outerClasses3(0)))
    +    assert(isClosure(outerClasses4(0)))
    +    assert(outerClasses3(0) === outerClasses4(0)) // part of the same 
"FunSuite#test" scope
    +    assert(outerClasses3(1) === this.getClass)
    +    assert(outerClasses4(1) === this.getClass)
    +    assert(outerObjects3(1) === this)
    +    assert(outerObjects4(1) === this)
    +  }
    +
    +  test("get outer classes and objects with nesting") {
    +    val localValue = someSerializableValue
    +
    +    val test1 = () => {
    +      val x = 1
    +      val closure1 = () => 1
    +      val closure2 = () => x
    +      val outerClasses1 = getOuterClasses(closure1)
    +      val outerClasses2 = getOuterClasses(closure2)
    +      val outerObjects1 = getOuterObjects(closure1)
    +      val outerObjects2 = getOuterObjects(closure2)
    +      assert(outerClasses1.size === outerObjects1.size)
    +      assert(outerClasses2.size === outerObjects2.size)
    +      // These inner closures only reference local variables, and so do 
not have $outer pointers
    +      assert(outerClasses1.isEmpty)
    +      assert(outerClasses2.isEmpty)
    +    }
    +
    +    val test2 = () => {
    +      def y = 1
    +      val closure1 = () => 1
    +      val closure2 = () => y
    +      val closure3 = () => localValue
    +      val outerClasses1 = getOuterClasses(closure1)
    +      val outerClasses2 = getOuterClasses(closure2)
    +      val outerClasses3 = getOuterClasses(closure3)
    +      val outerObjects1 = getOuterObjects(closure1)
    +      val outerObjects2 = getOuterObjects(closure2)
    +      val outerObjects3 = getOuterObjects(closure3)
    +      assert(outerClasses1.size === outerObjects1.size)
    +      assert(outerClasses2.size === outerObjects2.size)
    +      assert(outerClasses3.size === outerObjects3.size)
    +      // Same as above, this closure only references local variables
    +      assert(outerClasses1.isEmpty)
    +      // This closure references the "test2" scope because it needs to 
find the method `y`
    +      // Scope hierarchy: "test2" < "FunSuite#test" < ClosureCleanerSuite2
    +      assert(outerClasses2.size === 3)
    +      // This closure references the "test2" scope because it needs to 
find the `localValue`
    +      // defined outside of this scope
    +      assert(outerClasses3.size === 3)
    +      assert(isClosure(outerClasses2(0)))
    +      assert(isClosure(outerClasses3(0)))
    +      assert(isClosure(outerClasses2(1)))
    +      assert(isClosure(outerClasses3(1)))
    +      assert(outerClasses2(0) === outerClasses3(0)) // part of the same 
"test2" scope
    +      assert(outerClasses2(1) === outerClasses3(1)) // part of the same 
"FunSuite#test" scope
    +      assert(outerClasses2(2) === this.getClass)
    +      assert(outerClasses3(2) === this.getClass)
    +      assert(outerObjects2(2) === this)
    +      assert(outerObjects3(2) === this)
    +    }
    +
    +    test1()
    +    test2()
    +  }
    +
    +  test("find accessed fields") {
    +    val localValue = someSerializableValue
    +    val closure1 = () => 1
    +    val closure2 = () => localValue
    +    val closure3 = () => someSerializableValue
    +    val outerClasses1 = getOuterClasses(closure1)
    +    val outerClasses2 = getOuterClasses(closure2)
    +    val outerClasses3 = getOuterClasses(closure3)
    +
    +    val fields1 = findAccessedFields(closure1, outerClasses1, 
findTransitively = false)
    +    val fields2 = findAccessedFields(closure2, outerClasses2, 
findTransitively = false)
    +    val fields3 = findAccessedFields(closure3, outerClasses3, 
findTransitively = false)
    +    assert(fields1.isEmpty)
    +    assert(fields2.isEmpty)
    +    assert(fields3.size === 2)
    +    // This corresponds to the "FunSuite#test" closure. This is empty 
because the
    +    // `someSerializableValue` belongs to its parent (i.e. 
ClosureCleanerSuite2).
    +    assert(fields3(outerClasses3(0)).isEmpty)
    +    // This corresponds to the ClosureCleanerSuite2. This is also empty, 
however,
    +    // because accessing a `ClosureCleanerSuite2#someSerializableValue` 
actually involves a
    +    // method call. Since we do not find fields transitively, we will not 
recursively trace
    +    // through the fields referenced by this method.
    +    assert(fields3(outerClasses3(1)).isEmpty)
    +
    +    val fields1t = findAccessedFields(closure1, outerClasses1, 
findTransitively = true)
    +    val fields2t = findAccessedFields(closure2, outerClasses2, 
findTransitively = true)
    +    val fields3t = findAccessedFields(closure3, outerClasses3, 
findTransitively = true)
    +    assert(fields1t.isEmpty)
    +    assert(fields2t.isEmpty)
    +    assert(fields3t.size === 2)
    +    // Because we find fields transitively now, we are able to detect that 
we need the
    +    // $outer pointer to get the field from the ClosureCleanerSuite2
    +    assert(fields3t(outerClasses3(0)).size === 1)
    +    assert(fields3t(outerClasses3(0)).head === "$outer")
    +    assert(fields3t(outerClasses3(1)).size === 1)
    +    
assert(fields3t(outerClasses3(1)).head.contains("someSerializableValue"))
    +  }
    +
    +  test("find accessed fields with nesting") {
    +    val localValue = someSerializableValue
    +
    +    val test1 = () => {
    +      def a = localValue + 1
    +      val closure1 = () => 1
    +      val closure2 = () => a
    +      val closure3 = () => localValue
    +      val closure4 = () => someSerializableValue
    +      val outerClasses1 = getOuterClasses(closure1)
    +      val outerClasses2 = getOuterClasses(closure2)
    +      val outerClasses3 = getOuterClasses(closure3)
    +      val outerClasses4 = getOuterClasses(closure4)
    +
    +      // First, find only fields accessed directly, not transitively, by 
these closures
    +      val fields1 = findAccessedFields(closure1, outerClasses1, 
findTransitively = false)
    +      val fields2 = findAccessedFields(closure2, outerClasses2, 
findTransitively = false)
    +      val fields3 = findAccessedFields(closure3, outerClasses3, 
findTransitively = false)
    +      val fields4 = findAccessedFields(closure4, outerClasses4, 
findTransitively = false)
    +      assert(fields1.isEmpty)
    +      // "test1" < "FunSuite#test" < ClosureCleanerSuite2
    --- End diff --
    
    I think the middle one should maybe be clear that it's really the anonymous 
class passed into the `test` method, right?


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to