diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py index b90596ecc2bee41cba9e0852c3dd92eb651028b1..6172d69dcff97fba5546494732f0199f2c12e021 100644 --- a/pyspark/pyspark/context.py +++ b/pyspark/pyspark/context.py @@ -4,7 +4,7 @@ from tempfile import NamedTemporaryFile from pyspark.broadcast import Broadcast from pyspark.java_gateway import launch_gateway -from pyspark.serializers import dump_pickle, write_with_length +from pyspark.serializers import dump_pickle, write_with_length, batched from pyspark.rdd import RDD from py4j.java_collections import ListConverter @@ -91,6 +91,8 @@ class SparkContext(object): # objects are written to a file and loaded through textFile(). tempFile = NamedTemporaryFile(delete=False) atexit.register(lambda: os.unlink(tempFile.name)) + if self.batchSize != 1: + c = batched(c, self.batchSize) for x in c: write_with_length(dump_pickle(x), tempFile) tempFile.close() diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 20f84b2dd08e3394db6bfffbc10dd591630993bb..203f7377d2c049182afdc5d114392ab355e9a22c 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -2,7 +2,7 @@ import atexit from base64 import standard_b64encode as b64enc import copy from collections import defaultdict -from itertools import chain, ifilter, imap +from itertools import chain, ifilter, imap, product import operator import os import shlex @@ -123,12 +123,6 @@ class RDD(object): >>> rdd = sc.parallelize([1, 1, 2, 3]) >>> rdd.union(rdd).collect() [1, 1, 2, 3, 1, 1, 2, 3] - - Union of batched and unbatched RDDs (internal test): - - >>> batchedRDD = sc.parallelize([Batch([1, 2, 3, 4, 5])]) - >>> rdd.union(batchedRDD).collect() - [1, 1, 2, 3, 1, 2, 3, 4, 5] """ return RDD(self._jrdd.union(other._jrdd), self.ctx) @@ -168,7 +162,18 @@ class RDD(object): >>> sorted(rdd.cartesian(rdd).collect()) [(1, 1), (1, 2), (2, 1), (2, 2)] """ - return RDD(self._jrdd.cartesian(other._jrdd), self.ctx) + # Due to batching, we can't use the Java cartesian method. + java_cartesian = RDD(self._jrdd.cartesian(other._jrdd), self.ctx) + def unpack_batches(pair): + (x, y) = pair + if type(x) == Batch or type(y) == Batch: + xs = x.items if type(x) == Batch else [x] + ys = y.items if type(y) == Batch else [y] + for pair in product(xs, ys): + yield pair + else: + yield pair + return java_cartesian.flatMap(unpack_batches) def groupBy(self, f, numSplits=None): """ @@ -293,8 +298,6 @@ class RDD(object): >>> sc.parallelize([2, 3, 4]).count() 3 - >>> sc.parallelize([Batch([2, 3, 4])]).count() - 3 """ return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum() @@ -667,12 +670,8 @@ class PipelinedRDD(RDD): if not self._bypass_serializer and self.ctx.batchSize != 1: oldfunc = self.func batchSize = self.ctx.batchSize - if batchSize == -1: # unlimited batch size - def batched_func(iterator): - yield Batch(list(oldfunc(iterator))) - else: - def batched_func(iterator): - return batched(oldfunc(iterator), batchSize) + def batched_func(iterator): + return batched(oldfunc(iterator), batchSize) func = batched_func cmds = [func, self._bypass_serializer] pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds) diff --git a/pyspark/pyspark/serializers.py b/pyspark/pyspark/serializers.py index 8b08f7ef0f982173774d1d59abdfb8da28ebb9b5..9a5151ea00341459f83b85938b82a8555a91945b 100644 --- a/pyspark/pyspark/serializers.py +++ b/pyspark/pyspark/serializers.py @@ -16,17 +16,20 @@ class Batch(object): def batched(iterator, batchSize): - items = [] - count = 0 - for item in iterator: - items.append(item) - count += 1 - if count == batchSize: + if batchSize == -1: # unlimited batch size + yield Batch(list(iterator)) + else: + items = [] + count = 0 + for item in iterator: + items.append(item) + count += 1 + if count == batchSize: + yield Batch(items) + items = [] + count = 0 + if items: yield Batch(items) - items = [] - count = 0 - if items: - yield Batch(items) def dump_pickle(obj):