diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 39758e94f46fefaa0eccabfbb43bcc1e10491d8f..ab8351e55e9efa614533fa2cd93947efe688b2b5 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -238,6 +238,11 @@ private[spark] object PythonRDD { } def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) { + import scala.collection.JavaConverters._ + writeIteratorToPickleFile(items.asScala, filename) + } + + def writeIteratorToPickleFile[T](items: Iterator[T], filename: String) { val file = new DataOutputStream(new FileOutputStream(filename)) for (item <- items) { writeAsPickle(item, file) @@ -245,8 +250,10 @@ private[spark] object PythonRDD { file.close() } - def takePartition[T](rdd: RDD[T], partition: Int): java.util.Iterator[T] = - rdd.context.runJob(rdd, ((x: Iterator[T]) => x), Seq(partition), true).head + def takePartition[T](rdd: RDD[T], partition: Int): Iterator[T] = { + implicit val cm : ClassManifest[T] = rdd.elementClassManifest + rdd.context.runJob(rdd, ((x: Iterator[T]) => x.toArray), Seq(partition), true).head.iterator + } } private object Pickle { diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 61fcbbd37679fa806ff380b07cf1a941ff71bbd7..3e9d7d36da8a1b0fb81661a8ee0d02d076072939 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -196,12 +196,3 @@ def _start_update_server(): thread.daemon = True thread.start() return server - - -def _test(): - import doctest - doctest.testmod() - - -if __name__ == "__main__": - _test() diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 93876fa738f6336f0e9a21c484acf7dd74ce0477..def810dd461dab770ef5b5b351beaf6446b381e7 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -37,12 +37,3 @@ class Broadcast(object): def __reduce__(self): self._pickle_registry.add(self) return (_from_id, (self.bid, )) - - -def _test(): - import doctest - doctest.testmod() - - -if __name__ == "__main__": - _test() diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 6831f9b7f8b95aac5e82f7d16cb0597289a086a8..657fe6f98975bfa2ea86bc6435a868f0a97d4661 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -256,8 +256,10 @@ def _test(): globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) globs['tempdir'] = tempfile.mkdtemp() atexit.register(lambda: shutil.rmtree(globs['tempdir'])) - doctest.testmod(globs=globs) + (failure_count, test_count) = doctest.testmod(globs=globs) globs['sc'].stop() + if failure_count: + exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 41ea6e6e14c07b9c044f9e54372a80947dd46349..4cda6cf661197f662351bb5f08154b7aae1241f2 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -372,6 +372,10 @@ class RDD(object): items = [] for partition in range(self._jrdd.splits().size()): iterator = self.ctx._takePartition(self._jrdd.rdd(), partition) + # Each item in the iterator is a string, Python object, batch of + # Python objects. Regardless, it is sufficient to take `num` + # of these objects in order to collect `num` Python objects: + iterator = iterator.take(num) items.extend(self._collect_iterator_through_file(iterator)) if len(items) >= num: break @@ -748,8 +752,10 @@ def _test(): # The small batch size here ensures that we see multiple batches, # even in these small test examples: globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) - doctest.testmod(globs=globs) + (failure_count, test_count) = doctest.testmod(globs=globs) globs['sc'].stop() + if failure_count: + exit(-1) if __name__ == "__main__":