From 7ec3595de28d53839cb3a45e940ec16f81ffdf45 Mon Sep 17 00:00:00 2001
From: Josh Rosen <joshrosen@eecs.berkeley.edu>
Date: Fri, 28 Dec 2012 22:19:12 -0800
Subject: [PATCH] Fix bug (introduced by batching) in PySpark take()

---
 .../scala/spark/api/python/PythonRDD.scala    |  2 +-
 pyspark/pyspark/context.py                    |  6 ++---
 pyspark/pyspark/java_gateway.py               |  2 +-
 pyspark/pyspark/rdd.py                        | 27 ++++++++++++-------
 4 files changed, 22 insertions(+), 15 deletions(-)

diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index a80a8eea45..f76616a4c4 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -194,7 +194,7 @@ private[spark] object PythonRDD {
     JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
   }
 
-  def writeArrayToPickleFile[T](items: Array[T], filename: String) {
+  def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) {
     val file = new DataOutputStream(new FileOutputStream(filename))
     for (item <- items) {
       writeAsPickle(item, file)
diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py
index 988c81cd5d..b90596ecc2 100644
--- a/pyspark/pyspark/context.py
+++ b/pyspark/pyspark/context.py
@@ -19,8 +19,8 @@ class SparkContext(object):
 
     gateway = launch_gateway()
     jvm = gateway.jvm
-    readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile
-    writeArrayToPickleFile = jvm.PythonRDD.writeArrayToPickleFile
+    _readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile
+    _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile
 
     def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
         environment=None, batchSize=1024):
@@ -94,7 +94,7 @@ class SparkContext(object):
         for x in c:
             write_with_length(dump_pickle(x), tempFile)
         tempFile.close()
-        jrdd = self.readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
+        jrdd = self._readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
         return RDD(jrdd, self)
 
     def textFile(self, name, minSplits=None):
diff --git a/pyspark/pyspark/java_gateway.py b/pyspark/pyspark/java_gateway.py
index eb2a875762..2329e536cc 100644
--- a/pyspark/pyspark/java_gateway.py
+++ b/pyspark/pyspark/java_gateway.py
@@ -30,7 +30,7 @@ def launch_gateway():
                 sys.stderr.write(line)
     EchoOutputThread(proc.stdout).start()
     # Connect to the gateway
-    gateway = JavaGateway(GatewayClient(port=port))
+    gateway = JavaGateway(GatewayClient(port=port), auto_convert=False)
     # Import the classes used by PySpark
     java_import(gateway.jvm, "spark.api.java.*")
     java_import(gateway.jvm, "spark.api.python.*")
diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py
index bf32472d25..111476d274 100644
--- a/pyspark/pyspark/rdd.py
+++ b/pyspark/pyspark/rdd.py
@@ -152,8 +152,8 @@ class RDD(object):
         into a list.
 
         >>> rdd = sc.parallelize([1, 2, 3, 4], 2)
-        >>> rdd.glom().first()
-        [1, 2]
+        >>> sorted(rdd.glom().collect())
+        [[1, 2], [3, 4]]
         """
         def func(iterator): yield list(iterator)
         return self.mapPartitions(func)
@@ -211,10 +211,10 @@ class RDD(object):
         """
         Return a list that contains all of the elements in this RDD.
         """
-        picklesInJava = self._jrdd.rdd().collect()
-        return list(self._collect_array_through_file(picklesInJava))
+        picklesInJava = self._jrdd.collect().iterator()
+        return list(self._collect_iterator_through_file(picklesInJava))
 
-    def _collect_array_through_file(self, array):
+    def _collect_iterator_through_file(self, iterator):
         # Transferring lots of data through Py4J can be slow because
         # socket.readline() is inefficient.  Instead, we'll dump the data to a
         # file and read it back.
@@ -224,7 +224,7 @@ class RDD(object):
             try: os.unlink(tempFile.name)
             except: pass
         atexit.register(clean_up_file)
-        self.ctx.writeArrayToPickleFile(array, tempFile.name)
+        self.ctx._writeIteratorToPickleFile(iterator, tempFile.name)
         # Read the data into Python and deserialize it:
         with open(tempFile.name, 'rb') as tempFile:
             for item in read_from_pickle_file(tempFile):
@@ -325,11 +325,18 @@ class RDD(object):
         a lot of partitions are required. In that case, use L{collect} to get
         the whole RDD instead.
 
-        >>> sc.parallelize([2, 3, 4]).take(2)
+        >>> sc.parallelize([2, 3, 4, 5, 6]).take(2)
         [2, 3]
-        """
-        picklesInJava = self._jrdd.rdd().take(num)
-        return list(self._collect_array_through_file(picklesInJava))
+        >>> sc.parallelize([2, 3, 4, 5, 6]).take(10)
+        [2, 3, 4, 5, 6]
+        """
+        items = []
+        splits = self._jrdd.splits()
+        while len(items) < num and splits:
+            split = splits.pop(0)
+            iterator = self._jrdd.iterator(split)
+            items.extend(self._collect_iterator_through_file(iterator))
+        return items[:num]
 
     def first(self):
         """
-- 
GitLab