From 7ec3595de28d53839cb3a45e940ec16f81ffdf45 Mon Sep 17 00:00:00 2001 From: Josh Rosen <joshrosen@eecs.berkeley.edu> Date: Fri, 28 Dec 2012 22:19:12 -0800 Subject: [PATCH] Fix bug (introduced by batching) in PySpark take() --- .../scala/spark/api/python/PythonRDD.scala | 2 +- pyspark/pyspark/context.py | 6 ++--- pyspark/pyspark/java_gateway.py | 2 +- pyspark/pyspark/rdd.py | 27 ++++++++++++------- 4 files changed, 22 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index a80a8eea45..f76616a4c4 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -194,7 +194,7 @@ private[spark] object PythonRDD { JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) } - def writeArrayToPickleFile[T](items: Array[T], filename: String) { + def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) { val file = new DataOutputStream(new FileOutputStream(filename)) for (item <- items) { writeAsPickle(item, file) diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py index 988c81cd5d..b90596ecc2 100644 --- a/pyspark/pyspark/context.py +++ b/pyspark/pyspark/context.py @@ -19,8 +19,8 @@ class SparkContext(object): gateway = launch_gateway() jvm = gateway.jvm - readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile - writeArrayToPickleFile = jvm.PythonRDD.writeArrayToPickleFile + _readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile + _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile def __init__(self, master, jobName, sparkHome=None, pyFiles=None, environment=None, batchSize=1024): @@ -94,7 +94,7 @@ class SparkContext(object): for x in c: write_with_length(dump_pickle(x), tempFile) tempFile.close() - jrdd = self.readRDDFromPickleFile(self._jsc, tempFile.name, numSlices) + jrdd = self._readRDDFromPickleFile(self._jsc, tempFile.name, numSlices) return RDD(jrdd, self) def textFile(self, name, minSplits=None): diff --git a/pyspark/pyspark/java_gateway.py b/pyspark/pyspark/java_gateway.py index eb2a875762..2329e536cc 100644 --- a/pyspark/pyspark/java_gateway.py +++ b/pyspark/pyspark/java_gateway.py @@ -30,7 +30,7 @@ def launch_gateway(): sys.stderr.write(line) EchoOutputThread(proc.stdout).start() # Connect to the gateway - gateway = JavaGateway(GatewayClient(port=port)) + gateway = JavaGateway(GatewayClient(port=port), auto_convert=False) # Import the classes used by PySpark java_import(gateway.jvm, "spark.api.java.*") java_import(gateway.jvm, "spark.api.python.*") diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index bf32472d25..111476d274 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -152,8 +152,8 @@ class RDD(object): into a list. >>> rdd = sc.parallelize([1, 2, 3, 4], 2) - >>> rdd.glom().first() - [1, 2] + >>> sorted(rdd.glom().collect()) + [[1, 2], [3, 4]] """ def func(iterator): yield list(iterator) return self.mapPartitions(func) @@ -211,10 +211,10 @@ class RDD(object): """ Return a list that contains all of the elements in this RDD. """ - picklesInJava = self._jrdd.rdd().collect() - return list(self._collect_array_through_file(picklesInJava)) + picklesInJava = self._jrdd.collect().iterator() + return list(self._collect_iterator_through_file(picklesInJava)) - def _collect_array_through_file(self, array): + def _collect_iterator_through_file(self, iterator): # Transferring lots of data through Py4J can be slow because # socket.readline() is inefficient. Instead, we'll dump the data to a # file and read it back. @@ -224,7 +224,7 @@ class RDD(object): try: os.unlink(tempFile.name) except: pass atexit.register(clean_up_file) - self.ctx.writeArrayToPickleFile(array, tempFile.name) + self.ctx._writeIteratorToPickleFile(iterator, tempFile.name) # Read the data into Python and deserialize it: with open(tempFile.name, 'rb') as tempFile: for item in read_from_pickle_file(tempFile): @@ -325,11 +325,18 @@ class RDD(object): a lot of partitions are required. In that case, use L{collect} to get the whole RDD instead. - >>> sc.parallelize([2, 3, 4]).take(2) + >>> sc.parallelize([2, 3, 4, 5, 6]).take(2) [2, 3] - """ - picklesInJava = self._jrdd.rdd().take(num) - return list(self._collect_array_through_file(picklesInJava)) + >>> sc.parallelize([2, 3, 4, 5, 6]).take(10) + [2, 3, 4, 5, 6] + """ + items = [] + splits = self._jrdd.splits() + while len(items) < num and splits: + split = splits.pop(0) + iterator = self._jrdd.iterator(split) + items.extend(self._collect_iterator_through_file(iterator)) + return items[:num] def first(self): """ -- GitLab