From 4608902fb87af64a15b97ab21fe6382cd6e5a644 Mon Sep 17 00:00:00 2001
From: Josh Rosen <joshrosen@eecs.berkeley.edu>
Date: Mon, 24 Dec 2012 17:20:10 -0800
Subject: [PATCH] Use filesystem to collect RDDs in PySpark.

Passing large volumes of data through Py4J seems
to be slow.  It appears to be faster to write the
data to the local filesystem and read it back from
Python.
---
 .../scala/spark/api/python/PythonRDD.scala    | 66 +++++++------------
 pyspark/pyspark/context.py                    |  9 ++-
 pyspark/pyspark/rdd.py                        | 34 ++++++++--
 pyspark/pyspark/serializers.py                |  8 +++
 pyspark/pyspark/worker.py                     | 12 +---
 5 files changed, 66 insertions(+), 63 deletions(-)

diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index 50094d6b0f..4f870e837a 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -1,6 +1,7 @@
 package spark.api.python
 
 import java.io._
+import java.util.{List => JList}
 
 import scala.collection.Map
 import scala.collection.JavaConversions._
@@ -59,36 +60,7 @@ trait PythonRDDBase {
         }
         out.flush()
         for (elem <- parent.iterator(split)) {
-          if (elem.isInstanceOf[Array[Byte]]) {
-            val arr = elem.asInstanceOf[Array[Byte]]
-            dOut.writeInt(arr.length)
-            dOut.write(arr)
-          } else if (elem.isInstanceOf[scala.Tuple2[_, _]]) {
-            val t = elem.asInstanceOf[scala.Tuple2[_, _]]
-            val t1 = t._1.asInstanceOf[Array[Byte]]
-            val t2 = t._2.asInstanceOf[Array[Byte]]
-            val length = t1.length + t2.length - 3 - 3 + 4  // stripPickle() removes 3 bytes
-            dOut.writeInt(length)
-            dOut.writeByte(Pickle.PROTO)
-            dOut.writeByte(Pickle.TWO)
-            dOut.write(PythonRDD.stripPickle(t1))
-            dOut.write(PythonRDD.stripPickle(t2))
-            dOut.writeByte(Pickle.TUPLE2)
-            dOut.writeByte(Pickle.STOP)
-          } else if (elem.isInstanceOf[String]) {
-            // For uniformity, strings are wrapped into Pickles.
-            val s = elem.asInstanceOf[String].getBytes("UTF-8")
-            val length = 2 + 1 + 4 + s.length + 1
-            dOut.writeInt(length)
-            dOut.writeByte(Pickle.PROTO)
-            dOut.writeByte(Pickle.TWO)
-            dOut.writeByte(Pickle.BINUNICODE)
-            dOut.writeInt(Integer.reverseBytes(s.length))
-            dOut.write(s)
-            dOut.writeByte(Pickle.STOP)
-          } else {
-            throw new Exception("Unexpected RDD type")
-          }
+          PythonRDD.writeAsPickle(elem, dOut)
         }
         dOut.flush()
         out.flush()
@@ -174,36 +146,45 @@ object PythonRDD {
     arr.slice(2, arr.length - 1)
   }
 
-  def asPickle(elem: Any) : Array[Byte] = {
-    val baos = new ByteArrayOutputStream();
-    val dOut = new DataOutputStream(baos);
+  /**
+   * Write strings, pickled Python objects, or pairs of pickled objects to a data output stream.
+   * The data format is a 32-bit integer representing the pickled object's length (in bytes),
+   * followed by the pickled data.
+   * @param elem the object to write
+   * @param dOut a data output stream
+   */
+  def writeAsPickle(elem: Any, dOut: DataOutputStream) {
     if (elem.isInstanceOf[Array[Byte]]) {
-      elem.asInstanceOf[Array[Byte]]
+      val arr = elem.asInstanceOf[Array[Byte]]
+      dOut.writeInt(arr.length)
+      dOut.write(arr)
     } else if (elem.isInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]) {
       val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]
+      val length = t._1.length + t._2.length - 3 - 3 + 4  // stripPickle() removes 3 bytes
+      dOut.writeInt(length)
       dOut.writeByte(Pickle.PROTO)
       dOut.writeByte(Pickle.TWO)
       dOut.write(PythonRDD.stripPickle(t._1))
       dOut.write(PythonRDD.stripPickle(t._2))
       dOut.writeByte(Pickle.TUPLE2)
       dOut.writeByte(Pickle.STOP)
-      baos.toByteArray()
     } else if (elem.isInstanceOf[String]) {
       // For uniformity, strings are wrapped into Pickles.
       val s = elem.asInstanceOf[String].getBytes("UTF-8")
+      val length = 2 + 1 + 4 + s.length + 1
+      dOut.writeInt(length)
       dOut.writeByte(Pickle.PROTO)
       dOut.writeByte(Pickle.TWO)
       dOut.write(Pickle.BINUNICODE)
       dOut.writeInt(Integer.reverseBytes(s.length))
       dOut.write(s)
       dOut.writeByte(Pickle.STOP)
-      baos.toByteArray()
     } else {
       throw new Exception("Unexpected RDD type")
     }
   }
 
-  def pickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) :
+  def readRDDFromPickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) :
   JavaRDD[Array[Byte]] = {
     val file = new DataInputStream(new FileInputStream(filename))
     val objs = new collection.mutable.ArrayBuffer[Array[Byte]]
@@ -221,11 +202,12 @@ object PythonRDD {
     JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
   }
 
-  def arrayAsPickle(arr : Any) : Array[Byte] = {
-    val pickles : Array[Byte] = arr.asInstanceOf[Array[Any]].map(asPickle).map(stripPickle).flatten
-
-    Array[Byte](Pickle.PROTO, Pickle.TWO, Pickle.EMPTY_LIST, Pickle.MARK) ++ pickles ++
-      Array[Byte] (Pickle.APPENDS, Pickle.STOP)
+  def writeArrayToPickleFile[T](items: Array[T], filename: String) {
+    val file = new DataOutputStream(new FileOutputStream(filename))
+    for (item <- items) {
+      writeAsPickle(item, file)
+    }
+    file.close()
   }
 }
 
diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py
index 50d57e5317..19f9f9e133 100644
--- a/pyspark/pyspark/context.py
+++ b/pyspark/pyspark/context.py
@@ -14,9 +14,8 @@ class SparkContext(object):
 
     gateway = launch_gateway()
     jvm = gateway.jvm
-    pickleFile = jvm.spark.api.python.PythonRDD.pickleFile
-    asPickle = jvm.spark.api.python.PythonRDD.asPickle
-    arrayAsPickle = jvm.spark.api.python.PythonRDD.arrayAsPickle
+    readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile
+    writeArrayToPickleFile = jvm.PythonRDD.writeArrayToPickleFile
 
     def __init__(self, master, name, defaultParallelism=None):
         self.master = master
@@ -45,11 +44,11 @@ class SparkContext(object):
         # because it sends O(n) Py4J commands.  As an alternative, serialized
         # objects are written to a file and loaded through textFile().
         tempFile = NamedTemporaryFile(delete=False)
+        atexit.register(lambda: os.unlink(tempFile.name))
         for x in c:
             write_with_length(dump_pickle(x), tempFile)
         tempFile.close()
-        atexit.register(lambda: os.unlink(tempFile.name))
-        jrdd = self.pickleFile(self._jsc, tempFile.name, numSlices)
+        jrdd = self.readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
         return RDD(jrdd, self)
 
     def textFile(self, name, minSplits=None):
diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py
index 708ea6eb55..01908cff96 100644
--- a/pyspark/pyspark/rdd.py
+++ b/pyspark/pyspark/rdd.py
@@ -1,13 +1,15 @@
+import atexit
 from base64 import standard_b64encode as b64enc
 from collections import defaultdict
 from itertools import chain, ifilter, imap
 import os
 import shlex
 from subprocess import Popen, PIPE
+from tempfile import NamedTemporaryFile
 from threading import Thread
 
 from pyspark import cloudpickle
-from pyspark.serializers import dump_pickle, load_pickle
+from pyspark.serializers import dump_pickle, load_pickle, read_from_pickle_file
 from pyspark.join import python_join, python_left_outer_join, \
     python_right_outer_join, python_cogroup
 
@@ -145,10 +147,30 @@ class RDD(object):
         self.map(f).collect()  # Force evaluation
 
     def collect(self):
+        # To minimize the number of transfers between Python and Java, we'll
+        # flatten each partition into a list before collecting it.  Due to
+        # pipelining, this should add minimal overhead.
         def asList(iterator):
             yield list(iterator)
-        pickles = self.mapPartitions(asList)._jrdd.rdd().collect()
-        return list(chain.from_iterable(load_pickle(bytes(p)) for p in pickles))
+        picklesInJava = self.mapPartitions(asList)._jrdd.rdd().collect()
+        return list(chain.from_iterable(self._collect_array_through_file(picklesInJava)))
+
+    def _collect_array_through_file(self, array):
+        # Transferring lots of data through Py4J can be slow because
+        # socket.readline() is inefficient.  Instead, we'll dump the data to a
+        # file and read it back.
+        tempFile = NamedTemporaryFile(delete=False)
+        tempFile.close()
+        def clean_up_file():
+            try: os.unlink(tempFile.name)
+            except: pass
+        atexit.register(clean_up_file)
+        self.ctx.writeArrayToPickleFile(array, tempFile.name)
+        # Read the data into Python and deserialize it:
+        with open(tempFile.name, 'rb') as tempFile:
+            for item in read_from_pickle_file(tempFile):
+                yield item
+        os.unlink(tempFile.name)
 
     def reduce(self, f):
         """
@@ -220,15 +242,15 @@ class RDD(object):
         >>> sc.parallelize([2, 3, 4]).take(2)
         [2, 3]
         """
-        pickle = self.ctx.arrayAsPickle(self._jrdd.rdd().take(num))
-        return load_pickle(bytes(pickle))
+        picklesInJava = self._jrdd.rdd().take(num)
+        return list(self._collect_array_through_file(picklesInJava))
 
     def first(self):
         """
         >>> sc.parallelize([2, 3, 4]).first()
         2
         """
-        return load_pickle(bytes(self.ctx.asPickle(self._jrdd.first())))
+        return self.take(1)[0]
 
     def saveAsTextFile(self, path):
         def func(iterator):
diff --git a/pyspark/pyspark/serializers.py b/pyspark/pyspark/serializers.py
index 21ef8b106c..bfcdda8f12 100644
--- a/pyspark/pyspark/serializers.py
+++ b/pyspark/pyspark/serializers.py
@@ -33,3 +33,11 @@ def read_with_length(stream):
     if obj == "":
         raise EOFError
     return obj
+
+
+def read_from_pickle_file(stream):
+    try:
+        while True:
+            yield load_pickle(read_with_length(stream))
+    except EOFError:
+        return
diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py
index 62824a1c9b..9f6b507dbd 100644
--- a/pyspark/pyspark/worker.py
+++ b/pyspark/pyspark/worker.py
@@ -8,7 +8,7 @@ from base64 import standard_b64decode
 from pyspark.broadcast import Broadcast, _broadcastRegistry
 from pyspark.cloudpickle import CloudPickler
 from pyspark.serializers import write_with_length, read_with_length, \
-    read_long, read_int, dump_pickle, load_pickle
+    read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file
 
 
 # Redirect stdout to stderr so that users must return values from functions.
@@ -20,14 +20,6 @@ def load_obj():
     return load_pickle(standard_b64decode(sys.stdin.readline().strip()))
 
 
-def read_input():
-    try:
-        while True:
-            yield load_pickle(read_with_length(sys.stdin))
-    except EOFError:
-        return
-
-
 def main():
     num_broadcast_variables = read_int(sys.stdin)
     for _ in range(num_broadcast_variables):
@@ -40,7 +32,7 @@ def main():
         dumps = lambda x: x
     else:
         dumps = dump_pickle
-    for obj in func(read_input()):
+    for obj in func(read_from_pickle_file(sys.stdin)):
         write_with_length(dumps(obj), old_stdout)
 
 
-- 
GitLab