From e2dad15621f5dc15275b300df05483afde5025a0 Mon Sep 17 00:00:00 2001
From: Josh Rosen <joshrosen@eecs.berkeley.edu>
Date: Wed, 26 Dec 2012 17:34:24 -0800
Subject: [PATCH] Add support for batched serialization of Python objects in
 PySpark.

---
 pyspark/pyspark/context.py     |  3 +-
 pyspark/pyspark/rdd.py         | 57 +++++++++++++++++++++++-----------
 pyspark/pyspark/serializers.py | 34 +++++++++++++++++++-
 3 files changed, 74 insertions(+), 20 deletions(-)

diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py
index 19f9f9e133..032619693a 100644
--- a/pyspark/pyspark/context.py
+++ b/pyspark/pyspark/context.py
@@ -17,13 +17,14 @@ class SparkContext(object):
     readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile
     writeArrayToPickleFile = jvm.PythonRDD.writeArrayToPickleFile
 
-    def __init__(self, master, name, defaultParallelism=None):
+    def __init__(self, master, name, defaultParallelism=None, batchSize=-1):
         self.master = master
         self.name = name
         self._jsc = self.jvm.JavaSparkContext(master, name)
         self.defaultParallelism = \
             defaultParallelism or self._jsc.sc().defaultParallelism()
         self.pythonExec = os.environ.get("PYSPARK_PYTHON_EXEC", 'python')
+        self.batchSize = batchSize  # -1 represents a unlimited batch size
         # Broadcast's __reduce__ method stores Broadcast instances here.
         # This allows other code to determine which Broadcast instances have
         # been pickled, so it can determine which Java broadcast objects to
diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py
index 01908cff96..d7081dffd2 100644
--- a/pyspark/pyspark/rdd.py
+++ b/pyspark/pyspark/rdd.py
@@ -2,6 +2,7 @@ import atexit
 from base64 import standard_b64encode as b64enc
 from collections import defaultdict
 from itertools import chain, ifilter, imap
+import operator
 import os
 import shlex
 from subprocess import Popen, PIPE
@@ -9,7 +10,8 @@ from tempfile import NamedTemporaryFile
 from threading import Thread
 
 from pyspark import cloudpickle
-from pyspark.serializers import dump_pickle, load_pickle, read_from_pickle_file
+from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \
+    read_from_pickle_file
 from pyspark.join import python_join, python_left_outer_join, \
     python_right_outer_join, python_cogroup
 
@@ -83,6 +85,11 @@ class RDD(object):
         >>> rdd = sc.parallelize([1, 1, 2, 3])
         >>> rdd.union(rdd).collect()
         [1, 1, 2, 3, 1, 1, 2, 3]
+
+        # Union of batched and unbatched RDDs:
+        >>> batchedRDD = sc.parallelize([Batch([1, 2, 3, 4, 5])])
+        >>> rdd.union(batchedRDD).collect()
+        [1, 1, 2, 3, 1, 2, 3, 4, 5]
         """
         return RDD(self._jrdd.union(other._jrdd), self.ctx)
 
@@ -147,13 +154,8 @@ class RDD(object):
         self.map(f).collect()  # Force evaluation
 
     def collect(self):
-        # To minimize the number of transfers between Python and Java, we'll
-        # flatten each partition into a list before collecting it.  Due to
-        # pipelining, this should add minimal overhead.
-        def asList(iterator):
-            yield list(iterator)
-        picklesInJava = self.mapPartitions(asList)._jrdd.rdd().collect()
-        return list(chain.from_iterable(self._collect_array_through_file(picklesInJava)))
+        picklesInJava = self._jrdd.rdd().collect()
+        return list(self._collect_array_through_file(picklesInJava))
 
     def _collect_array_through_file(self, array):
         # Transferring lots of data through Py4J can be slow because
@@ -214,12 +216,21 @@ class RDD(object):
 
     # TODO: aggregate
 
+    def sum(self):
+        """
+        >>> sc.parallelize([1.0, 2.0, 3.0]).sum()
+        6.0
+        """
+        return self.mapPartitions(lambda x: [sum(x)]).reduce(operator.add)
+
     def count(self):
         """
         >>> sc.parallelize([2, 3, 4]).count()
-        3L
+        3
+        >>> sc.parallelize([Batch([2, 3, 4])]).count()
+        3
         """
-        return self._jrdd.count()
+        return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum()
 
     def countByValue(self):
         """
@@ -342,24 +353,23 @@ class RDD(object):
         """
         if numSplits is None:
             numSplits = self.ctx.defaultParallelism
+        # Transferring O(n) objects to Java is too expensive.  Instead, we'll
+        # form the hash buckets in Python, transferring O(numSplits) objects
+        # to Java.  Each object is a (splitNumber, [objects]) pair.
         def add_shuffle_key(iterator):
             buckets = defaultdict(list)
             for (k, v) in iterator:
                 buckets[hashFunc(k) % numSplits].append((k, v))
             for (split, items) in buckets.iteritems():
                 yield str(split)
-                yield dump_pickle(items)
+                yield dump_pickle(Batch(items))
         keyed = PipelinedRDD(self, add_shuffle_key)
         keyed._bypass_serializer = True
         pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
         partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits)
-        # Transferring O(n) objects to Java is too expensive.  Instead, we'll
-        # form the hash buckets in Python, transferring O(numSplits) objects
-        # to Java.  Each object is a (splitNumber, [objects]) pair.
         jrdd = pairRDD.partitionBy(partitioner)
         jrdd = jrdd.map(self.ctx.jvm.ExtractValue())
-        # Flatten the resulting RDD:
-        return RDD(jrdd, self.ctx).flatMap(lambda items: items)
+        return RDD(jrdd, self.ctx)
 
     def combineByKey(self, createCombiner, mergeValue, mergeCombiners,
                      numSplits=None):
@@ -478,8 +488,19 @@ class PipelinedRDD(RDD):
     def _jrdd(self):
         if self._jrdd_val:
             return self._jrdd_val
-        funcs = [self.func, self._bypass_serializer]
-        pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in funcs)
+        func = self.func
+        if not self._bypass_serializer and self.ctx.batchSize != 1:
+            oldfunc = self.func
+            batchSize = self.ctx.batchSize
+            if batchSize == -1:  # unlimited batch size
+                def batched_func(iterator):
+                    yield Batch(list(oldfunc(iterator)))
+            else:
+                def batched_func(iterator):
+                    return batched(oldfunc(iterator), batchSize)
+            func = batched_func
+        cmds = [func, self._bypass_serializer]
+        pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds)
         broadcast_vars = ListConverter().convert(
             [x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
             self.ctx.gateway._gateway_client)
diff --git a/pyspark/pyspark/serializers.py b/pyspark/pyspark/serializers.py
index bfcdda8f12..4ed925697c 100644
--- a/pyspark/pyspark/serializers.py
+++ b/pyspark/pyspark/serializers.py
@@ -2,6 +2,33 @@ import struct
 import cPickle
 
 
+class Batch(object):
+    """
+    Used to store multiple RDD entries as a single Java object.
+
+    This relieves us from having to explicitly track whether an RDD
+    is stored as batches of objects and avoids problems when processing
+    the union() of batched and unbatched RDDs (e.g. the union() of textFile()
+    with another RDD).
+    """
+    def __init__(self, items):
+        self.items = items
+
+
+def batched(iterator, batchSize):
+    items = []
+    count = 0
+    for item in iterator:
+        items.append(item)
+        count += 1
+        if count == batchSize:
+            yield Batch(items)
+            items = []
+            count = []
+    if items:
+        yield Batch(items)
+
+
 def dump_pickle(obj):
     return cPickle.dumps(obj, 2)
 
@@ -38,6 +65,11 @@ def read_with_length(stream):
 def read_from_pickle_file(stream):
     try:
         while True:
-            yield load_pickle(read_with_length(stream))
+            obj = load_pickle(read_with_length(stream))
+            if type(obj) == Batch:  # We don't care about inheritance
+                for item in obj.items:
+                    yield item
+            else:
+                yield obj
     except EOFError:
         return
-- 
GitLab