From b5c51c8df480f1a82a82e4d597d8eea631bffb4e Mon Sep 17 00:00:00 2001
From: Davies Liu <davies.liu@gmail.com>
Date: Thu, 9 Apr 2015 17:07:23 -0700
Subject: [PATCH] [SPARK-3074] [PySpark] support groupByKey() with single huge
 key

This patch change groupByKey() to use external sort based approach, so it can support single huge key.

For example, it can group by a dataset including one hot key with 40 millions values (strings), using 500M memory for Python worker, finished in about 2 minutes. (it will need 6G memory in hash based approach).

During groupByKey(), it will do in-memory groupBy first. If the dataset can not fit in memory, then data will be partitioned by hash. If one partition still can not fit in memory, it will switch to sort based groupBy().

Author: Davies Liu <davies.liu@gmail.com>
Author: Davies Liu <davies@databricks.com>

Closes #1977 from davies/groupby and squashes the following commits:

af3713a [Davies Liu] make sure it's iterator
67772dd [Davies Liu] fix tests
e78c15c [Davies Liu] address comments
0b0fde8 [Davies Liu] address comments
0dcf320 [Davies Liu] address comments, rollback changes in ResultIterable
e3b8eab [Davies Liu] fix narrow dependency
2a1857a [Davies Liu] typo
d2f053b [Davies Liu] add repr for FlattedValuesSerializer
c6a2f8d [Davies Liu] address comments
9e2df24 [Davies Liu] Merge branch 'master' of github.com:apache/spark into groupby
2b9c261 [Davies Liu] fix typo in comments
70aadcd [Davies Liu] Merge branch 'master' of github.com:apache/spark into groupby
a14b4bd [Davies Liu] Merge branch 'master' of github.com:apache/spark into groupby
ab5515b [Davies Liu] Merge branch 'master' into groupby
651f891 [Davies Liu] simplify GroupByKey
1578f2e [Davies Liu] Merge branch 'master' of github.com:apache/spark into groupby
1f69f93 [Davies Liu] fix tests
0d3395f [Davies Liu] Merge branch 'master' of github.com:apache/spark into groupby
341f1e0 [Davies Liu] add comments, refactor
47918b8 [Davies Liu] remove unused code
6540948 [Davies Liu] address comments:
17f4ec6 [Davies Liu] Merge branch 'master' of github.com:apache/spark into groupby
4d4bc86 [Davies Liu] bugfix
8ef965e [Davies Liu] Merge branch 'master' into groupby
fbc504a [Davies Liu] Merge branch 'master' into groupby
779ed03 [Davies Liu] fix merge conflict
2c1d05b [Davies Liu] refactor, minor turning
b48cda5 [Davies Liu] Merge branch 'master' into groupby
85138e6 [Davies Liu] Merge branch 'master' into groupby
acd8e1b [Davies Liu] fix memory when groupByKey().count()
905b233 [Davies Liu] Merge branch 'sort' into groupby
1f075ed [Davies Liu] Merge branch 'master' into sort
4b07d39 [Davies Liu] compress the data while spilling
0a081c6 [Davies Liu] Merge branch 'master' into groupby
f157fe7 [Davies Liu] Merge branch 'sort' into groupby
eb53ca6 [Davies Liu] Merge branch 'master' into sort
b2dc3bf [Davies Liu] Merge branch 'sort' into groupby
644abaf [Davies Liu] add license in LICENSE
19f7873 [Davies Liu] improve tests
11ba318 [Davies Liu] typo
085aef8 [Davies Liu] Merge branch 'master' into groupby
3ee58e5 [Davies Liu] switch to sort based groupBy, based on size of data
1ea0669 [Davies Liu] choose sort based groupByKey() automatically
b40bae7 [Davies Liu] bugfix
efa23df [Davies Liu] refactor, add spark.shuffle.sort=False
250be4e [Davies Liu] flatten the combined values when dumping into disks
d05060d [Davies Liu] group the same key before shuffle, reduce the comparison during sorting
083d842 [Davies Liu] sorted based groupByKey()
55602ee [Davies Liu] use external sort in sortBy() and sortByKey()
---
 python/pyspark/join.py           |  13 +-
 python/pyspark/rdd.py            |  48 ++-
 python/pyspark/resultiterable.py |   7 +-
 python/pyspark/serializers.py    |  25 +-
 python/pyspark/shuffle.py        | 531 ++++++++++++++++++++++++-------
 python/pyspark/tests.py          |  50 ++-
 6 files changed, 531 insertions(+), 143 deletions(-)

diff --git a/python/pyspark/join.py b/python/pyspark/join.py
index efc1ef9396..c3491defb2 100644
--- a/python/pyspark/join.py
+++ b/python/pyspark/join.py
@@ -48,7 +48,7 @@ def python_join(rdd, other, numPartitions):
                 vbuf.append(v)
             elif n == 2:
                 wbuf.append(v)
-        return [(v, w) for v in vbuf for w in wbuf]
+        return ((v, w) for v in vbuf for w in wbuf)
     return _do_python_join(rdd, other, numPartitions, dispatch)
 
 
@@ -62,7 +62,7 @@ def python_right_outer_join(rdd, other, numPartitions):
                 wbuf.append(v)
         if not vbuf:
             vbuf.append(None)
-        return [(v, w) for v in vbuf for w in wbuf]
+        return ((v, w) for v in vbuf for w in wbuf)
     return _do_python_join(rdd, other, numPartitions, dispatch)
 
 
@@ -76,7 +76,7 @@ def python_left_outer_join(rdd, other, numPartitions):
                 wbuf.append(v)
         if not wbuf:
             wbuf.append(None)
-        return [(v, w) for v in vbuf for w in wbuf]
+        return ((v, w) for v in vbuf for w in wbuf)
     return _do_python_join(rdd, other, numPartitions, dispatch)
 
 
@@ -104,8 +104,9 @@ def python_cogroup(rdds, numPartitions):
     rdd_len = len(vrdds)
 
     def dispatch(seq):
-        bufs = [[] for i in range(rdd_len)]
-        for (n, v) in seq:
+        bufs = [[] for _ in range(rdd_len)]
+        for n, v in seq:
             bufs[n].append(v)
-        return tuple(map(ResultIterable, bufs))
+        return tuple(ResultIterable(vs) for vs in bufs)
+
     return union_vrdds.groupByKey(numPartitions).mapValues(dispatch)
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 2d05611321..1b18789040 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -41,7 +41,7 @@ from pyspark.rddsampler import RDDSampler, RDDRangeSampler, RDDStratifiedSampler
 from pyspark.storagelevel import StorageLevel
 from pyspark.resultiterable import ResultIterable
 from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \
-    get_used_memory, ExternalSorter
+    get_used_memory, ExternalSorter, ExternalGroupBy
 from pyspark.traceback_utils import SCCallSiteSync
 
 from py4j.java_collections import ListConverter, MapConverter
@@ -573,8 +573,8 @@ class RDD(object):
         if numPartitions is None:
             numPartitions = self._defaultReducePartitions()
 
-        spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() == 'true')
-        memory = _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m"))
+        spill = self._can_spill()
+        memory = self._memory_limit()
         serializer = self._jrdd_deserializer
 
         def sortPartition(iterator):
@@ -1699,10 +1699,8 @@ class RDD(object):
             numPartitions = self._defaultReducePartitions()
 
         serializer = self.ctx.serializer
-        spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower()
-                 == 'true')
-        memory = _parse_memory(self.ctx._conf.get(
-            "spark.python.worker.memory", "512m"))
+        spill = self._can_spill()
+        memory = self._memory_limit()
         agg = Aggregator(createCombiner, mergeValue, mergeCombiners)
 
         def combineLocally(iterator):
@@ -1755,21 +1753,28 @@ class RDD(object):
 
         return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions)
 
+    def _can_spill(self):
+        return self.ctx._conf.get("spark.shuffle.spill", "True").lower() == "true"
+
+    def _memory_limit(self):
+        return _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m"))
+
     # TODO: support variant with custom partitioner
     def groupByKey(self, numPartitions=None):
         """
         Group the values for each key in the RDD into a single sequence.
-        Hash-partitions the resulting RDD with into numPartitions partitions.
+        Hash-partitions the resulting RDD with numPartitions partitions.
 
         Note: If you are grouping in order to perform an aggregation (such as a
         sum or average) over each key, using reduceByKey or aggregateByKey will
         provide much better performance.
 
         >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
-        >>> map((lambda (x,y): (x, list(y))), sorted(x.groupByKey().collect()))
+        >>> sorted(x.groupByKey().mapValues(len).collect())
+        [('a', 2), ('b', 1)]
+        >>> sorted(x.groupByKey().mapValues(list).collect())
         [('a', [1, 1]), ('b', [1])]
         """
-
         def createCombiner(x):
             return [x]
 
@@ -1781,8 +1786,27 @@ class RDD(object):
             a.extend(b)
             return a
 
-        return self.combineByKey(createCombiner, mergeValue, mergeCombiners,
-                                 numPartitions).mapValues(lambda x: ResultIterable(x))
+        spill = self._can_spill()
+        memory = self._memory_limit()
+        serializer = self._jrdd_deserializer
+        agg = Aggregator(createCombiner, mergeValue, mergeCombiners)
+
+        def combine(iterator):
+            merger = ExternalMerger(agg, memory * 0.9, serializer) \
+                if spill else InMemoryMerger(agg)
+            merger.mergeValues(iterator)
+            return merger.iteritems()
+
+        locally_combined = self.mapPartitions(combine, preservesPartitioning=True)
+        shuffled = locally_combined.partitionBy(numPartitions)
+
+        def groupByKey(it):
+            merger = ExternalGroupBy(agg, memory, serializer)\
+                if spill else InMemoryMerger(agg)
+            merger.mergeCombiners(it)
+            return merger.iteritems()
+
+        return shuffled.mapPartitions(groupByKey, True).mapValues(ResultIterable)
 
     def flatMapValues(self, f):
         """
diff --git a/python/pyspark/resultiterable.py b/python/pyspark/resultiterable.py
index ef04c82866..1ab5ce14c3 100644
--- a/python/pyspark/resultiterable.py
+++ b/python/pyspark/resultiterable.py
@@ -15,15 +15,16 @@
 # limitations under the License.
 #
 
-__all__ = ["ResultIterable"]
-
 import collections
 
+__all__ = ["ResultIterable"]
+
 
 class ResultIterable(collections.Iterable):
 
     """
-    A special result iterable. This is used because the standard iterator can not be pickled
+    A special result iterable. This is used because the standard
+    iterator can not be pickled
     """
 
     def __init__(self, data):
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 0ffb41d02f..4afa82f4b2 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -220,6 +220,29 @@ class BatchedSerializer(Serializer):
         return "BatchedSerializer(%s, %d)" % (str(self.serializer), self.batchSize)
 
 
+class FlattenedValuesSerializer(BatchedSerializer):
+
+    """
+    Serializes a stream of list of pairs, split the list of values
+    which contain more than a certain number of objects to make them
+    have similar sizes.
+    """
+    def __init__(self, serializer, batchSize=10):
+        BatchedSerializer.__init__(self, serializer, batchSize)
+
+    def _batched(self, iterator):
+        n = self.batchSize
+        for key, values in iterator:
+            for i in xrange(0, len(values), n):
+                yield key, values[i:i + n]
+
+    def load_stream(self, stream):
+        return self.serializer.load_stream(stream)
+
+    def __repr__(self):
+        return "FlattenedValuesSerializer(%d)" % self.batchSize
+
+
 class AutoBatchedSerializer(BatchedSerializer):
     """
     Choose the size of batch automatically based on the size of object
@@ -251,7 +274,7 @@ class AutoBatchedSerializer(BatchedSerializer):
         return (isinstance(other, AutoBatchedSerializer) and
                 other.serializer == self.serializer and other.bestSize == self.bestSize)
 
-    def __str__(self):
+    def __repr__(self):
         return "AutoBatchedSerializer(%s)" % str(self.serializer)
 
 
diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py
index 10a7ccd502..8a6fc627eb 100644
--- a/python/pyspark/shuffle.py
+++ b/python/pyspark/shuffle.py
@@ -16,28 +16,35 @@
 #
 
 import os
-import sys
 import platform
 import shutil
 import warnings
 import gc
 import itertools
+import operator
 import random
 
 import pyspark.heapq3 as heapq
-from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
+from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \
+    CompressedSerializer, AutoBatchedSerializer
+
 
 try:
     import psutil
 
+    process = None
+
     def get_used_memory():
         """ Return the used memory in MB """
-        process = psutil.Process(os.getpid())
+        global process
+        if process is None or process._pid != os.getpid():
+            process = psutil.Process(os.getpid())
         if hasattr(process, "memory_info"):
             info = process.memory_info()
         else:
             info = process.get_memory_info()
         return info.rss >> 20
+
 except ImportError:
 
     def get_used_memory():
@@ -46,6 +53,7 @@ except ImportError:
             for line in open('/proc/self/status'):
                 if line.startswith('VmRSS:'):
                     return int(line.split()[1]) >> 10
+
         else:
             warnings.warn("Please install psutil to have better "
                           "support with spilling")
@@ -54,6 +62,7 @@ except ImportError:
                 rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
                 return rss >> 20
             # TODO: support windows
+
         return 0
 
 
@@ -148,10 +157,16 @@ class InMemoryMerger(Merger):
             d[k] = comb(d[k], v) if k in d else v
 
     def iteritems(self):
-        """ Return the merged items ad iterator """
+        """ Return the merged items as iterator """
         return self.data.iteritems()
 
 
+def _compressed_serializer(self, serializer=None):
+    # always use PickleSerializer to simplify implementation
+    ser = PickleSerializer()
+    return AutoBatchedSerializer(CompressedSerializer(ser))
+
+
 class ExternalMerger(Merger):
 
     """
@@ -173,7 +188,7 @@ class ExternalMerger(Merger):
       dict. Repeat this again until combine all the items.
 
     - Before return any items, it will load each partition and
-      combine them seperately. Yield them before loading next
+      combine them separately. Yield them before loading next
       partition.
 
     - During loading a partition, if the memory goes over limit,
@@ -182,7 +197,7 @@ class ExternalMerger(Merger):
 
     `data` and `pdata` are used to hold the merged items in memory.
     At first, all the data are merged into `data`. Once the used
-    memory goes over limit, the items in `data` are dumped indo
+    memory goes over limit, the items in `data` are dumped into
     disks, `data` will be cleared, all rest of items will be merged
     into `pdata` and then dumped into disks. Before returning, all
     the items in `pdata` will be dumped into disks.
@@ -193,16 +208,16 @@ class ExternalMerger(Merger):
     >>> agg = SimpleAggregator(lambda x, y: x + y)
     >>> merger = ExternalMerger(agg, 10)
     >>> N = 10000
-    >>> merger.mergeValues(zip(xrange(N), xrange(N)) * 10)
+    >>> merger.mergeValues(zip(xrange(N), xrange(N)))
     >>> assert merger.spills > 0
     >>> sum(v for k,v in merger.iteritems())
-    499950000
+    49995000
 
     >>> merger = ExternalMerger(agg, 10)
-    >>> merger.mergeCombiners(zip(xrange(N), xrange(N)) * 10)
+    >>> merger.mergeCombiners(zip(xrange(N), xrange(N)))
     >>> assert merger.spills > 0
     >>> sum(v for k,v in merger.iteritems())
-    499950000
+    49995000
     """
 
     # the max total partitions created recursively
@@ -212,8 +227,7 @@ class ExternalMerger(Merger):
                  localdirs=None, scale=1, partitions=59, batch=1000):
         Merger.__init__(self, aggregator)
         self.memory_limit = memory_limit
-        # default serializer is only used for tests
-        self.serializer = serializer or AutoBatchedSerializer(PickleSerializer())
+        self.serializer = _compressed_serializer(serializer)
         self.localdirs = localdirs or _get_local_dirs(str(id(self)))
         # number of partitions when spill data into disks
         self.partitions = partitions
@@ -221,7 +235,7 @@ class ExternalMerger(Merger):
         self.batch = batch
         # scale is used to scale down the hash of key for recursive hash map
         self.scale = scale
-        # unpartitioned merged data
+        # un-partitioned merged data
         self.data = {}
         # partitioned merged data, list of dicts
         self.pdata = []
@@ -244,72 +258,63 @@ class ExternalMerger(Merger):
 
     def mergeValues(self, iterator):
         """ Combine the items by creator and combiner """
-        iterator = iter(iterator)
         # speedup attribute lookup
         creator, comb = self.agg.createCombiner, self.agg.mergeValue
-        d, c, batch = self.data, 0, self.batch
+        c, data, pdata, hfun, batch = 0, self.data, self.pdata, self._partition, self.batch
+        limit = self.memory_limit
 
         for k, v in iterator:
+            d = pdata[hfun(k)] if pdata else data
             d[k] = comb(d[k], v) if k in d else creator(v)
 
             c += 1
-            if c % batch == 0 and get_used_memory() > self.memory_limit:
-                self._spill()
-                self._partitioned_mergeValues(iterator, self._next_limit())
-                break
+            if c >= batch:
+                if get_used_memory() >= limit:
+                    self._spill()
+                    limit = self._next_limit()
+                    batch /= 2
+                    c = 0
+                else:
+                    batch *= 1.5
+
+        if get_used_memory() >= limit:
+            self._spill()
 
     def _partition(self, key):
         """ Return the partition for key """
         return hash((key, self._seed)) % self.partitions
 
-    def _partitioned_mergeValues(self, iterator, limit=0):
-        """ Partition the items by key, then combine them """
-        # speedup attribute lookup
-        creator, comb = self.agg.createCombiner, self.agg.mergeValue
-        c, pdata, hfun, batch = 0, self.pdata, self._partition, self.batch
-
-        for k, v in iterator:
-            d = pdata[hfun(k)]
-            d[k] = comb(d[k], v) if k in d else creator(v)
-            if not limit:
-                continue
-
-            c += 1
-            if c % batch == 0 and get_used_memory() > limit:
-                self._spill()
-                limit = self._next_limit()
+    def _object_size(self, obj):
+        """ How much of memory for this obj, assume that all the objects
+        consume similar bytes of memory
+        """
+        return 1
 
-    def mergeCombiners(self, iterator, check=True):
+    def mergeCombiners(self, iterator, limit=None):
         """ Merge (K,V) pair by mergeCombiner """
-        iterator = iter(iterator)
+        if limit is None:
+            limit = self.memory_limit
         # speedup attribute lookup
-        d, comb, batch = self.data, self.agg.mergeCombiners, self.batch
-        c = 0
-        for k, v in iterator:
-            d[k] = comb(d[k], v) if k in d else v
-            if not check:
-                continue
-
-            c += 1
-            if c % batch == 0 and get_used_memory() > self.memory_limit:
-                self._spill()
-                self._partitioned_mergeCombiners(iterator, self._next_limit())
-                break
-
-    def _partitioned_mergeCombiners(self, iterator, limit=0):
-        """ Partition the items by key, then merge them """
-        comb, pdata = self.agg.mergeCombiners, self.pdata
-        c, hfun = 0, self._partition
+        comb, hfun, objsize = self.agg.mergeCombiners, self._partition, self._object_size
+        c, data, pdata, batch = 0, self.data, self.pdata, self.batch
         for k, v in iterator:
-            d = pdata[hfun(k)]
+            d = pdata[hfun(k)] if pdata else data
             d[k] = comb(d[k], v) if k in d else v
             if not limit:
                 continue
 
-            c += 1
-            if c % self.batch == 0 and get_used_memory() > limit:
-                self._spill()
-                limit = self._next_limit()
+            c += objsize(v)
+            if c > batch:
+                if get_used_memory() > limit:
+                    self._spill()
+                    limit = self._next_limit()
+                    batch /= 2
+                    c = 0
+                else:
+                    batch *= 1.5
+
+        if limit and get_used_memory() >= limit:
+            self._spill()
 
     def _spill(self):
         """
@@ -335,7 +340,7 @@ class ExternalMerger(Merger):
 
             for k, v in self.data.iteritems():
                 h = self._partition(k)
-                # put one item in batch, make it compatitable with load_stream
+                # put one item in batch, make it compatible with load_stream
                 # it will increase the memory if dump them in batch
                 self.serializer.dump_stream([(k, v)], streams[h])
 
@@ -344,7 +349,7 @@ class ExternalMerger(Merger):
                 s.close()
 
             self.data.clear()
-            self.pdata = [{} for i in range(self.partitions)]
+            self.pdata.extend([{} for i in range(self.partitions)])
 
         else:
             for i in range(self.partitions):
@@ -370,29 +375,12 @@ class ExternalMerger(Merger):
         assert not self.data
         if any(self.pdata):
             self._spill()
-        hard_limit = self._next_limit()
+        # disable partitioning and spilling when merge combiners from disk
+        self.pdata = []
 
         try:
             for i in range(self.partitions):
-                self.data = {}
-                for j in range(self.spills):
-                    path = self._get_spill_dir(j)
-                    p = os.path.join(path, str(i))
-                    # do not check memory during merging
-                    self.mergeCombiners(self.serializer.load_stream(open(p)),
-                                        False)
-
-                    # limit the total partitions
-                    if (self.scale * self.partitions < self.MAX_TOTAL_PARTITIONS
-                            and j < self.spills - 1
-                            and get_used_memory() > hard_limit):
-                        self.data.clear()  # will read from disk again
-                        gc.collect()  # release the memory as much as possible
-                        for v in self._recursive_merged_items(i):
-                            yield v
-                        return
-
-                for v in self.data.iteritems():
+                for v in self._merged_items(i):
                     yield v
                 self.data.clear()
 
@@ -400,53 +388,56 @@ class ExternalMerger(Merger):
                 for j in range(self.spills):
                     path = self._get_spill_dir(j)
                     os.remove(os.path.join(path, str(i)))
-
         finally:
             self._cleanup()
 
-    def _cleanup(self):
-        """ Clean up all the files in disks """
-        for d in self.localdirs:
-            shutil.rmtree(d, True)
+    def _merged_items(self, index):
+        self.data = {}
+        limit = self._next_limit()
+        for j in range(self.spills):
+            path = self._get_spill_dir(j)
+            p = os.path.join(path, str(index))
+            # do not check memory during merging
+            self.mergeCombiners(self.serializer.load_stream(open(p)), 0)
+
+            # limit the total partitions
+            if (self.scale * self.partitions < self.MAX_TOTAL_PARTITIONS
+                    and j < self.spills - 1
+                    and get_used_memory() > limit):
+                self.data.clear()  # will read from disk again
+                gc.collect()  # release the memory as much as possible
+                return self._recursive_merged_items(index)
 
-    def _recursive_merged_items(self, start):
+        return self.data.iteritems()
+
+    def _recursive_merged_items(self, index):
         """
         merge the partitioned items and return the as iterator
 
         If one partition can not be fit in memory, then them will be
         partitioned and merged recursively.
         """
-        # make sure all the data are dumps into disks.
-        assert not self.data
-        if any(self.pdata):
-            self._spill()
-        assert self.spills > 0
-
-        for i in range(start, self.partitions):
-            subdirs = [os.path.join(d, "parts", str(i))
-                       for d in self.localdirs]
-            m = ExternalMerger(self.agg, self.memory_limit, self.serializer,
-                               subdirs, self.scale * self.partitions, self.partitions)
-            m.pdata = [{} for _ in range(self.partitions)]
-            limit = self._next_limit()
-
-            for j in range(self.spills):
-                path = self._get_spill_dir(j)
-                p = os.path.join(path, str(i))
-                m._partitioned_mergeCombiners(
-                    self.serializer.load_stream(open(p)))
-
-                if get_used_memory() > limit:
-                    m._spill()
-                    limit = self._next_limit()
+        subdirs = [os.path.join(d, "parts", str(index)) for d in self.localdirs]
+        m = ExternalMerger(self.agg, self.memory_limit, self.serializer, subdirs,
+                           self.scale * self.partitions, self.partitions, self.batch)
+        m.pdata = [{} for _ in range(self.partitions)]
+        limit = self._next_limit()
+
+        for j in range(self.spills):
+            path = self._get_spill_dir(j)
+            p = os.path.join(path, str(index))
+            m.mergeCombiners(self.serializer.load_stream(open(p)), 0)
+
+            if get_used_memory() > limit:
+                m._spill()
+                limit = self._next_limit()
 
-            for v in m._external_items():
-                yield v
+        return m._external_items()
 
-            # remove the merged partition
-            for j in range(self.spills):
-                path = self._get_spill_dir(j)
-                os.remove(os.path.join(path, str(i)))
+    def _cleanup(self):
+        """ Clean up all the files in disks """
+        for d in self.localdirs:
+            shutil.rmtree(d, True)
 
 
 class ExternalSorter(object):
@@ -457,6 +448,7 @@ class ExternalSorter(object):
     The spilling will only happen when the used memory goes above
     the limit.
 
+
     >>> sorter = ExternalSorter(1)  # 1M
     >>> import random
     >>> l = range(1024)
@@ -469,7 +461,7 @@ class ExternalSorter(object):
     def __init__(self, memory_limit, serializer=None):
         self.memory_limit = memory_limit
         self.local_dirs = _get_local_dirs("sort")
-        self.serializer = serializer or AutoBatchedSerializer(PickleSerializer())
+        self.serializer = _compressed_serializer(serializer)
 
     def _get_path(self, n):
         """ Choose one directory for spill by number n """
@@ -515,6 +507,7 @@ class ExternalSorter(object):
                 limit = self._next_limit()
                 MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
                 DiskBytesSpilled += os.path.getsize(path)
+                os.unlink(path)  # data will be deleted after close
 
             elif not chunks:
                 batch = min(batch * 2, 10000)
@@ -529,6 +522,310 @@ class ExternalSorter(object):
         return heapq.merge(chunks, key=key, reverse=reverse)
 
 
+class ExternalList(object):
+    """
+    ExternalList can have many items which cannot be hold in memory in
+    the same time.
+
+    >>> l = ExternalList(range(100))
+    >>> len(l)
+    100
+    >>> l.append(10)
+    >>> len(l)
+    101
+    >>> for i in range(20240):
+    ...     l.append(i)
+    >>> len(l)
+    20341
+    >>> import pickle
+    >>> l2 = pickle.loads(pickle.dumps(l))
+    >>> len(l2)
+    20341
+    >>> list(l2)[100]
+    10
+    """
+    LIMIT = 10240
+
+    def __init__(self, values):
+        self.values = values
+        self.count = len(values)
+        self._file = None
+        self._ser = None
+
+    def __getstate__(self):
+        if self._file is not None:
+            self._file.flush()
+            f = os.fdopen(os.dup(self._file.fileno()))
+            f.seek(0)
+            serialized = f.read()
+        else:
+            serialized = ''
+        return self.values, self.count, serialized
+
+    def __setstate__(self, item):
+        self.values, self.count, serialized = item
+        if serialized:
+            self._open_file()
+            self._file.write(serialized)
+        else:
+            self._file = None
+            self._ser = None
+
+    def __iter__(self):
+        if self._file is not None:
+            self._file.flush()
+            # read all items from disks first
+            with os.fdopen(os.dup(self._file.fileno()), 'r') as f:
+                f.seek(0)
+                for v in self._ser.load_stream(f):
+                    yield v
+
+        for v in self.values:
+            yield v
+
+    def __len__(self):
+        return self.count
+
+    def append(self, value):
+        self.values.append(value)
+        self.count += 1
+        # dump them into disk if the key is huge
+        if len(self.values) >= self.LIMIT:
+            self._spill()
+
+    def _open_file(self):
+        dirs = _get_local_dirs("objects")
+        d = dirs[id(self) % len(dirs)]
+        if not os.path.exists(d):
+            os.makedirs(d)
+        p = os.path.join(d, str(id))
+        self._file = open(p, "w+", 65536)
+        self._ser = BatchedSerializer(CompressedSerializer(PickleSerializer()), 1024)
+        os.unlink(p)
+
+    def _spill(self):
+        """ dump the values into disk """
+        global MemoryBytesSpilled, DiskBytesSpilled
+        if self._file is None:
+            self._open_file()
+
+        used_memory = get_used_memory()
+        pos = self._file.tell()
+        self._ser.dump_stream(self.values, self._file)
+        self.values = []
+        gc.collect()
+        DiskBytesSpilled += self._file.tell() - pos
+        MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
+
+
+class ExternalListOfList(ExternalList):
+    """
+    An external list for list.
+
+    >>> l = ExternalListOfList([[i, i] for i in range(100)])
+    >>> len(l)
+    200
+    >>> l.append(range(10))
+    >>> len(l)
+    210
+    >>> len(list(l))
+    210
+    """
+
+    def __init__(self, values):
+        ExternalList.__init__(self, values)
+        self.count = sum(len(i) for i in values)
+
+    def append(self, value):
+        ExternalList.append(self, value)
+        # already counted 1 in ExternalList.append
+        self.count += len(value) - 1
+
+    def __iter__(self):
+        for values in ExternalList.__iter__(self):
+            for v in values:
+                yield v
+
+
+class GroupByKey(object):
+    """
+    Group a sorted iterator as [(k1, it1), (k2, it2), ...]
+
+    >>> k = [i/3 for i in range(6)]
+    >>> v = [[i] for i in range(6)]
+    >>> g = GroupByKey(iter(zip(k, v)))
+    >>> [(k, list(it)) for k, it in g]
+    [(0, [0, 1, 2]), (1, [3, 4, 5])]
+    """
+
+    def __init__(self, iterator):
+        self.iterator = iter(iterator)
+        self.next_item = None
+
+    def __iter__(self):
+        return self
+
+    def next(self):
+        key, value = self.next_item if self.next_item else next(self.iterator)
+        values = ExternalListOfList([value])
+        try:
+            while True:
+                k, v = next(self.iterator)
+                if k != key:
+                    self.next_item = (k, v)
+                    break
+                values.append(v)
+        except StopIteration:
+            self.next_item = None
+        return key, values
+
+
+class ExternalGroupBy(ExternalMerger):
+
+    """
+    Group by the items by key. If any partition of them can not been
+    hold in memory, it will do sort based group by.
+
+    This class works as follows:
+
+    - It repeatedly group the items by key and save them in one dict in
+      memory.
+
+    - When the used memory goes above memory limit, it will split
+      the combined data into partitions by hash code, dump them
+      into disk, one file per partition. If the number of keys
+      in one partitions is smaller than 1000, it will sort them
+      by key before dumping into disk.
+
+    - Then it goes through the rest of the iterator, group items
+      by key into different dict by hash. Until the used memory goes over
+      memory limit, it dump all the dicts into disks, one file per
+      dict. Repeat this again until combine all the items. It
+      also will try to sort the items by key in each partition
+      before dumping into disks.
+
+    - It will yield the grouped items partitions by partitions.
+      If the data in one partitions can be hold in memory, then it
+      will load and combine them in memory and yield.
+
+    - If the dataset in one partition cannot be hold in memory,
+      it will sort them first. If all the files are already sorted,
+      it merge them by heap.merge(), so it will do external sort
+      for all the files.
+
+    - After sorting, `GroupByKey` class will put all the continuous
+      items with the same key as a group, yield the values as
+      an iterator.
+    """
+    SORT_KEY_LIMIT = 1000
+
+    def flattened_serializer(self):
+        assert isinstance(self.serializer, BatchedSerializer)
+        ser = self.serializer
+        return FlattenedValuesSerializer(ser, 20)
+
+    def _object_size(self, obj):
+        return len(obj)
+
+    def _spill(self):
+        """
+        dump already partitioned data into disks.
+        """
+        global MemoryBytesSpilled, DiskBytesSpilled
+        path = self._get_spill_dir(self.spills)
+        if not os.path.exists(path):
+            os.makedirs(path)
+
+        used_memory = get_used_memory()
+        if not self.pdata:
+            # The data has not been partitioned, it will iterator the
+            # data once, write them into different files, has no
+            # additional memory. It only called when the memory goes
+            # above limit at the first time.
+
+            # open all the files for writing
+            streams = [open(os.path.join(path, str(i)), 'w')
+                       for i in range(self.partitions)]
+
+            # If the number of keys is small, then the overhead of sort is small
+            # sort them before dumping into disks
+            self._sorted = len(self.data) < self.SORT_KEY_LIMIT
+            if self._sorted:
+                self.serializer = self.flattened_serializer()
+                for k in sorted(self.data.keys()):
+                    h = self._partition(k)
+                    self.serializer.dump_stream([(k, self.data[k])], streams[h])
+            else:
+                for k, v in self.data.iteritems():
+                    h = self._partition(k)
+                    self.serializer.dump_stream([(k, v)], streams[h])
+
+            for s in streams:
+                DiskBytesSpilled += s.tell()
+                s.close()
+
+            self.data.clear()
+            # self.pdata is cached in `mergeValues` and `mergeCombiners`
+            self.pdata.extend([{} for i in range(self.partitions)])
+
+        else:
+            for i in range(self.partitions):
+                p = os.path.join(path, str(i))
+                with open(p, "w") as f:
+                    # dump items in batch
+                    if self._sorted:
+                        # sort by key only (stable)
+                        sorted_items = sorted(self.pdata[i].iteritems(), key=operator.itemgetter(0))
+                        self.serializer.dump_stream(sorted_items, f)
+                    else:
+                        self.serializer.dump_stream(self.pdata[i].iteritems(), f)
+                self.pdata[i].clear()
+                DiskBytesSpilled += os.path.getsize(p)
+
+        self.spills += 1
+        gc.collect()  # release the memory as much as possible
+        MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
+
+    def _merged_items(self, index):
+        size = sum(os.path.getsize(os.path.join(self._get_spill_dir(j), str(index)))
+                   for j in range(self.spills))
+        # if the memory can not hold all the partition,
+        # then use sort based merge. Because of compression,
+        # the data on disks will be much smaller than needed memory
+        if (size >> 20) >= self.memory_limit / 10:
+            return self._merge_sorted_items(index)
+
+        self.data = {}
+        for j in range(self.spills):
+            path = self._get_spill_dir(j)
+            p = os.path.join(path, str(index))
+            # do not check memory during merging
+            self.mergeCombiners(self.serializer.load_stream(open(p)), 0)
+        return self.data.iteritems()
+
+    def _merge_sorted_items(self, index):
+        """ load a partition from disk, then sort and group by key """
+        def load_partition(j):
+            path = self._get_spill_dir(j)
+            p = os.path.join(path, str(index))
+            return self.serializer.load_stream(open(p, 'r', 65536))
+
+        disk_items = [load_partition(j) for j in range(self.spills)]
+
+        if self._sorted:
+            # all the partitions are already sorted
+            sorted_items = heapq.merge(disk_items, key=operator.itemgetter(0))
+
+        else:
+            # Flatten the combined values, so it will not consume huge
+            # memory during merging sort.
+            ser = self.flattened_serializer()
+            sorter = ExternalSorter(self.memory_limit, ser)
+            sorted_items = sorter.sorted(itertools.chain(*disk_items),
+                                         key=operator.itemgetter(0))
+        return ((k, vs) for k, vs in GroupByKey(sorted_items))
+
+
 if __name__ == "__main__":
     import doctest
     doctest.testmod()
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index dd8d3b1c53..0bd5d20f78 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -31,6 +31,7 @@ import tempfile
 import time
 import zipfile
 import random
+import itertools
 import threading
 import hashlib
 
@@ -76,7 +77,7 @@ SPARK_HOME = os.environ["SPARK_HOME"]
 class MergerTests(unittest.TestCase):
 
     def setUp(self):
-        self.N = 1 << 14
+        self.N = 1 << 12
         self.l = [i for i in xrange(self.N)]
         self.data = zip(self.l, self.l)
         self.agg = Aggregator(lambda x: [x],
@@ -108,7 +109,7 @@ class MergerTests(unittest.TestCase):
                          sum(xrange(self.N)))
 
     def test_medium_dataset(self):
-        m = ExternalMerger(self.agg, 10)
+        m = ExternalMerger(self.agg, 30)
         m.mergeValues(self.data)
         self.assertTrue(m.spills >= 1)
         self.assertEqual(sum(sum(v) for k, v in m.iteritems()),
@@ -124,10 +125,36 @@ class MergerTests(unittest.TestCase):
         m = ExternalMerger(self.agg, 10, partitions=3)
         m.mergeCombiners(map(lambda (k, v): (k, [str(v)]), self.data * 10))
         self.assertTrue(m.spills >= 1)
-        self.assertEqual(sum(len(v) for k, v in m._recursive_merged_items(0)),
+        self.assertEqual(sum(len(v) for k, v in m.iteritems()),
                          self.N * 10)
         m._cleanup()
 
+    def test_group_by_key(self):
+
+        def gen_data(N, step):
+            for i in range(1, N + 1, step):
+                for j in range(i):
+                    yield (i, [j])
+
+        def gen_gs(N, step=1):
+            return shuffle.GroupByKey(gen_data(N, step))
+
+        self.assertEqual(1, len(list(gen_gs(1))))
+        self.assertEqual(2, len(list(gen_gs(2))))
+        self.assertEqual(100, len(list(gen_gs(100))))
+        self.assertEqual(range(1, 101), [k for k, _ in gen_gs(100)])
+        self.assertTrue(all(range(k) == list(vs) for k, vs in gen_gs(100)))
+
+        for k, vs in gen_gs(50002, 10000):
+            self.assertEqual(k, len(vs))
+            self.assertEqual(range(k), list(vs))
+
+        ser = PickleSerializer()
+        l = ser.loads(ser.dumps(list(gen_gs(50002, 30000))))
+        for k, vs in l:
+            self.assertEqual(k, len(vs))
+            self.assertEqual(range(k), list(vs))
+
 
 class SorterTests(unittest.TestCase):
     def test_in_memory_sort(self):
@@ -702,6 +729,21 @@ class RDDTests(ReusedPySparkTestCase):
         self.assertEquals(result.getNumPartitions(), 5)
         self.assertEquals(result.count(), 3)
 
+    def test_external_group_by_key(self):
+        self.sc._conf.set("spark.python.worker.memory", "5m")
+        N = 200001
+        kv = self.sc.parallelize(range(N)).map(lambda x: (x % 3, x))
+        gkv = kv.groupByKey().cache()
+        self.assertEqual(3, gkv.count())
+        filtered = gkv.filter(lambda (k, vs): k == 1)
+        self.assertEqual(1, filtered.count())
+        self.assertEqual([(1, N/3)], filtered.mapValues(len).collect())
+        self.assertEqual([(N/3, N/3)],
+                         filtered.values().map(lambda x: (len(x), len(list(x)))).collect())
+        result = filtered.collect()[0][1]
+        self.assertEqual(N/3, len(result))
+        self.assertTrue(isinstance(result.data, shuffle.ExternalList))
+
     def test_sort_on_empty_rdd(self):
         self.assertEqual([], self.sc.parallelize(zip([], [])).sortByKey().collect())
 
@@ -752,9 +794,9 @@ class RDDTests(ReusedPySparkTestCase):
         self.assertEqual(rdd.getNumPartitions() + 2, parted.union(rdd).getNumPartitions())
         self.assertEqual(rdd.getNumPartitions() + 2, rdd.union(parted).getNumPartitions())
 
-        self.sc.setJobGroup("test1", "test", True)
         tracker = self.sc.statusTracker()
 
+        self.sc.setJobGroup("test1", "test", True)
         d = sorted(parted.join(parted).collect())
         self.assertEqual(10, len(d))
         self.assertEqual((0, (0, 0)), d[0])
-- 
GitLab