Skip to content
Snippets Groups Projects
Commit ccd075cf authored by Josh Rosen's avatar Josh Rosen
Browse files

Reduce object overhead in Pyspark shuffle and collect

parent 2ccf3b66
No related branches found
No related tags found
No related merge requests found
...@@ -145,8 +145,10 @@ class RDD(object): ...@@ -145,8 +145,10 @@ class RDD(object):
self.map(f).collect() # Force evaluation self.map(f).collect() # Force evaluation
def collect(self): def collect(self):
pickle = self.ctx.arrayAsPickle(self._jrdd.rdd().collect()) def asList(iterator):
return load_pickle(bytes(pickle)) yield list(iterator)
pickles = self.mapPartitions(asList)._jrdd.rdd().collect()
return list(chain.from_iterable(load_pickle(bytes(p)) for p in pickles))
def reduce(self, f): def reduce(self, f):
""" """
...@@ -319,16 +321,23 @@ class RDD(object): ...@@ -319,16 +321,23 @@ class RDD(object):
if numSplits is None: if numSplits is None:
numSplits = self.ctx.defaultParallelism numSplits = self.ctx.defaultParallelism
def add_shuffle_key(iterator): def add_shuffle_key(iterator):
buckets = defaultdict(list)
for (k, v) in iterator: for (k, v) in iterator:
yield str(hashFunc(k)) buckets[hashFunc(k) % numSplits].append((k, v))
yield dump_pickle((k, v)) for (split, items) in buckets.iteritems():
yield str(split)
yield dump_pickle(items)
keyed = PipelinedRDD(self, add_shuffle_key) keyed = PipelinedRDD(self, add_shuffle_key)
keyed._bypass_serializer = True keyed._bypass_serializer = True
pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits) partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits)
# Transferring O(n) objects to Java is too expensive. Instead, we'll
# form the hash buckets in Python, transferring O(numSplits) objects
# to Java. Each object is a (splitNumber, [objects]) pair.
jrdd = pairRDD.partitionBy(partitioner) jrdd = pairRDD.partitionBy(partitioner)
jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) jrdd = jrdd.map(self.ctx.jvm.ExtractValue())
return RDD(jrdd, self.ctx) # Flatten the resulting RDD:
return RDD(jrdd, self.ctx).flatMap(lambda items: items)
def combineByKey(self, createCombiner, mergeValue, mergeCombiners, def combineByKey(self, createCombiner, mergeValue, mergeCombiners,
numSplits=None): numSplits=None):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment