Repository: spark Updated Branches: refs/heads/master 58da1a245 -> 77cc0d67d
[SPARK-12717][PYTHON] Adding thread-safe broadcast pickle registry ## What changes were proposed in this pull request? When using PySpark broadcast variables in a multi-threaded environment, `SparkContext._pickled_broadcast_vars` becomes a shared resource. A race condition can occur when broadcast variables that are pickled from one thread get added to the shared ` _pickled_broadcast_vars` and become part of the python command from another thread. This PR introduces a thread-safe pickled registry using thread local storage so that when python command is pickled (causing the broadcast variable to be pickled and added to the registry) each thread will have their own view of the pickle registry to retrieve and clear the broadcast variables used. ## How was this patch tested? Added a unit test that causes this race condition using another thread. Author: Bryan Cutler <cutl...@gmail.com> Closes #18695 from BryanCutler/pyspark-bcast-threadsafe-SPARK-12717. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/77cc0d67 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/77cc0d67 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/77cc0d67 Branch: refs/heads/master Commit: 77cc0d67d5a7ea526f8efd37b2590923953cb8e0 Parents: 58da1a2 Author: Bryan Cutler <cutl...@gmail.com> Authored: Wed Aug 2 07:12:23 2017 +0900 Committer: hyukjinkwon <gurwls...@gmail.com> Committed: Wed Aug 2 07:12:23 2017 +0900 ---------------------------------------------------------------------- python/pyspark/broadcast.py | 19 +++++++++++++++++ python/pyspark/context.py | 4 ++-- python/pyspark/tests.py | 44 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 65 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/77cc0d67/python/pyspark/broadcast.py ---------------------------------------------------------------------- diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index b1b59f7..02fc515 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -19,6 +19,7 @@ import os import sys import gc from tempfile import NamedTemporaryFile +import threading from pyspark.cloudpickle import print_exec from pyspark.util import _exception_message @@ -139,6 +140,24 @@ class Broadcast(object): return _from_id, (self._jbroadcast.id(),) +class BroadcastPickleRegistry(threading.local): + """ Thread-local registry for broadcast variables that have been pickled + """ + + def __init__(self): + self.__dict__.setdefault("_registry", set()) + + def __iter__(self): + for bcast in self._registry: + yield bcast + + def add(self, bcast): + self._registry.add(bcast) + + def clear(self): + self._registry.clear() + + if __name__ == "__main__": import doctest (failure_count, test_count) = doctest.testmod() http://git-wip-us.apache.org/repos/asf/spark/blob/77cc0d67/python/pyspark/context.py ---------------------------------------------------------------------- diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 80cb48f..a704604 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -30,7 +30,7 @@ from py4j.protocol import Py4JError from pyspark import accumulators from pyspark.accumulators import Accumulator -from pyspark.broadcast import Broadcast +from pyspark.broadcast import Broadcast, BroadcastPickleRegistry from pyspark.conf import SparkConf from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway @@ -195,7 +195,7 @@ class SparkContext(object): # This allows other code to determine which Broadcast instances have # been pickled, so it can determine which Java broadcast objects to # send. - self._pickled_broadcast_vars = set() + self._pickled_broadcast_vars = BroadcastPickleRegistry() SparkFiles._sc = self root_dir = SparkFiles.getRootDirectory() http://git-wip-us.apache.org/repos/asf/spark/blob/77cc0d67/python/pyspark/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 73ab442..000dd1e 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -858,6 +858,50 @@ class RDDTests(ReusedPySparkTestCase): self.assertEqual(N, size) self.assertEqual(checksum, csum) + def test_multithread_broadcast_pickle(self): + import threading + + b1 = self.sc.broadcast(list(range(3))) + b2 = self.sc.broadcast(list(range(3))) + + def f1(): + return b1.value + + def f2(): + return b2.value + + funcs_num_pickled = {f1: None, f2: None} + + def do_pickle(f, sc): + command = (f, None, sc.serializer, sc.serializer) + ser = CloudPickleSerializer() + ser.dumps(command) + + def process_vars(sc): + broadcast_vars = list(sc._pickled_broadcast_vars) + num_pickled = len(broadcast_vars) + sc._pickled_broadcast_vars.clear() + return num_pickled + + def run(f, sc): + do_pickle(f, sc) + funcs_num_pickled[f] = process_vars(sc) + + # pickle f1, adds b1 to sc._pickled_broadcast_vars in main thread local storage + do_pickle(f1, self.sc) + + # run all for f2, should only add/count/clear b2 from worker thread local storage + t = threading.Thread(target=run, args=(f2, self.sc)) + t.start() + t.join() + + # count number of vars pickled in main thread, only b1 should be counted and cleared + funcs_num_pickled[f1] = process_vars(self.sc) + + self.assertEqual(funcs_num_pickled[f1], 1) + self.assertEqual(funcs_num_pickled[f2], 1) + self.assertEqual(len(list(self.sc._pickled_broadcast_vars)), 0) + def test_large_closure(self): N = 200000 data = [float(i) for i in xrange(N)] --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org