Repository: spark
Updated Branches:
  refs/heads/master 58da1a245 -> 77cc0d67d


[SPARK-12717][PYTHON] Adding thread-safe broadcast pickle registry

## What changes were proposed in this pull request?

When using PySpark broadcast variables in a multi-threaded environment,  
`SparkContext._pickled_broadcast_vars` becomes a shared resource.  A race 
condition can occur when broadcast variables that are pickled from one thread 
get added to the shared ` _pickled_broadcast_vars` and become part of the 
python command from another thread.  This PR introduces a thread-safe pickled 
registry using thread local storage so that when python command is pickled 
(causing the broadcast variable to be pickled and added to the registry) each 
thread will have their own view of the pickle registry to retrieve and clear 
the broadcast variables used.

## How was this patch tested?

Added a unit test that causes this race condition using another thread.

Author: Bryan Cutler <cutl...@gmail.com>

Closes #18695 from BryanCutler/pyspark-bcast-threadsafe-SPARK-12717.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/77cc0d67
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/77cc0d67
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/77cc0d67

Branch: refs/heads/master
Commit: 77cc0d67d5a7ea526f8efd37b2590923953cb8e0
Parents: 58da1a2
Author: Bryan Cutler <cutl...@gmail.com>
Authored: Wed Aug 2 07:12:23 2017 +0900
Committer: hyukjinkwon <gurwls...@gmail.com>
Committed: Wed Aug 2 07:12:23 2017 +0900

----------------------------------------------------------------------
 python/pyspark/broadcast.py | 19 +++++++++++++++++
 python/pyspark/context.py   |  4 ++--
 python/pyspark/tests.py     | 44 ++++++++++++++++++++++++++++++++++++++++
 3 files changed, 65 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/77cc0d67/python/pyspark/broadcast.py
----------------------------------------------------------------------
diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py
index b1b59f7..02fc515 100644
--- a/python/pyspark/broadcast.py
+++ b/python/pyspark/broadcast.py
@@ -19,6 +19,7 @@ import os
 import sys
 import gc
 from tempfile import NamedTemporaryFile
+import threading
 
 from pyspark.cloudpickle import print_exec
 from pyspark.util import _exception_message
@@ -139,6 +140,24 @@ class Broadcast(object):
         return _from_id, (self._jbroadcast.id(),)
 
 
+class BroadcastPickleRegistry(threading.local):
+    """ Thread-local registry for broadcast variables that have been pickled
+    """
+
+    def __init__(self):
+        self.__dict__.setdefault("_registry", set())
+
+    def __iter__(self):
+        for bcast in self._registry:
+            yield bcast
+
+    def add(self, bcast):
+        self._registry.add(bcast)
+
+    def clear(self):
+        self._registry.clear()
+
+
 if __name__ == "__main__":
     import doctest
     (failure_count, test_count) = doctest.testmod()

http://git-wip-us.apache.org/repos/asf/spark/blob/77cc0d67/python/pyspark/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 80cb48f..a704604 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -30,7 +30,7 @@ from py4j.protocol import Py4JError
 
 from pyspark import accumulators
 from pyspark.accumulators import Accumulator
-from pyspark.broadcast import Broadcast
+from pyspark.broadcast import Broadcast, BroadcastPickleRegistry
 from pyspark.conf import SparkConf
 from pyspark.files import SparkFiles
 from pyspark.java_gateway import launch_gateway
@@ -195,7 +195,7 @@ class SparkContext(object):
         # This allows other code to determine which Broadcast instances have
         # been pickled, so it can determine which Java broadcast objects to
         # send.
-        self._pickled_broadcast_vars = set()
+        self._pickled_broadcast_vars = BroadcastPickleRegistry()
 
         SparkFiles._sc = self
         root_dir = SparkFiles.getRootDirectory()

http://git-wip-us.apache.org/repos/asf/spark/blob/77cc0d67/python/pyspark/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 73ab442..000dd1e 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -858,6 +858,50 @@ class RDDTests(ReusedPySparkTestCase):
         self.assertEqual(N, size)
         self.assertEqual(checksum, csum)
 
+    def test_multithread_broadcast_pickle(self):
+        import threading
+
+        b1 = self.sc.broadcast(list(range(3)))
+        b2 = self.sc.broadcast(list(range(3)))
+
+        def f1():
+            return b1.value
+
+        def f2():
+            return b2.value
+
+        funcs_num_pickled = {f1: None, f2: None}
+
+        def do_pickle(f, sc):
+            command = (f, None, sc.serializer, sc.serializer)
+            ser = CloudPickleSerializer()
+            ser.dumps(command)
+
+        def process_vars(sc):
+            broadcast_vars = list(sc._pickled_broadcast_vars)
+            num_pickled = len(broadcast_vars)
+            sc._pickled_broadcast_vars.clear()
+            return num_pickled
+
+        def run(f, sc):
+            do_pickle(f, sc)
+            funcs_num_pickled[f] = process_vars(sc)
+
+        # pickle f1, adds b1 to sc._pickled_broadcast_vars in main thread 
local storage
+        do_pickle(f1, self.sc)
+
+        # run all for f2, should only add/count/clear b2 from worker thread 
local storage
+        t = threading.Thread(target=run, args=(f2, self.sc))
+        t.start()
+        t.join()
+
+        # count number of vars pickled in main thread, only b1 should be 
counted and cleared
+        funcs_num_pickled[f1] = process_vars(self.sc)
+
+        self.assertEqual(funcs_num_pickled[f1], 1)
+        self.assertEqual(funcs_num_pickled[f2], 1)
+        self.assertEqual(len(list(self.sc._pickled_broadcast_vars)), 0)
+
     def test_large_closure(self):
         N = 200000
         data = [float(i) for i in xrange(N)]


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to