From 9cc6ff9c4e7eec2d62261fc166ad2ebade148752 Mon Sep 17 00:00:00 2001
From: Josh Rosen <joshrosen@eecs.berkeley.edu>
Date: Fri, 1 Feb 2013 11:09:56 -0800
Subject: [PATCH] Do not launch JavaGateways on workers (SPARK-674).

The problem was that the gateway was being initialized whenever the
pyspark.context module was loaded.  The fix uses lazy initialization
that occurs only when SparkContext instances are actually constructed.

I also made the gateway and jvm variables private.

This change results in ~3-4x performance improvement when running the
PySpark unit tests.
---
 python/pyspark/context.py | 27 +++++++++++++++++----------
 python/pyspark/files.py   |  2 +-
 python/pyspark/rdd.py     | 12 ++++++------
 python/pyspark/tests.py   |  2 +-
 4 files changed, 25 insertions(+), 18 deletions(-)

diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 783e3dc148..ba6896dda3 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -24,11 +24,10 @@ class SparkContext(object):
     broadcast variables on that cluster.
     """
 
-    gateway = launch_gateway()
-    jvm = gateway.jvm
-    _readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile
-    _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile
-    _takePartition = jvm.PythonRDD.takePartition
+    _gateway = None
+    _jvm = None
+    _writeIteratorToPickleFile = None
+    _takePartition = None
     _next_accum_id = 0
     _active_spark_context = None
     _lock = Lock()
@@ -56,6 +55,13 @@ class SparkContext(object):
                 raise ValueError("Cannot run multiple SparkContexts at once")
             else:
                 SparkContext._active_spark_context = self
+                if not SparkContext._gateway:
+                    SparkContext._gateway = launch_gateway()
+                    SparkContext._jvm = SparkContext._gateway.jvm
+                    SparkContext._writeIteratorToPickleFile = \
+                        SparkContext._jvm.PythonRDD.writeIteratorToPickleFile
+                    SparkContext._takePartition = \
+                        SparkContext._jvm.PythonRDD.takePartition
         self.master = master
         self.jobName = jobName
         self.sparkHome = sparkHome or None # None becomes null in Py4J
@@ -63,8 +69,8 @@ class SparkContext(object):
         self.batchSize = batchSize  # -1 represents a unlimited batch size
 
         # Create the Java SparkContext through Py4J
-        empty_string_array = self.gateway.new_array(self.jvm.String, 0)
-        self._jsc = self.jvm.JavaSparkContext(master, jobName, sparkHome,
+        empty_string_array = self._gateway.new_array(self._jvm.String, 0)
+        self._jsc = self._jvm.JavaSparkContext(master, jobName, sparkHome,
                                               empty_string_array)
 
         # Create a single Accumulator in Java that we'll send all our updates through;
@@ -72,8 +78,8 @@ class SparkContext(object):
         self._accumulatorServer = accumulators._start_update_server()
         (host, port) = self._accumulatorServer.server_address
         self._javaAccumulator = self._jsc.accumulator(
-                self.jvm.java.util.ArrayList(),
-                self.jvm.PythonAccumulatorParam(host, port))
+                self._jvm.java.util.ArrayList(),
+                self._jvm.PythonAccumulatorParam(host, port))
 
         self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
         # Broadcast's __reduce__ method stores Broadcast instances here.
@@ -127,7 +133,8 @@ class SparkContext(object):
         for x in c:
             write_with_length(dump_pickle(x), tempFile)
         tempFile.close()
-        jrdd = self._readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
+        readRDDFromPickleFile = self._jvm.PythonRDD.readRDDFromPickleFile
+        jrdd = readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
         return RDD(jrdd, self)
 
     def textFile(self, name, minSplits=None):
diff --git a/python/pyspark/files.py b/python/pyspark/files.py
index 98f6a399cc..001b7a28b6 100644
--- a/python/pyspark/files.py
+++ b/python/pyspark/files.py
@@ -35,4 +35,4 @@ class SparkFiles(object):
             return cls._root_directory
         else:
             # This will have to change if we support multiple SparkContexts:
-            return cls._sc.jvm.spark.SparkFiles.getRootDirectory()
+            return cls._sc._jvm.spark.SparkFiles.getRootDirectory()
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index d53355a8f1..d7cad2f372 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -407,7 +407,7 @@ class RDD(object):
             return (str(x).encode("utf-8") for x in iterator)
         keyed = PipelinedRDD(self, func)
         keyed._bypass_serializer = True
-        keyed._jrdd.map(self.ctx.jvm.BytesToString()).saveAsTextFile(path)
+        keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path)
 
     # Pair functions
 
@@ -550,8 +550,8 @@ class RDD(object):
                 yield dump_pickle(Batch(items))
         keyed = PipelinedRDD(self, add_shuffle_key)
         keyed._bypass_serializer = True
-        pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
-        partitioner = self.ctx.jvm.PythonPartitioner(numSplits,
+        pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
+        partitioner = self.ctx._jvm.PythonPartitioner(numSplits,
                                                      id(partitionFunc))
         jrdd = pairRDD.partitionBy(partitioner).values()
         rdd = RDD(jrdd, self.ctx)
@@ -730,13 +730,13 @@ class PipelinedRDD(RDD):
         pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds)
         broadcast_vars = ListConverter().convert(
             [x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
-            self.ctx.gateway._gateway_client)
+            self.ctx._gateway._gateway_client)
         self.ctx._pickled_broadcast_vars.clear()
         class_manifest = self._prev_jrdd.classManifest()
         env = copy.copy(self.ctx.environment)
         env['PYTHONPATH'] = os.environ.get("PYTHONPATH", "")
-        env = MapConverter().convert(env, self.ctx.gateway._gateway_client)
-        python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(),
+        env = MapConverter().convert(env, self.ctx._gateway._gateway_client)
+        python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
             pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec,
             broadcast_vars, self.ctx._javaAccumulator, class_manifest)
         self._jrdd_val = python_rdd.asJavaRDD()
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 52297d44e6..6a1962d267 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -26,7 +26,7 @@ class PySparkTestCase(unittest.TestCase):
         sys.path = self._old_sys_path
         # To avoid Akka rebinding to the same port, since it doesn't unbind
         # immediately on shutdown
-        self.sc.jvm.System.clearProperty("spark.driver.port")
+        self.sc._jvm.System.clearProperty("spark.driver.port")
 
 
 class TestCheckpoint(PySparkTestCase):
-- 
GitLab