Repository: spark
Updated Branches:
  refs/heads/master 0bfd3afb0 -> abf588f47


[SPARK-3749] [PySpark] fix bugs in broadcast large closure of RDD

1. broadcast is triggle unexpected
2. fd is leaked in JVM (also leak in parallelize())
3. broadcast is not unpersisted in JVM after RDD is not be used any more.

cc JoshRosen , sorry for these stupid bugs.

Author: Davies Liu <[email protected]>

Closes #2603 from davies/fix_broadcast and squashes the following commits:

080a743 [Davies Liu] fix bugs in broadcast large closure of RDD


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/abf588f4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/abf588f4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/abf588f4

Branch: refs/heads/master
Commit: abf588f47a26d0066f0b75d52b200a87bb085064
Parents: 0bfd3af
Author: Davies Liu <[email protected]>
Authored: Wed Oct 1 11:21:34 2014 -0700
Committer: Josh Rosen <[email protected]>
Committed: Wed Oct 1 11:21:34 2014 -0700

----------------------------------------------------------------------
 .../org/apache/spark/api/python/PythonRDD.scala | 34 ++++++++++++--------
 python/pyspark/rdd.py                           | 12 +++++--
 python/pyspark/sql.py                           |  2 +-
 python/pyspark/tests.py                         |  8 +++--
 4 files changed, 37 insertions(+), 19 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/abf588f4/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index f9ff4ea..9241414 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -339,26 +339,34 @@ private[spark] object PythonRDD extends Logging {
   def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: 
Int):
   JavaRDD[Array[Byte]] = {
     val file = new DataInputStream(new FileInputStream(filename))
-    val objs = new collection.mutable.ArrayBuffer[Array[Byte]]
     try {
-      while (true) {
-        val length = file.readInt()
-        val obj = new Array[Byte](length)
-        file.readFully(obj)
-        objs.append(obj)
+      val objs = new collection.mutable.ArrayBuffer[Array[Byte]]
+      try {
+        while (true) {
+          val length = file.readInt()
+          val obj = new Array[Byte](length)
+          file.readFully(obj)
+          objs.append(obj)
+        }
+      } catch {
+        case eof: EOFException => {}
       }
-    } catch {
-      case eof: EOFException => {}
+      JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
+    } finally {
+      file.close()
     }
-    JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
   }
 
   def readBroadcastFromFile(sc: JavaSparkContext, filename: String): 
Broadcast[Array[Byte]] = {
     val file = new DataInputStream(new FileInputStream(filename))
-    val length = file.readInt()
-    val obj = new Array[Byte](length)
-    file.readFully(obj)
-    sc.broadcast(obj)
+    try {
+      val length = file.readInt()
+      val obj = new Array[Byte](length)
+      file.readFully(obj)
+      sc.broadcast(obj)
+    } finally {
+      file.close()
+    }
   }
 
   def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) {

http://git-wip-us.apache.org/repos/asf/spark/blob/abf588f4/python/pyspark/rdd.py
----------------------------------------------------------------------
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 8ed89e2..dc64977 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -2073,6 +2073,12 @@ class PipelinedRDD(RDD):
         self._jrdd_deserializer = self.ctx.serializer
         self._bypass_serializer = False
         self._partitionFunc = prev._partitionFunc if 
self.preservesPartitioning else None
+        self._broadcast = None
+
+    def __del__(self):
+        if self._broadcast:
+            self._broadcast.unpersist()
+            self._broadcast = None
 
     @property
     def _jrdd(self):
@@ -2087,9 +2093,9 @@ class PipelinedRDD(RDD):
         # the serialized command will be compressed by broadcast
         ser = CloudPickleSerializer()
         pickled_command = ser.dumps(command)
-        if pickled_command > (1 << 20):  # 1M
-            broadcast = self.ctx.broadcast(pickled_command)
-            pickled_command = ser.dumps(broadcast)
+        if len(pickled_command) > (1 << 20):  # 1M
+            self._broadcast = self.ctx.broadcast(pickled_command)
+            pickled_command = ser.dumps(self._broadcast)
         broadcast_vars = ListConverter().convert(
             [x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
             self.ctx._gateway._gateway_client)

http://git-wip-us.apache.org/repos/asf/spark/blob/abf588f4/python/pyspark/sql.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index d8bdf22..974b5e2 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -965,7 +965,7 @@ class SQLContext(object):
                    BatchedSerializer(PickleSerializer(), 1024))
         ser = CloudPickleSerializer()
         pickled_command = ser.dumps(command)
-        if pickled_command > (1 << 20):  # 1M
+        if len(pickled_command) > (1 << 20):  # 1M
             broadcast = self._sc.broadcast(pickled_command)
             pickled_command = ser.dumps(broadcast)
         broadcast_vars = ListConverter().convert(

http://git-wip-us.apache.org/repos/asf/spark/blob/abf588f4/python/pyspark/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 7e2bbc9..6fb6bc9 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -467,8 +467,12 @@ class TestRDDFunctions(PySparkTestCase):
     def test_large_closure(self):
         N = 1000000
         data = [float(i) for i in xrange(N)]
-        m = self.sc.parallelize(range(1), 1).map(lambda x: len(data)).sum()
-        self.assertEquals(N, m)
+        rdd = self.sc.parallelize(range(1), 1).map(lambda x: len(data))
+        self.assertEquals(N, rdd.first())
+        self.assertTrue(rdd._broadcast is not None)
+        rdd = self.sc.parallelize(range(1), 1).map(lambda x: 1)
+        self.assertEqual(1, rdd.first())
+        self.assertTrue(rdd._broadcast is None)
 
     def test_zip_with_different_serializers(self):
         a = self.sc.parallelize(range(5))


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to