Repository: spark
Updated Branches:
  refs/heads/master b162cc0c2 -> fd48d80a6


[SPARK-17822][R] Make JVMObjectTracker a member variable of RBackend

## What changes were proposed in this pull request?

* This PR changes `JVMObjectTracker` from `object` to `class` and let its 
instance associated with each RBackend. So we can manage the lifecycle of JVM 
objects when there are multiple `RBackend` sessions. `RBackend.close` will 
clear the object tracker explicitly.
* I assume that `SQLUtils` and `RRunner` do not need to track JVM instances, 
which could be wrong.
* Small refactor of `SerDe.sqlSerDe` to increase readability.

## How was this patch tested?

* Added unit tests for `JVMObjectTracker`.
* Wait for Jenkins to run full tests.

Author: Xiangrui Meng <m...@databricks.com>

Closes #16154 from mengxr/SPARK-17822.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fd48d80a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fd48d80a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fd48d80a

Branch: refs/heads/master
Commit: fd48d80a6145ea94f03e7fc6e4d724a0fbccac58
Parents: b162cc0
Author: Xiangrui Meng <m...@databricks.com>
Authored: Fri Dec 9 07:51:46 2016 -0800
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Fri Dec 9 07:51:46 2016 -0800

----------------------------------------------------------------------
 .../apache/spark/api/r/JVMObjectTracker.scala   | 87 ++++++++++++++++++
 .../scala/org/apache/spark/api/r/RBackend.scala |  6 +-
 .../apache/spark/api/r/RBackendHandler.scala    | 54 ++----------
 .../scala/org/apache/spark/api/r/RRunner.scala  |  2 +-
 .../scala/org/apache/spark/api/r/SerDe.scala    | 92 ++++++++++++--------
 .../spark/api/r/JVMObjectTrackerSuite.scala     | 73 ++++++++++++++++
 .../org/apache/spark/api/r/RBackendSuite.scala  | 31 +++++++
 .../org/apache/spark/sql/api/r/SQLUtils.scala   | 12 +--
 8 files changed, 265 insertions(+), 92 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/fd48d80a/core/src/main/scala/org/apache/spark/api/r/JVMObjectTracker.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/r/JVMObjectTracker.scala 
b/core/src/main/scala/org/apache/spark/api/r/JVMObjectTracker.scala
new file mode 100644
index 0000000..3432700
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/r/JVMObjectTracker.scala
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.r
+
+import java.util.concurrent.atomic.AtomicInteger
+import java.util.concurrent.ConcurrentHashMap
+
+/** JVM object ID wrapper */
+private[r] case class JVMObjectId(id: String) {
+  require(id != null, "Object ID cannot be null.")
+}
+
+/**
+ * Counter that tracks JVM objects returned to R.
+ * This is useful for referencing these objects in RPC calls.
+ */
+private[r] class JVMObjectTracker {
+
+  private[this] val objMap = new ConcurrentHashMap[JVMObjectId, Object]()
+  private[this] val objCounter = new AtomicInteger()
+
+  /**
+   * Returns the JVM object associated with the input key or None if not found.
+   */
+  final def get(id: JVMObjectId): Option[Object] = this.synchronized {
+    if (objMap.containsKey(id)) {
+      Some(objMap.get(id))
+    } else {
+      None
+    }
+  }
+
+  /**
+   * Returns the JVM object associated with the input key or throws an 
exception if not found.
+   */
+  @throws[NoSuchElementException]("if key does not exist.")
+  final def apply(id: JVMObjectId): Object = {
+    get(id).getOrElse(
+      throw new NoSuchElementException(s"$id does not exist.")
+    )
+  }
+
+  /**
+   * Adds a JVM object to track and returns assigned ID, which is unique 
within this tracker.
+   */
+  final def addAndGetId(obj: Object): JVMObjectId = {
+    val id = JVMObjectId(objCounter.getAndIncrement().toString)
+    objMap.put(id, obj)
+    id
+  }
+
+  /**
+   * Removes and returns a JVM object with the specific ID from the tracker, 
or None if not found.
+   */
+  final def remove(id: JVMObjectId): Option[Object] = this.synchronized {
+    if (objMap.containsKey(id)) {
+      Some(objMap.remove(id))
+    } else {
+      None
+    }
+  }
+
+  /**
+   * Number of JVM objects being tracked.
+   */
+  final def size: Int = objMap.size()
+
+  /**
+   * Clears the tracker.
+   */
+  final def clear(): Unit = objMap.clear()
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/fd48d80a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala 
b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
index 550746c..2d1152a 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
@@ -22,7 +22,7 @@ import java.net.{InetAddress, InetSocketAddress, ServerSocket}
 import java.util.concurrent.TimeUnit
 
 import io.netty.bootstrap.ServerBootstrap
-import io.netty.channel.{ChannelFuture, ChannelInitializer, ChannelOption, 
EventLoopGroup}
+import io.netty.channel.{ChannelFuture, ChannelInitializer, EventLoopGroup}
 import io.netty.channel.nio.NioEventLoopGroup
 import io.netty.channel.socket.SocketChannel
 import io.netty.channel.socket.nio.NioServerSocketChannel
@@ -42,6 +42,9 @@ private[spark] class RBackend {
   private[this] var bootstrap: ServerBootstrap = null
   private[this] var bossGroup: EventLoopGroup = null
 
+  /** Tracks JVM objects returned to R for this RBackend instance. */
+  private[r] val jvmObjectTracker = new JVMObjectTracker
+
   def init(): Int = {
     val conf = new SparkConf()
     val backendConnectionTimeout = conf.getInt(
@@ -94,6 +97,7 @@ private[spark] class RBackend {
       bootstrap.childGroup().shutdownGracefully()
     }
     bootstrap = null
+    jvmObjectTracker.clear()
   }
 
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/fd48d80a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala 
b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
index 9f5afa2..cfd37ac 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
@@ -20,7 +20,6 @@ package org.apache.spark.api.r
 import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, 
DataOutputStream}
 import java.util.concurrent.TimeUnit
 
-import scala.collection.mutable.HashMap
 import scala.language.existentials
 
 import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
@@ -62,7 +61,7 @@ private[r] class RBackendHandler(server: RBackend)
           assert(numArgs == 1)
 
           writeInt(dos, 0)
-          writeObject(dos, args(0))
+          writeObject(dos, args(0), server.jvmObjectTracker)
         case "stopBackend" =>
           writeInt(dos, 0)
           writeType(dos, "void")
@@ -72,9 +71,9 @@ private[r] class RBackendHandler(server: RBackend)
             val t = readObjectType(dis)
             assert(t == 'c')
             val objToRemove = readString(dis)
-            JVMObjectTracker.remove(objToRemove)
+            server.jvmObjectTracker.remove(JVMObjectId(objToRemove))
             writeInt(dos, 0)
-            writeObject(dos, null)
+            writeObject(dos, null, server.jvmObjectTracker)
           } catch {
             case e: Exception =>
               logError(s"Removing $objId failed", e)
@@ -143,12 +142,8 @@ private[r] class RBackendHandler(server: RBackend)
       val cls = if (isStatic) {
         Utils.classForName(objId)
       } else {
-        JVMObjectTracker.get(objId) match {
-          case None => throw new IllegalArgumentException("Object not found " 
+ objId)
-          case Some(o) =>
-            obj = o
-            o.getClass
-        }
+        obj = server.jvmObjectTracker(JVMObjectId(objId))
+        obj.getClass
       }
 
       val args = readArgs(numArgs, dis)
@@ -173,7 +168,7 @@ private[r] class RBackendHandler(server: RBackend)
 
         // Write status bit
         writeInt(dos, 0)
-        writeObject(dos, ret.asInstanceOf[AnyRef])
+        writeObject(dos, ret.asInstanceOf[AnyRef], server.jvmObjectTracker)
       } else if (methodName == "<init>") {
         // methodName should be "<init>" for constructor
         val ctors = cls.getConstructors
@@ -193,7 +188,7 @@ private[r] class RBackendHandler(server: RBackend)
         val obj = ctors(index.get).newInstance(args : _*)
 
         writeInt(dos, 0)
-        writeObject(dos, obj.asInstanceOf[AnyRef])
+        writeObject(dos, obj.asInstanceOf[AnyRef], server.jvmObjectTracker)
       } else {
         throw new IllegalArgumentException("invalid method " + methodName + " 
for object " + objId)
       }
@@ -210,7 +205,7 @@ private[r] class RBackendHandler(server: RBackend)
   // Read a number of arguments from the data input stream
   def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = {
     (0 until numArgs).map { _ =>
-      readObject(dis)
+      readObject(dis, server.jvmObjectTracker)
     }.toArray
   }
 
@@ -286,37 +281,4 @@ private[r] class RBackendHandler(server: RBackend)
   }
 }
 
-/**
- * Helper singleton that tracks JVM objects returned to R.
- * This is useful for referencing these objects in RPC calls.
- */
-private[r] object JVMObjectTracker {
-
-  // TODO: This map should be thread-safe if we want to support multiple
-  // connections at the same time
-  private[this] val objMap = new HashMap[String, Object]
-
-  // TODO: We support only one connection now, so an integer is fine.
-  // Investigate using use atomic integer in the future.
-  private[this] var objCounter: Int = 0
-
-  def getObject(id: String): Object = {
-    objMap(id)
-  }
-
-  def get(id: String): Option[Object] = {
-    objMap.get(id)
-  }
-
-  def put(obj: Object): String = {
-    val objId = objCounter.toString
-    objCounter = objCounter + 1
-    objMap.put(objId, obj)
-    objId
-  }
 
-  def remove(id: String): Option[Object] = {
-    objMap.remove(id)
-  }
-
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/fd48d80a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala 
b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
index 7ef6472..29e21b3 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
@@ -152,7 +152,7 @@ private[spark] class RRunner[U](
           dataOut.writeInt(mode)
 
           if (isDataFrame) {
-            SerDe.writeObject(dataOut, colNames)
+            SerDe.writeObject(dataOut, colNames, jvmObjectTracker = null)
           }
 
           if (!iter.hasNext) {

http://git-wip-us.apache.org/repos/asf/spark/blob/fd48d80a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala 
b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
index 550e075..dad928c 100644
--- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
@@ -28,13 +28,20 @@ import scala.collection.mutable.WrappedArray
  * Utility functions to serialize, deserialize objects to / from R
  */
 private[spark] object SerDe {
-  type ReadObject = (DataInputStream, Char) => Object
-  type WriteObject = (DataOutputStream, Object) => Boolean
+  type SQLReadObject = (DataInputStream, Char) => Object
+  type SQLWriteObject = (DataOutputStream, Object) => Boolean
 
-  var sqlSerDe: (ReadObject, WriteObject) = _
+  private[this] var sqlReadObject: SQLReadObject = _
+  private[this] var sqlWriteObject: SQLWriteObject = _
 
-  def registerSqlSerDe(sqlSerDe: (ReadObject, WriteObject)): Unit = {
-    this.sqlSerDe = sqlSerDe
+  def setSQLReadObject(value: SQLReadObject): this.type = {
+    sqlReadObject = value
+    this
+  }
+
+  def setSQLWriteObject(value: SQLWriteObject): this.type = {
+    sqlWriteObject = value
+    this
   }
 
   // Type mapping from R to Java
@@ -56,32 +63,33 @@ private[spark] object SerDe {
     dis.readByte().toChar
   }
 
-  def readObject(dis: DataInputStream): Object = {
+  def readObject(dis: DataInputStream, jvmObjectTracker: JVMObjectTracker): 
Object = {
     val dataType = readObjectType(dis)
-    readTypedObject(dis, dataType)
+    readTypedObject(dis, dataType, jvmObjectTracker)
   }
 
   def readTypedObject(
       dis: DataInputStream,
-      dataType: Char): Object = {
+      dataType: Char,
+      jvmObjectTracker: JVMObjectTracker): Object = {
     dataType match {
       case 'n' => null
       case 'i' => new java.lang.Integer(readInt(dis))
       case 'd' => new java.lang.Double(readDouble(dis))
       case 'b' => new java.lang.Boolean(readBoolean(dis))
       case 'c' => readString(dis)
-      case 'e' => readMap(dis)
+      case 'e' => readMap(dis, jvmObjectTracker)
       case 'r' => readBytes(dis)
-      case 'a' => readArray(dis)
-      case 'l' => readList(dis)
+      case 'a' => readArray(dis, jvmObjectTracker)
+      case 'l' => readList(dis, jvmObjectTracker)
       case 'D' => readDate(dis)
       case 't' => readTime(dis)
-      case 'j' => JVMObjectTracker.getObject(readString(dis))
+      case 'j' => jvmObjectTracker(JVMObjectId(readString(dis)))
       case _ =>
-        if (sqlSerDe == null || sqlSerDe._1 == null) {
+        if (sqlReadObject == null) {
           throw new IllegalArgumentException (s"Invalid type $dataType")
         } else {
-          val obj = (sqlSerDe._1)(dis, dataType)
+          val obj = sqlReadObject(dis, dataType)
           if (obj == null) {
             throw new IllegalArgumentException (s"Invalid type $dataType")
           } else {
@@ -181,28 +189,28 @@ private[spark] object SerDe {
   }
 
   // All elements of an array must be of the same type
-  def readArray(dis: DataInputStream): Array[_] = {
+  def readArray(dis: DataInputStream, jvmObjectTracker: JVMObjectTracker): 
Array[_] = {
     val arrType = readObjectType(dis)
     arrType match {
       case 'i' => readIntArr(dis)
       case 'c' => readStringArr(dis)
       case 'd' => readDoubleArr(dis)
       case 'b' => readBooleanArr(dis)
-      case 'j' => readStringArr(dis).map(x => JVMObjectTracker.getObject(x))
+      case 'j' => readStringArr(dis).map(x => jvmObjectTracker(JVMObjectId(x)))
       case 'r' => readBytesArr(dis)
       case 'a' =>
         val len = readInt(dis)
-        (0 until len).map(_ => readArray(dis)).toArray
+        (0 until len).map(_ => readArray(dis, jvmObjectTracker)).toArray
       case 'l' =>
         val len = readInt(dis)
-        (0 until len).map(_ => readList(dis)).toArray
+        (0 until len).map(_ => readList(dis, jvmObjectTracker)).toArray
       case _ =>
-        if (sqlSerDe == null || sqlSerDe._1 == null) {
+        if (sqlReadObject == null) {
           throw new IllegalArgumentException (s"Invalid array type $arrType")
         } else {
           val len = readInt(dis)
           (0 until len).map { _ =>
-            val obj = (sqlSerDe._1)(dis, arrType)
+            val obj = sqlReadObject(dis, arrType)
             if (obj == null) {
               throw new IllegalArgumentException (s"Invalid array type 
$arrType")
             } else {
@@ -215,17 +223,19 @@ private[spark] object SerDe {
 
   // Each element of a list can be of different type. They are all represented
   // as Object on JVM side
-  def readList(dis: DataInputStream): Array[Object] = {
+  def readList(dis: DataInputStream, jvmObjectTracker: JVMObjectTracker): 
Array[Object] = {
     val len = readInt(dis)
-    (0 until len).map(_ => readObject(dis)).toArray
+    (0 until len).map(_ => readObject(dis, jvmObjectTracker)).toArray
   }
 
-  def readMap(in: DataInputStream): java.util.Map[Object, Object] = {
+  def readMap(
+      in: DataInputStream,
+      jvmObjectTracker: JVMObjectTracker): java.util.Map[Object, Object] = {
     val len = readInt(in)
     if (len > 0) {
       // Keys is an array of String
-      val keys = readArray(in).asInstanceOf[Array[Object]]
-      val values = readList(in)
+      val keys = readArray(in, jvmObjectTracker).asInstanceOf[Array[Object]]
+      val values = readList(in, jvmObjectTracker)
 
       keys.zip(values).toMap.asJava
     } else {
@@ -272,7 +282,11 @@ private[spark] object SerDe {
     }
   }
 
-  private def writeKeyValue(dos: DataOutputStream, key: Object, value: 
Object): Unit = {
+  private def writeKeyValue(
+      dos: DataOutputStream,
+      key: Object,
+      value: Object,
+      jvmObjectTracker: JVMObjectTracker): Unit = {
     if (key == null) {
       throw new IllegalArgumentException("Key in map can't be null.")
     } else if (!key.isInstanceOf[String]) {
@@ -280,10 +294,10 @@ private[spark] object SerDe {
     }
 
     writeString(dos, key.asInstanceOf[String])
-    writeObject(dos, value)
+    writeObject(dos, value, jvmObjectTracker)
   }
 
-  def writeObject(dos: DataOutputStream, obj: Object): Unit = {
+  def writeObject(dos: DataOutputStream, obj: Object, jvmObjectTracker: 
JVMObjectTracker): Unit = {
     if (obj == null) {
       writeType(dos, "void")
     } else {
@@ -373,14 +387,14 @@ private[spark] object SerDe {
         case v: Array[Object] =>
           writeType(dos, "list")
           writeInt(dos, v.length)
-          v.foreach(elem => writeObject(dos, elem))
+          v.foreach(elem => writeObject(dos, elem, jvmObjectTracker))
 
         // Handle Properties
         // This must be above the case java.util.Map below.
         // (Properties implements Map<Object,Object> and will be serialized as 
map otherwise)
         case v: java.util.Properties =>
           writeType(dos, "jobj")
-          writeJObj(dos, value)
+          writeJObj(dos, value, jvmObjectTracker)
 
         // Handle map
         case v: java.util.Map[_, _] =>
@@ -392,19 +406,21 @@ private[spark] object SerDe {
             val key = entry.getKey
             val value = entry.getValue
 
-            writeKeyValue(dos, key.asInstanceOf[Object], 
value.asInstanceOf[Object])
+            writeKeyValue(
+              dos, key.asInstanceOf[Object], value.asInstanceOf[Object], 
jvmObjectTracker)
           }
         case v: scala.collection.Map[_, _] =>
           writeType(dos, "map")
           writeInt(dos, v.size)
-          v.foreach { case (key, value) =>
-            writeKeyValue(dos, key.asInstanceOf[Object], 
value.asInstanceOf[Object])
+          v.foreach { case (k1, v1) =>
+            writeKeyValue(dos, k1.asInstanceOf[Object], 
v1.asInstanceOf[Object], jvmObjectTracker)
           }
 
         case _ =>
-          if (sqlSerDe == null || sqlSerDe._2 == null || !(sqlSerDe._2)(dos, 
value)) {
+          val sqlWriteSucceeded = sqlWriteObject != null && 
sqlWriteObject(dos, value)
+          if (!sqlWriteSucceeded) {
             writeType(dos, "jobj")
-            writeJObj(dos, value)
+            writeJObj(dos, value, jvmObjectTracker)
           }
       }
     }
@@ -447,9 +463,9 @@ private[spark] object SerDe {
     out.write(value)
   }
 
-  def writeJObj(out: DataOutputStream, value: Object): Unit = {
-    val objId = JVMObjectTracker.put(value)
-    writeString(out, objId)
+  def writeJObj(out: DataOutputStream, value: Object, jvmObjectTracker: 
JVMObjectTracker): Unit = {
+    val JVMObjectId(id) = jvmObjectTracker.addAndGetId(value)
+    writeString(out, id)
   }
 
   def writeIntArr(out: DataOutputStream, value: Array[Int]): Unit = {

http://git-wip-us.apache.org/repos/asf/spark/blob/fd48d80a/core/src/test/scala/org/apache/spark/api/r/JVMObjectTrackerSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/api/r/JVMObjectTrackerSuite.scala 
b/core/src/test/scala/org/apache/spark/api/r/JVMObjectTrackerSuite.scala
new file mode 100644
index 0000000..6a979ae
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/api/r/JVMObjectTrackerSuite.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.r
+
+import org.apache.spark.SparkFunSuite
+
+class JVMObjectTrackerSuite extends SparkFunSuite {
+  test("JVMObjectId does not take null IDs") {
+    intercept[IllegalArgumentException] {
+      JVMObjectId(null)
+    }
+  }
+
+  test("JVMObjectTracker") {
+    val tracker = new JVMObjectTracker
+    assert(tracker.size === 0)
+    withClue("an empty tracker can be cleared") {
+      tracker.clear()
+    }
+    val none = JVMObjectId("none")
+    assert(tracker.get(none) === None)
+    intercept[NoSuchElementException] {
+      tracker(JVMObjectId("none"))
+    }
+
+    val obj1 = new Object
+    val id1 = tracker.addAndGetId(obj1)
+    assert(id1 != null)
+    assert(tracker.size === 1)
+    assert(tracker.get(id1).get.eq(obj1))
+    assert(tracker(id1).eq(obj1))
+
+    val obj2 = new Object
+    val id2 = tracker.addAndGetId(obj2)
+    assert(id1 !== id2)
+    assert(tracker.size === 2)
+    assert(tracker(id2).eq(obj2))
+
+    val Some(obj1Removed) = tracker.remove(id1)
+    assert(obj1Removed.eq(obj1))
+    assert(tracker.get(id1) === None)
+    assert(tracker.size === 1)
+    assert(tracker(id2).eq(obj2))
+
+    val obj3 = new Object
+    val id3 = tracker.addAndGetId(obj3)
+    assert(tracker.size === 2)
+    assert(id3 != id1)
+    assert(id3 != id2)
+    assert(tracker(id3).eq(obj3))
+
+    tracker.clear()
+    assert(tracker.size === 0)
+    assert(tracker.get(id1) === None)
+    assert(tracker.get(id2) === None)
+    assert(tracker.get(id3) === None)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/fd48d80a/core/src/test/scala/org/apache/spark/api/r/RBackendSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/api/r/RBackendSuite.scala 
b/core/src/test/scala/org/apache/spark/api/r/RBackendSuite.scala
new file mode 100644
index 0000000..085cc26
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/api/r/RBackendSuite.scala
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.r
+
+import org.apache.spark.SparkFunSuite
+
+class RBackendSuite extends SparkFunSuite {
+  test("close() clears jvmObjectTracker") {
+    val backend = new RBackend
+    val tracker = backend.jvmObjectTracker
+    val id = tracker.addAndGetId(new Object)
+    backend.close()
+    assert(tracker.get(id) === None)
+    assert(tracker.size === 0)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/fd48d80a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
index 9de6510..80bbad4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
@@ -36,7 +36,7 @@ import 
org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
 import org.apache.spark.sql.types._
 
 private[sql] object SQLUtils extends Logging {
-  SerDe.registerSqlSerDe((readSqlObject, writeSqlObject))
+  SerDe.setSQLReadObject(readSqlObject).setSQLWriteObject(writeSqlObject)
 
   private[this] def withHiveExternalCatalog(sc: SparkContext): SparkContext = {
     sc.conf.set(CATALOG_IMPLEMENTATION.key, "hive")
@@ -158,7 +158,7 @@ private[sql] object SQLUtils extends Logging {
     val dis = new DataInputStream(bis)
     val num = SerDe.readInt(dis)
     Row.fromSeq((0 until num).map { i =>
-      doConversion(SerDe.readObject(dis), schema.fields(i).dataType)
+      doConversion(SerDe.readObject(dis, jvmObjectTracker = null), 
schema.fields(i).dataType)
     })
   }
 
@@ -167,7 +167,7 @@ private[sql] object SQLUtils extends Logging {
     val dos = new DataOutputStream(bos)
 
     val cols = (0 until row.length).map(row(_).asInstanceOf[Object]).toArray
-    SerDe.writeObject(dos, cols)
+    SerDe.writeObject(dos, cols, jvmObjectTracker = null)
     bos.toByteArray()
   }
 
@@ -247,7 +247,7 @@ private[sql] object SQLUtils extends Logging {
     dataType match {
       case 's' =>
         // Read StructType for DataFrame
-        val fields = SerDe.readList(dis).asInstanceOf[Array[Object]]
+        val fields = SerDe.readList(dis, jvmObjectTracker = 
null).asInstanceOf[Array[Object]]
         Row.fromSeq(fields)
       case _ => null
     }
@@ -258,8 +258,8 @@ private[sql] object SQLUtils extends Logging {
       // Handle struct type in DataFrame
       case v: GenericRowWithSchema =>
         dos.writeByte('s')
-        SerDe.writeObject(dos, v.schema.fieldNames)
-        SerDe.writeObject(dos, v.values)
+        SerDe.writeObject(dos, v.schema.fieldNames, jvmObjectTracker = null)
+        SerDe.writeObject(dos, v.values, jvmObjectTracker = null)
         true
       case _ =>
         false


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to