Repository: spark
Updated Branches:
  refs/heads/master c7032290a -> cc3648774


[SPARK-3046] use executor's class loader as the default serializer classloader

The serializer is not always used in an executor thread (e.g. connection 
manager, broadcast), in which case the classloader might not have the user jar 
set, leading to corruption in deserialization.

https://issues.apache.org/jira/browse/SPARK-3046

https://issues.apache.org/jira/browse/SPARK-2878

Author: Reynold Xin <[email protected]>

Closes #1972 from rxin/kryoBug and squashes the following commits:

c1c7bf0 [Reynold Xin] Made change to JavaSerializer.
7204c33 [Reynold Xin] Added imports back.
d879e67 [Reynold Xin] [SPARK-3046] use executor's class loader as the default 
serializer class loader.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/cc364877
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/cc364877
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/cc364877

Branch: refs/heads/master
Commit: cc3648774e9a744850107bb187f2828d447e0a48
Parents: c703229
Author: Reynold Xin <[email protected]>
Authored: Fri Aug 15 17:04:15 2014 -0700
Committer: Reynold Xin <[email protected]>
Committed: Fri Aug 15 17:04:15 2014 -0700

----------------------------------------------------------------------
 .../org/apache/spark/executor/Executor.scala    |  3 +
 .../spark/serializer/JavaSerializer.scala       |  9 ++-
 .../spark/serializer/KryoSerializer.scala       |  9 ++-
 .../apache/spark/serializer/Serializer.scala    | 17 +++++
 .../KryoSerializerDistributedSuite.scala        | 71 ++++++++++++++++++++
 .../spark/serializer/KryoSerializerSuite.scala  | 23 ++++++-
 6 files changed, 128 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/cc364877/core/src/main/scala/org/apache/spark/executor/Executor.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala 
b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index eac1f23..fb3f7bd 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -99,6 +99,9 @@ private[spark] class Executor(
   private val urlClassLoader = createClassLoader()
   private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader)
 
+  // Set the classloader for serializer
+  env.serializer.setDefaultClassLoader(urlClassLoader)
+
   // Akka's message frame size. If task result is bigger than this, we use the 
block manager
   // to send the result back.
   private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf)

http://git-wip-us.apache.org/repos/asf/spark/blob/cc364877/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala 
b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
index 34bc312..af33a2f 100644
--- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
@@ -63,7 +63,9 @@ extends DeserializationStream {
   def close() { objIn.close() }
 }
 
-private[spark] class JavaSerializerInstance(counterReset: Int) extends 
SerializerInstance {
+private[spark] class JavaSerializerInstance(counterReset: Int, 
defaultClassLoader: ClassLoader)
+  extends SerializerInstance {
+
   def serialize[T: ClassTag](t: T): ByteBuffer = {
     val bos = new ByteArrayOutputStream()
     val out = serializeStream(bos)
@@ -109,7 +111,10 @@ private[spark] class JavaSerializerInstance(counterReset: 
Int) extends Serialize
 class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable {
   private var counterReset = conf.getInt("spark.serializer.objectStreamReset", 
100)
 
-  def newInstance(): SerializerInstance = new 
JavaSerializerInstance(counterReset)
+  override def newInstance(): SerializerInstance = {
+    val classLoader = 
defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader)
+    new JavaSerializerInstance(counterReset, classLoader)
+  }
 
   override def writeExternal(out: ObjectOutput) {
     out.writeInt(counterReset)

http://git-wip-us.apache.org/repos/asf/spark/blob/cc364877/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala 
b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index 85944ea..9968222 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -61,7 +61,9 @@ class KryoSerializer(conf: SparkConf)
     val instantiator = new EmptyScalaKryoInstantiator
     val kryo = instantiator.newKryo()
     kryo.setRegistrationRequired(registrationRequired)
-    val classLoader = Thread.currentThread.getContextClassLoader
+
+    val oldClassLoader = Thread.currentThread.getContextClassLoader
+    val classLoader = 
defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader)
 
     // Allow disabling Kryo reference tracking if user knows their object 
graphs don't have loops.
     // Do this before we invoke the user registrator so the user registrator 
can override this.
@@ -84,10 +86,15 @@ class KryoSerializer(conf: SparkConf)
       try {
         val reg = Class.forName(regCls, true, classLoader).newInstance()
           .asInstanceOf[KryoRegistrator]
+
+        // Use the default classloader when calling the user registrator.
+        Thread.currentThread.setContextClassLoader(classLoader)
         reg.registerClasses(kryo)
       } catch {
         case e: Exception => 
           throw new SparkException(s"Failed to invoke $regCls", e)
+      } finally {
+        Thread.currentThread.setContextClassLoader(oldClassLoader)
       }
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/cc364877/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala 
b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
index f2f5cea..e674438 100644
--- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
@@ -44,6 +44,23 @@ import org.apache.spark.util.{ByteBufferInputStream, 
NextIterator}
  */
 @DeveloperApi
 trait Serializer {
+
+  /**
+   * Default ClassLoader to use in deserialization. Implementations of 
[[Serializer]] should
+   * make sure it is using this when set.
+   */
+  @volatile protected var defaultClassLoader: Option[ClassLoader] = None
+
+  /**
+   * Sets a class loader for the serializer to use in deserialization.
+   *
+   * @return this Serializer object
+   */
+  def setDefaultClassLoader(classLoader: ClassLoader): Serializer = {
+    defaultClassLoader = Some(classLoader)
+    this
+  }
+
   def newInstance(): SerializerInstance
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/cc364877/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala
 
b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala
new file mode 100644
index 0000000..11e8c9c
--- /dev/null
+++ 
b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.serializer
+
+import org.apache.spark.util.Utils
+
+import com.esotericsoftware.kryo.Kryo
+import org.scalatest.FunSuite
+
+import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, 
TestUtils}
+import org.apache.spark.SparkContext._
+import org.apache.spark.serializer.KryoDistributedTest._
+
+class KryoSerializerDistributedSuite extends FunSuite {
+
+  test("kryo objects are serialised consistently in different processes") {
+    val conf = new SparkConf(false)
+    conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
+    conf.set("spark.kryo.registrator", classOf[AppJarRegistrator].getName)
+    conf.set("spark.task.maxFailures", "1")
+
+    val jar = 
TestUtils.createJarWithClasses(List(AppJarRegistrator.customClassName))
+    conf.setJars(List(jar.getPath))
+
+    val sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
+    val original = Thread.currentThread.getContextClassLoader
+    val loader = new java.net.URLClassLoader(Array(jar), 
Utils.getContextOrSparkClassLoader)
+    SparkEnv.get.serializer.setDefaultClassLoader(loader)
+
+    val cachedRDD = sc.parallelize((0 until 10).map((_, new MyCustomClass)), 
3).cache()
+
+    // Randomly mix the keys so that the join below will require a shuffle 
with each partition
+    // sending data to multiple other partitions.
+    val shuffledRDD = cachedRDD.map { case (i, o) => (i * i * i - 10 * i * i, 
o)}
+
+    // Join the two RDDs, and force evaluation
+    assert(shuffledRDD.join(cachedRDD).collect().size == 1)
+
+    LocalSparkContext.stop(sc)
+  }
+}
+
+object KryoDistributedTest {
+  class MyCustomClass
+
+  class AppJarRegistrator extends KryoRegistrator {
+    override def registerClasses(k: Kryo) {
+      val classLoader = Thread.currentThread.getContextClassLoader
+      k.register(Class.forName(AppJarRegistrator.customClassName, true, 
classLoader))
+    }
+  }
+
+  object AppJarRegistrator {
+    val customClassName = "KryoSerializerDistributedSuiteCustomClass"
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/cc364877/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala 
b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
index 3bf9efe..a579fd5 100644
--- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
@@ -23,7 +23,7 @@ import scala.reflect.ClassTag
 import com.esotericsoftware.kryo.Kryo
 import org.scalatest.FunSuite
 
-import org.apache.spark.SharedSparkContext
+import org.apache.spark.{SparkConf, SharedSparkContext}
 import org.apache.spark.serializer.KryoTest._
 
 class KryoSerializerSuite extends FunSuite with SharedSparkContext {
@@ -217,8 +217,29 @@ class KryoSerializerSuite extends FunSuite with 
SharedSparkContext {
     val thrown = intercept[SparkException](new 
KryoSerializer(conf).newInstance())
     assert(thrown.getMessage.contains("Failed to invoke 
this.class.does.not.exist"))
   }
+
+  test("default class loader can be set by a different thread") {
+    val ser = new KryoSerializer(new SparkConf)
+
+    // First serialize the object
+    val serInstance = ser.newInstance()
+    val bytes = serInstance.serialize(new ClassLoaderTestingObject)
+
+    // Deserialize the object to make sure normal deserialization works
+    serInstance.deserialize[ClassLoaderTestingObject](bytes)
+
+    // Set a special, broken ClassLoader and make sure we get an exception on 
deserialization
+    ser.setDefaultClassLoader(new ClassLoader() {
+      override def loadClass(name: String) = throw new 
UnsupportedOperationException
+    })
+    intercept[UnsupportedOperationException] {
+      ser.newInstance().deserialize[ClassLoaderTestingObject](bytes)
+    }
+  }
 }
 
+class ClassLoaderTestingObject
+
 class KryoSerializerResizableOutputSuite extends FunSuite {
   import org.apache.spark.SparkConf
   import org.apache.spark.SparkContext


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to