diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py
index b1b59f73d67185f3590b03c5978fd47c068c8ea0..02fc515fb824a23b83bfadfa285afadfe1d25de8 100644
--- a/python/pyspark/broadcast.py
+++ b/python/pyspark/broadcast.py
@@ -19,6 +19,7 @@ import os
 import sys
 import gc
 from tempfile import NamedTemporaryFile
+import threading
 
 from pyspark.cloudpickle import print_exec
 from pyspark.util import _exception_message
@@ -139,6 +140,24 @@ class Broadcast(object):
         return _from_id, (self._jbroadcast.id(),)
 
 
+class BroadcastPickleRegistry(threading.local):
+    """ Thread-local registry for broadcast variables that have been pickled
+    """
+
+    def __init__(self):
+        self.__dict__.setdefault("_registry", set())
+
+    def __iter__(self):
+        for bcast in self._registry:
+            yield bcast
+
+    def add(self, bcast):
+        self._registry.add(bcast)
+
+    def clear(self):
+        self._registry.clear()
+
+
 if __name__ == "__main__":
     import doctest
     (failure_count, test_count) = doctest.testmod()
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 80cb48fb8209e0019d72fb9f5749cd077c718934..a7046043e03764b87e7f240be6aafd2dc083a72a 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -30,7 +30,7 @@ from py4j.protocol import Py4JError
 
 from pyspark import accumulators
 from pyspark.accumulators import Accumulator
-from pyspark.broadcast import Broadcast
+from pyspark.broadcast import Broadcast, BroadcastPickleRegistry
 from pyspark.conf import SparkConf
 from pyspark.files import SparkFiles
 from pyspark.java_gateway import launch_gateway
@@ -195,7 +195,7 @@ class SparkContext(object):
         # This allows other code to determine which Broadcast instances have
         # been pickled, so it can determine which Java broadcast objects to
         # send.
-        self._pickled_broadcast_vars = set()
+        self._pickled_broadcast_vars = BroadcastPickleRegistry()
 
         SparkFiles._sc = self
         root_dir = SparkFiles.getRootDirectory()
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 73ab442dfd791f2ae7f3de78adba48b2a3cb7393..000dd1eb8e481584b7605722e84cce1f528f3a19 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -858,6 +858,50 @@ class RDDTests(ReusedPySparkTestCase):
         self.assertEqual(N, size)
         self.assertEqual(checksum, csum)
 
+    def test_multithread_broadcast_pickle(self):
+        import threading
+
+        b1 = self.sc.broadcast(list(range(3)))
+        b2 = self.sc.broadcast(list(range(3)))
+
+        def f1():
+            return b1.value
+
+        def f2():
+            return b2.value
+
+        funcs_num_pickled = {f1: None, f2: None}
+
+        def do_pickle(f, sc):
+            command = (f, None, sc.serializer, sc.serializer)
+            ser = CloudPickleSerializer()
+            ser.dumps(command)
+
+        def process_vars(sc):
+            broadcast_vars = list(sc._pickled_broadcast_vars)
+            num_pickled = len(broadcast_vars)
+            sc._pickled_broadcast_vars.clear()
+            return num_pickled
+
+        def run(f, sc):
+            do_pickle(f, sc)
+            funcs_num_pickled[f] = process_vars(sc)
+
+        # pickle f1, adds b1 to sc._pickled_broadcast_vars in main thread local storage
+        do_pickle(f1, self.sc)
+
+        # run all for f2, should only add/count/clear b2 from worker thread local storage
+        t = threading.Thread(target=run, args=(f2, self.sc))
+        t.start()
+        t.join()
+
+        # count number of vars pickled in main thread, only b1 should be counted and cleared
+        funcs_num_pickled[f1] = process_vars(self.sc)
+
+        self.assertEqual(funcs_num_pickled[f1], 1)
+        self.assertEqual(funcs_num_pickled[f2], 1)
+        self.assertEqual(len(list(self.sc._pickled_broadcast_vars)), 0)
+
     def test_large_closure(self):
         N = 200000
         data = [float(i) for i in xrange(N)]