diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 593d74bca5fff1382b72e8dcbcbc0f95efdb95d8..64f6a3ca6bf4cca25aad7e534f3b39ee909bcf6c 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -319,7 +319,7 @@ class SparkContext(object): # Make sure we distribute data evenly if it's smaller than self.batchSize if "__len__" not in dir(c): c = list(c) # Make it a list so we can compute its length - batchSize = max(1, min(len(c) // numSlices, self._batchSize)) + batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024)) serializer = BatchedSerializer(self._unbatched_serializer, batchSize) serializer.dump_stream(c, tempFile) tempFile.close() diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index bd08c9a6d20d6c073d50181cde50440406410f00..b8bda835174b2bb16c6ec69e948b280f4996ac9d 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -181,6 +181,10 @@ class BatchedSerializer(Serializer): def _batched(self, iterator): if self.batchSize == self.UNLIMITED_BATCH_SIZE: yield list(iterator) + elif hasattr(iterator, "__len__") and hasattr(iterator, "__getslice__"): + n = len(iterator) + for i in xrange(0, n, self.batchSize): + yield iterator[i: i + self.batchSize] else: items = [] count = 0