diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index b1b59f73d67185f3590b03c5978fd47c068c8ea0..02fc515fb824a23b83bfadfa285afadfe1d25de8 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -19,6 +19,7 @@ import os import sys import gc from tempfile import NamedTemporaryFile +import threading from pyspark.cloudpickle import print_exec from pyspark.util import _exception_message @@ -139,6 +140,24 @@ class Broadcast(object): return _from_id, (self._jbroadcast.id(),) +class BroadcastPickleRegistry(threading.local): + """ Thread-local registry for broadcast variables that have been pickled + """ + + def __init__(self): + self.__dict__.setdefault("_registry", set()) + + def __iter__(self): + for bcast in self._registry: + yield bcast + + def add(self, bcast): + self._registry.add(bcast) + + def clear(self): + self._registry.clear() + + if __name__ == "__main__": import doctest (failure_count, test_count) = doctest.testmod() diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 80cb48fb8209e0019d72fb9f5749cd077c718934..a7046043e03764b87e7f240be6aafd2dc083a72a 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -30,7 +30,7 @@ from py4j.protocol import Py4JError from pyspark import accumulators from pyspark.accumulators import Accumulator -from pyspark.broadcast import Broadcast +from pyspark.broadcast import Broadcast, BroadcastPickleRegistry from pyspark.conf import SparkConf from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway @@ -195,7 +195,7 @@ class SparkContext(object): # This allows other code to determine which Broadcast instances have # been pickled, so it can determine which Java broadcast objects to # send. - self._pickled_broadcast_vars = set() + self._pickled_broadcast_vars = BroadcastPickleRegistry() SparkFiles._sc = self root_dir = SparkFiles.getRootDirectory() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 73ab442dfd791f2ae7f3de78adba48b2a3cb7393..000dd1eb8e481584b7605722e84cce1f528f3a19 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -858,6 +858,50 @@ class RDDTests(ReusedPySparkTestCase): self.assertEqual(N, size) self.assertEqual(checksum, csum) + def test_multithread_broadcast_pickle(self): + import threading + + b1 = self.sc.broadcast(list(range(3))) + b2 = self.sc.broadcast(list(range(3))) + + def f1(): + return b1.value + + def f2(): + return b2.value + + funcs_num_pickled = {f1: None, f2: None} + + def do_pickle(f, sc): + command = (f, None, sc.serializer, sc.serializer) + ser = CloudPickleSerializer() + ser.dumps(command) + + def process_vars(sc): + broadcast_vars = list(sc._pickled_broadcast_vars) + num_pickled = len(broadcast_vars) + sc._pickled_broadcast_vars.clear() + return num_pickled + + def run(f, sc): + do_pickle(f, sc) + funcs_num_pickled[f] = process_vars(sc) + + # pickle f1, adds b1 to sc._pickled_broadcast_vars in main thread local storage + do_pickle(f1, self.sc) + + # run all for f2, should only add/count/clear b2 from worker thread local storage + t = threading.Thread(target=run, args=(f2, self.sc)) + t.start() + t.join() + + # count number of vars pickled in main thread, only b1 should be counted and cleared + funcs_num_pickled[f1] = process_vars(self.sc) + + self.assertEqual(funcs_num_pickled[f1], 1) + self.assertEqual(funcs_num_pickled[f2], 1) + self.assertEqual(len(list(self.sc._pickled_broadcast_vars)), 0) + def test_large_closure(self): N = 200000 data = [float(i) for i in xrange(N)]