From 77cc0d67d5a7ea526f8efd37b2590923953cb8e0 Mon Sep 17 00:00:00 2001 From: Bryan Cutler <cutlerb@gmail.com> Date: Wed, 2 Aug 2017 07:12:23 +0900 Subject: [PATCH] [SPARK-12717][PYTHON] Adding thread-safe broadcast pickle registry ## What changes were proposed in this pull request? When using PySpark broadcast variables in a multi-threaded environment, `SparkContext._pickled_broadcast_vars` becomes a shared resource. A race condition can occur when broadcast variables that are pickled from one thread get added to the shared ` _pickled_broadcast_vars` and become part of the python command from another thread. This PR introduces a thread-safe pickled registry using thread local storage so that when python command is pickled (causing the broadcast variable to be pickled and added to the registry) each thread will have their own view of the pickle registry to retrieve and clear the broadcast variables used. ## How was this patch tested? Added a unit test that causes this race condition using another thread. Author: Bryan Cutler <cutlerb@gmail.com> Closes #18695 from BryanCutler/pyspark-bcast-threadsafe-SPARK-12717. --- python/pyspark/broadcast.py | 19 ++++++++++++++++ python/pyspark/context.py | 4 ++-- python/pyspark/tests.py | 44 +++++++++++++++++++++++++++++++++++++ 3 files changed, 65 insertions(+), 2 deletions(-) diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index b1b59f73d6..02fc515fb8 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -19,6 +19,7 @@ import os import sys import gc from tempfile import NamedTemporaryFile +import threading from pyspark.cloudpickle import print_exec from pyspark.util import _exception_message @@ -139,6 +140,24 @@ class Broadcast(object): return _from_id, (self._jbroadcast.id(),) +class BroadcastPickleRegistry(threading.local): + """ Thread-local registry for broadcast variables that have been pickled + """ + + def __init__(self): + self.__dict__.setdefault("_registry", set()) + + def __iter__(self): + for bcast in self._registry: + yield bcast + + def add(self, bcast): + self._registry.add(bcast) + + def clear(self): + self._registry.clear() + + if __name__ == "__main__": import doctest (failure_count, test_count) = doctest.testmod() diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 80cb48fb82..a7046043e0 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -30,7 +30,7 @@ from py4j.protocol import Py4JError from pyspark import accumulators from pyspark.accumulators import Accumulator -from pyspark.broadcast import Broadcast +from pyspark.broadcast import Broadcast, BroadcastPickleRegistry from pyspark.conf import SparkConf from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway @@ -195,7 +195,7 @@ class SparkContext(object): # This allows other code to determine which Broadcast instances have # been pickled, so it can determine which Java broadcast objects to # send. - self._pickled_broadcast_vars = set() + self._pickled_broadcast_vars = BroadcastPickleRegistry() SparkFiles._sc = self root_dir = SparkFiles.getRootDirectory() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 73ab442dfd..000dd1eb8e 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -858,6 +858,50 @@ class RDDTests(ReusedPySparkTestCase): self.assertEqual(N, size) self.assertEqual(checksum, csum) + def test_multithread_broadcast_pickle(self): + import threading + + b1 = self.sc.broadcast(list(range(3))) + b2 = self.sc.broadcast(list(range(3))) + + def f1(): + return b1.value + + def f2(): + return b2.value + + funcs_num_pickled = {f1: None, f2: None} + + def do_pickle(f, sc): + command = (f, None, sc.serializer, sc.serializer) + ser = CloudPickleSerializer() + ser.dumps(command) + + def process_vars(sc): + broadcast_vars = list(sc._pickled_broadcast_vars) + num_pickled = len(broadcast_vars) + sc._pickled_broadcast_vars.clear() + return num_pickled + + def run(f, sc): + do_pickle(f, sc) + funcs_num_pickled[f] = process_vars(sc) + + # pickle f1, adds b1 to sc._pickled_broadcast_vars in main thread local storage + do_pickle(f1, self.sc) + + # run all for f2, should only add/count/clear b2 from worker thread local storage + t = threading.Thread(target=run, args=(f2, self.sc)) + t.start() + t.join() + + # count number of vars pickled in main thread, only b1 should be counted and cleared + funcs_num_pickled[f1] = process_vars(self.sc) + + self.assertEqual(funcs_num_pickled[f1], 1) + self.assertEqual(funcs_num_pickled[f2], 1) + self.assertEqual(len(list(self.sc._pickled_broadcast_vars)), 0) + def test_large_closure(self): N = 200000 data = [float(i) for i in xrange(N)] -- GitLab