diff --git a/python/pyspark/join.py b/python/pyspark/join.py
index efc1ef93964129ef6a3a8daee03a91212b05132d..c3491defb2b292988963fe0bf65104ccba4ead6c 100644
--- a/python/pyspark/join.py
+++ b/python/pyspark/join.py
@@ -48,7 +48,7 @@ def python_join(rdd, other, numPartitions):
                 vbuf.append(v)
             elif n == 2:
                 wbuf.append(v)
-        return [(v, w) for v in vbuf for w in wbuf]
+        return ((v, w) for v in vbuf for w in wbuf)
     return _do_python_join(rdd, other, numPartitions, dispatch)
 
 
@@ -62,7 +62,7 @@ def python_right_outer_join(rdd, other, numPartitions):
                 wbuf.append(v)
         if not vbuf:
             vbuf.append(None)
-        return [(v, w) for v in vbuf for w in wbuf]
+        return ((v, w) for v in vbuf for w in wbuf)
     return _do_python_join(rdd, other, numPartitions, dispatch)
 
 
@@ -76,7 +76,7 @@ def python_left_outer_join(rdd, other, numPartitions):
                 wbuf.append(v)
         if not wbuf:
             wbuf.append(None)
-        return [(v, w) for v in vbuf for w in wbuf]
+        return ((v, w) for v in vbuf for w in wbuf)
     return _do_python_join(rdd, other, numPartitions, dispatch)
 
 
@@ -104,8 +104,9 @@ def python_cogroup(rdds, numPartitions):
     rdd_len = len(vrdds)
 
     def dispatch(seq):
-        bufs = [[] for i in range(rdd_len)]
-        for (n, v) in seq:
+        bufs = [[] for _ in range(rdd_len)]
+        for n, v in seq:
             bufs[n].append(v)
-        return tuple(map(ResultIterable, bufs))
+        return tuple(ResultIterable(vs) for vs in bufs)
+
     return union_vrdds.groupByKey(numPartitions).mapValues(dispatch)
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 2d05611321ed638d90b8869c73fba265dedfa9d1..1b18789040360a8eb0049b81caa3d39ce2784608 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -41,7 +41,7 @@ from pyspark.rddsampler import RDDSampler, RDDRangeSampler, RDDStratifiedSampler
 from pyspark.storagelevel import StorageLevel
 from pyspark.resultiterable import ResultIterable
 from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \
-    get_used_memory, ExternalSorter
+    get_used_memory, ExternalSorter, ExternalGroupBy
 from pyspark.traceback_utils import SCCallSiteSync
 
 from py4j.java_collections import ListConverter, MapConverter
@@ -573,8 +573,8 @@ class RDD(object):
         if numPartitions is None:
             numPartitions = self._defaultReducePartitions()
 
-        spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() == 'true')
-        memory = _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m"))
+        spill = self._can_spill()
+        memory = self._memory_limit()
         serializer = self._jrdd_deserializer
 
         def sortPartition(iterator):
@@ -1699,10 +1699,8 @@ class RDD(object):
             numPartitions = self._defaultReducePartitions()
 
         serializer = self.ctx.serializer
-        spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower()
-                 == 'true')
-        memory = _parse_memory(self.ctx._conf.get(
-            "spark.python.worker.memory", "512m"))
+        spill = self._can_spill()
+        memory = self._memory_limit()
         agg = Aggregator(createCombiner, mergeValue, mergeCombiners)
 
         def combineLocally(iterator):
@@ -1755,21 +1753,28 @@ class RDD(object):
 
         return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions)
 
+    def _can_spill(self):
+        return self.ctx._conf.get("spark.shuffle.spill", "True").lower() == "true"
+
+    def _memory_limit(self):
+        return _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m"))
+
     # TODO: support variant with custom partitioner
     def groupByKey(self, numPartitions=None):
         """
         Group the values for each key in the RDD into a single sequence.
-        Hash-partitions the resulting RDD with into numPartitions partitions.
+        Hash-partitions the resulting RDD with numPartitions partitions.
 
         Note: If you are grouping in order to perform an aggregation (such as a
         sum or average) over each key, using reduceByKey or aggregateByKey will
         provide much better performance.
 
         >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
-        >>> map((lambda (x,y): (x, list(y))), sorted(x.groupByKey().collect()))
+        >>> sorted(x.groupByKey().mapValues(len).collect())
+        [('a', 2), ('b', 1)]
+        >>> sorted(x.groupByKey().mapValues(list).collect())
         [('a', [1, 1]), ('b', [1])]
         """
-
         def createCombiner(x):
             return [x]
 
@@ -1781,8 +1786,27 @@ class RDD(object):
             a.extend(b)
             return a
 
-        return self.combineByKey(createCombiner, mergeValue, mergeCombiners,
-                                 numPartitions).mapValues(lambda x: ResultIterable(x))
+        spill = self._can_spill()
+        memory = self._memory_limit()
+        serializer = self._jrdd_deserializer
+        agg = Aggregator(createCombiner, mergeValue, mergeCombiners)
+
+        def combine(iterator):
+            merger = ExternalMerger(agg, memory * 0.9, serializer) \
+                if spill else InMemoryMerger(agg)
+            merger.mergeValues(iterator)
+            return merger.iteritems()
+
+        locally_combined = self.mapPartitions(combine, preservesPartitioning=True)
+        shuffled = locally_combined.partitionBy(numPartitions)
+
+        def groupByKey(it):
+            merger = ExternalGroupBy(agg, memory, serializer)\
+                if spill else InMemoryMerger(agg)
+            merger.mergeCombiners(it)
+            return merger.iteritems()
+
+        return shuffled.mapPartitions(groupByKey, True).mapValues(ResultIterable)
 
     def flatMapValues(self, f):
         """
diff --git a/python/pyspark/resultiterable.py b/python/pyspark/resultiterable.py
index ef04c82866e6c5d91bc2ebcbf3b681ba0e8e97ff..1ab5ce14c3531c7aaaa7be5dbd31fae31802b47d 100644
--- a/python/pyspark/resultiterable.py
+++ b/python/pyspark/resultiterable.py
@@ -15,15 +15,16 @@
 # limitations under the License.
 #
 
-__all__ = ["ResultIterable"]
-
 import collections
 
+__all__ = ["ResultIterable"]
+
 
 class ResultIterable(collections.Iterable):
 
     """
-    A special result iterable. This is used because the standard iterator can not be pickled
+    A special result iterable. This is used because the standard
+    iterator can not be pickled
     """
 
     def __init__(self, data):
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 0ffb41d02f6f6343d0376872cb3d83c034f44824..4afa82f4b297316782d1441e8a3cf3130affc8c9 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -220,6 +220,29 @@ class BatchedSerializer(Serializer):
         return "BatchedSerializer(%s, %d)" % (str(self.serializer), self.batchSize)
 
 
+class FlattenedValuesSerializer(BatchedSerializer):
+
+    """
+    Serializes a stream of list of pairs, split the list of values
+    which contain more than a certain number of objects to make them
+    have similar sizes.
+    """
+    def __init__(self, serializer, batchSize=10):
+        BatchedSerializer.__init__(self, serializer, batchSize)
+
+    def _batched(self, iterator):
+        n = self.batchSize
+        for key, values in iterator:
+            for i in xrange(0, len(values), n):
+                yield key, values[i:i + n]
+
+    def load_stream(self, stream):
+        return self.serializer.load_stream(stream)
+
+    def __repr__(self):
+        return "FlattenedValuesSerializer(%d)" % self.batchSize
+
+
 class AutoBatchedSerializer(BatchedSerializer):
     """
     Choose the size of batch automatically based on the size of object
@@ -251,7 +274,7 @@ class AutoBatchedSerializer(BatchedSerializer):
         return (isinstance(other, AutoBatchedSerializer) and
                 other.serializer == self.serializer and other.bestSize == self.bestSize)
 
-    def __str__(self):
+    def __repr__(self):
         return "AutoBatchedSerializer(%s)" % str(self.serializer)
 
 
diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py
index 10a7ccd5020008be9834700ffdc1ff4096241a5e..8a6fc627eb383b526906ba60f898d7731f2203a9 100644
--- a/python/pyspark/shuffle.py
+++ b/python/pyspark/shuffle.py
@@ -16,28 +16,35 @@
 #
 
 import os
-import sys
 import platform
 import shutil
 import warnings
 import gc
 import itertools
+import operator
 import random
 
 import pyspark.heapq3 as heapq
-from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
+from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \
+    CompressedSerializer, AutoBatchedSerializer
+
 
 try:
     import psutil
 
+    process = None
+
     def get_used_memory():
         """ Return the used memory in MB """
-        process = psutil.Process(os.getpid())
+        global process
+        if process is None or process._pid != os.getpid():
+            process = psutil.Process(os.getpid())
         if hasattr(process, "memory_info"):
             info = process.memory_info()
         else:
             info = process.get_memory_info()
         return info.rss >> 20
+
 except ImportError:
 
     def get_used_memory():
@@ -46,6 +53,7 @@ except ImportError:
             for line in open('/proc/self/status'):
                 if line.startswith('VmRSS:'):
                     return int(line.split()[1]) >> 10
+
         else:
             warnings.warn("Please install psutil to have better "
                           "support with spilling")
@@ -54,6 +62,7 @@ except ImportError:
                 rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
                 return rss >> 20
             # TODO: support windows
+
         return 0
 
 
@@ -148,10 +157,16 @@ class InMemoryMerger(Merger):
             d[k] = comb(d[k], v) if k in d else v
 
     def iteritems(self):
-        """ Return the merged items ad iterator """
+        """ Return the merged items as iterator """
         return self.data.iteritems()
 
 
+def _compressed_serializer(self, serializer=None):
+    # always use PickleSerializer to simplify implementation
+    ser = PickleSerializer()
+    return AutoBatchedSerializer(CompressedSerializer(ser))
+
+
 class ExternalMerger(Merger):
 
     """
@@ -173,7 +188,7 @@ class ExternalMerger(Merger):
       dict. Repeat this again until combine all the items.
 
     - Before return any items, it will load each partition and
-      combine them seperately. Yield them before loading next
+      combine them separately. Yield them before loading next
       partition.
 
     - During loading a partition, if the memory goes over limit,
@@ -182,7 +197,7 @@ class ExternalMerger(Merger):
 
     `data` and `pdata` are used to hold the merged items in memory.
     At first, all the data are merged into `data`. Once the used
-    memory goes over limit, the items in `data` are dumped indo
+    memory goes over limit, the items in `data` are dumped into
     disks, `data` will be cleared, all rest of items will be merged
     into `pdata` and then dumped into disks. Before returning, all
     the items in `pdata` will be dumped into disks.
@@ -193,16 +208,16 @@ class ExternalMerger(Merger):
     >>> agg = SimpleAggregator(lambda x, y: x + y)
     >>> merger = ExternalMerger(agg, 10)
     >>> N = 10000
-    >>> merger.mergeValues(zip(xrange(N), xrange(N)) * 10)
+    >>> merger.mergeValues(zip(xrange(N), xrange(N)))
     >>> assert merger.spills > 0
     >>> sum(v for k,v in merger.iteritems())
-    499950000
+    49995000
 
     >>> merger = ExternalMerger(agg, 10)
-    >>> merger.mergeCombiners(zip(xrange(N), xrange(N)) * 10)
+    >>> merger.mergeCombiners(zip(xrange(N), xrange(N)))
     >>> assert merger.spills > 0
     >>> sum(v for k,v in merger.iteritems())
-    499950000
+    49995000
     """
 
     # the max total partitions created recursively
@@ -212,8 +227,7 @@ class ExternalMerger(Merger):
                  localdirs=None, scale=1, partitions=59, batch=1000):
         Merger.__init__(self, aggregator)
         self.memory_limit = memory_limit
-        # default serializer is only used for tests
-        self.serializer = serializer or AutoBatchedSerializer(PickleSerializer())
+        self.serializer = _compressed_serializer(serializer)
         self.localdirs = localdirs or _get_local_dirs(str(id(self)))
         # number of partitions when spill data into disks
         self.partitions = partitions
@@ -221,7 +235,7 @@ class ExternalMerger(Merger):
         self.batch = batch
         # scale is used to scale down the hash of key for recursive hash map
         self.scale = scale
-        # unpartitioned merged data
+        # un-partitioned merged data
         self.data = {}
         # partitioned merged data, list of dicts
         self.pdata = []
@@ -244,72 +258,63 @@ class ExternalMerger(Merger):
 
     def mergeValues(self, iterator):
         """ Combine the items by creator and combiner """
-        iterator = iter(iterator)
         # speedup attribute lookup
         creator, comb = self.agg.createCombiner, self.agg.mergeValue
-        d, c, batch = self.data, 0, self.batch
+        c, data, pdata, hfun, batch = 0, self.data, self.pdata, self._partition, self.batch
+        limit = self.memory_limit
 
         for k, v in iterator:
+            d = pdata[hfun(k)] if pdata else data
             d[k] = comb(d[k], v) if k in d else creator(v)
 
             c += 1
-            if c % batch == 0 and get_used_memory() > self.memory_limit:
-                self._spill()
-                self._partitioned_mergeValues(iterator, self._next_limit())
-                break
+            if c >= batch:
+                if get_used_memory() >= limit:
+                    self._spill()
+                    limit = self._next_limit()
+                    batch /= 2
+                    c = 0
+                else:
+                    batch *= 1.5
+
+        if get_used_memory() >= limit:
+            self._spill()
 
     def _partition(self, key):
         """ Return the partition for key """
         return hash((key, self._seed)) % self.partitions
 
-    def _partitioned_mergeValues(self, iterator, limit=0):
-        """ Partition the items by key, then combine them """
-        # speedup attribute lookup
-        creator, comb = self.agg.createCombiner, self.agg.mergeValue
-        c, pdata, hfun, batch = 0, self.pdata, self._partition, self.batch
-
-        for k, v in iterator:
-            d = pdata[hfun(k)]
-            d[k] = comb(d[k], v) if k in d else creator(v)
-            if not limit:
-                continue
-
-            c += 1
-            if c % batch == 0 and get_used_memory() > limit:
-                self._spill()
-                limit = self._next_limit()
+    def _object_size(self, obj):
+        """ How much of memory for this obj, assume that all the objects
+        consume similar bytes of memory
+        """
+        return 1
 
-    def mergeCombiners(self, iterator, check=True):
+    def mergeCombiners(self, iterator, limit=None):
         """ Merge (K,V) pair by mergeCombiner """
-        iterator = iter(iterator)
+        if limit is None:
+            limit = self.memory_limit
         # speedup attribute lookup
-        d, comb, batch = self.data, self.agg.mergeCombiners, self.batch
-        c = 0
-        for k, v in iterator:
-            d[k] = comb(d[k], v) if k in d else v
-            if not check:
-                continue
-
-            c += 1
-            if c % batch == 0 and get_used_memory() > self.memory_limit:
-                self._spill()
-                self._partitioned_mergeCombiners(iterator, self._next_limit())
-                break
-
-    def _partitioned_mergeCombiners(self, iterator, limit=0):
-        """ Partition the items by key, then merge them """
-        comb, pdata = self.agg.mergeCombiners, self.pdata
-        c, hfun = 0, self._partition
+        comb, hfun, objsize = self.agg.mergeCombiners, self._partition, self._object_size
+        c, data, pdata, batch = 0, self.data, self.pdata, self.batch
         for k, v in iterator:
-            d = pdata[hfun(k)]
+            d = pdata[hfun(k)] if pdata else data
             d[k] = comb(d[k], v) if k in d else v
             if not limit:
                 continue
 
-            c += 1
-            if c % self.batch == 0 and get_used_memory() > limit:
-                self._spill()
-                limit = self._next_limit()
+            c += objsize(v)
+            if c > batch:
+                if get_used_memory() > limit:
+                    self._spill()
+                    limit = self._next_limit()
+                    batch /= 2
+                    c = 0
+                else:
+                    batch *= 1.5
+
+        if limit and get_used_memory() >= limit:
+            self._spill()
 
     def _spill(self):
         """
@@ -335,7 +340,7 @@ class ExternalMerger(Merger):
 
             for k, v in self.data.iteritems():
                 h = self._partition(k)
-                # put one item in batch, make it compatitable with load_stream
+                # put one item in batch, make it compatible with load_stream
                 # it will increase the memory if dump them in batch
                 self.serializer.dump_stream([(k, v)], streams[h])
 
@@ -344,7 +349,7 @@ class ExternalMerger(Merger):
                 s.close()
 
             self.data.clear()
-            self.pdata = [{} for i in range(self.partitions)]
+            self.pdata.extend([{} for i in range(self.partitions)])
 
         else:
             for i in range(self.partitions):
@@ -370,29 +375,12 @@ class ExternalMerger(Merger):
         assert not self.data
         if any(self.pdata):
             self._spill()
-        hard_limit = self._next_limit()
+        # disable partitioning and spilling when merge combiners from disk
+        self.pdata = []
 
         try:
             for i in range(self.partitions):
-                self.data = {}
-                for j in range(self.spills):
-                    path = self._get_spill_dir(j)
-                    p = os.path.join(path, str(i))
-                    # do not check memory during merging
-                    self.mergeCombiners(self.serializer.load_stream(open(p)),
-                                        False)
-
-                    # limit the total partitions
-                    if (self.scale * self.partitions < self.MAX_TOTAL_PARTITIONS
-                            and j < self.spills - 1
-                            and get_used_memory() > hard_limit):
-                        self.data.clear()  # will read from disk again
-                        gc.collect()  # release the memory as much as possible
-                        for v in self._recursive_merged_items(i):
-                            yield v
-                        return
-
-                for v in self.data.iteritems():
+                for v in self._merged_items(i):
                     yield v
                 self.data.clear()
 
@@ -400,53 +388,56 @@ class ExternalMerger(Merger):
                 for j in range(self.spills):
                     path = self._get_spill_dir(j)
                     os.remove(os.path.join(path, str(i)))
-
         finally:
             self._cleanup()
 
-    def _cleanup(self):
-        """ Clean up all the files in disks """
-        for d in self.localdirs:
-            shutil.rmtree(d, True)
+    def _merged_items(self, index):
+        self.data = {}
+        limit = self._next_limit()
+        for j in range(self.spills):
+            path = self._get_spill_dir(j)
+            p = os.path.join(path, str(index))
+            # do not check memory during merging
+            self.mergeCombiners(self.serializer.load_stream(open(p)), 0)
+
+            # limit the total partitions
+            if (self.scale * self.partitions < self.MAX_TOTAL_PARTITIONS
+                    and j < self.spills - 1
+                    and get_used_memory() > limit):
+                self.data.clear()  # will read from disk again
+                gc.collect()  # release the memory as much as possible
+                return self._recursive_merged_items(index)
 
-    def _recursive_merged_items(self, start):
+        return self.data.iteritems()
+
+    def _recursive_merged_items(self, index):
         """
         merge the partitioned items and return the as iterator
 
         If one partition can not be fit in memory, then them will be
         partitioned and merged recursively.
         """
-        # make sure all the data are dumps into disks.
-        assert not self.data
-        if any(self.pdata):
-            self._spill()
-        assert self.spills > 0
-
-        for i in range(start, self.partitions):
-            subdirs = [os.path.join(d, "parts", str(i))
-                       for d in self.localdirs]
-            m = ExternalMerger(self.agg, self.memory_limit, self.serializer,
-                               subdirs, self.scale * self.partitions, self.partitions)
-            m.pdata = [{} for _ in range(self.partitions)]
-            limit = self._next_limit()
-
-            for j in range(self.spills):
-                path = self._get_spill_dir(j)
-                p = os.path.join(path, str(i))
-                m._partitioned_mergeCombiners(
-                    self.serializer.load_stream(open(p)))
-
-                if get_used_memory() > limit:
-                    m._spill()
-                    limit = self._next_limit()
+        subdirs = [os.path.join(d, "parts", str(index)) for d in self.localdirs]
+        m = ExternalMerger(self.agg, self.memory_limit, self.serializer, subdirs,
+                           self.scale * self.partitions, self.partitions, self.batch)
+        m.pdata = [{} for _ in range(self.partitions)]
+        limit = self._next_limit()
+
+        for j in range(self.spills):
+            path = self._get_spill_dir(j)
+            p = os.path.join(path, str(index))
+            m.mergeCombiners(self.serializer.load_stream(open(p)), 0)
+
+            if get_used_memory() > limit:
+                m._spill()
+                limit = self._next_limit()
 
-            for v in m._external_items():
-                yield v
+        return m._external_items()
 
-            # remove the merged partition
-            for j in range(self.spills):
-                path = self._get_spill_dir(j)
-                os.remove(os.path.join(path, str(i)))
+    def _cleanup(self):
+        """ Clean up all the files in disks """
+        for d in self.localdirs:
+            shutil.rmtree(d, True)
 
 
 class ExternalSorter(object):
@@ -457,6 +448,7 @@ class ExternalSorter(object):
     The spilling will only happen when the used memory goes above
     the limit.
 
+
     >>> sorter = ExternalSorter(1)  # 1M
     >>> import random
     >>> l = range(1024)
@@ -469,7 +461,7 @@ class ExternalSorter(object):
     def __init__(self, memory_limit, serializer=None):
         self.memory_limit = memory_limit
         self.local_dirs = _get_local_dirs("sort")
-        self.serializer = serializer or AutoBatchedSerializer(PickleSerializer())
+        self.serializer = _compressed_serializer(serializer)
 
     def _get_path(self, n):
         """ Choose one directory for spill by number n """
@@ -515,6 +507,7 @@ class ExternalSorter(object):
                 limit = self._next_limit()
                 MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
                 DiskBytesSpilled += os.path.getsize(path)
+                os.unlink(path)  # data will be deleted after close
 
             elif not chunks:
                 batch = min(batch * 2, 10000)
@@ -529,6 +522,310 @@ class ExternalSorter(object):
         return heapq.merge(chunks, key=key, reverse=reverse)
 
 
+class ExternalList(object):
+    """
+    ExternalList can have many items which cannot be hold in memory in
+    the same time.
+
+    >>> l = ExternalList(range(100))
+    >>> len(l)
+    100
+    >>> l.append(10)
+    >>> len(l)
+    101
+    >>> for i in range(20240):
+    ...     l.append(i)
+    >>> len(l)
+    20341
+    >>> import pickle
+    >>> l2 = pickle.loads(pickle.dumps(l))
+    >>> len(l2)
+    20341
+    >>> list(l2)[100]
+    10
+    """
+    LIMIT = 10240
+
+    def __init__(self, values):
+        self.values = values
+        self.count = len(values)
+        self._file = None
+        self._ser = None
+
+    def __getstate__(self):
+        if self._file is not None:
+            self._file.flush()
+            f = os.fdopen(os.dup(self._file.fileno()))
+            f.seek(0)
+            serialized = f.read()
+        else:
+            serialized = ''
+        return self.values, self.count, serialized
+
+    def __setstate__(self, item):
+        self.values, self.count, serialized = item
+        if serialized:
+            self._open_file()
+            self._file.write(serialized)
+        else:
+            self._file = None
+            self._ser = None
+
+    def __iter__(self):
+        if self._file is not None:
+            self._file.flush()
+            # read all items from disks first
+            with os.fdopen(os.dup(self._file.fileno()), 'r') as f:
+                f.seek(0)
+                for v in self._ser.load_stream(f):
+                    yield v
+
+        for v in self.values:
+            yield v
+
+    def __len__(self):
+        return self.count
+
+    def append(self, value):
+        self.values.append(value)
+        self.count += 1
+        # dump them into disk if the key is huge
+        if len(self.values) >= self.LIMIT:
+            self._spill()
+
+    def _open_file(self):
+        dirs = _get_local_dirs("objects")
+        d = dirs[id(self) % len(dirs)]
+        if not os.path.exists(d):
+            os.makedirs(d)
+        p = os.path.join(d, str(id))
+        self._file = open(p, "w+", 65536)
+        self._ser = BatchedSerializer(CompressedSerializer(PickleSerializer()), 1024)
+        os.unlink(p)
+
+    def _spill(self):
+        """ dump the values into disk """
+        global MemoryBytesSpilled, DiskBytesSpilled
+        if self._file is None:
+            self._open_file()
+
+        used_memory = get_used_memory()
+        pos = self._file.tell()
+        self._ser.dump_stream(self.values, self._file)
+        self.values = []
+        gc.collect()
+        DiskBytesSpilled += self._file.tell() - pos
+        MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
+
+
+class ExternalListOfList(ExternalList):
+    """
+    An external list for list.
+
+    >>> l = ExternalListOfList([[i, i] for i in range(100)])
+    >>> len(l)
+    200
+    >>> l.append(range(10))
+    >>> len(l)
+    210
+    >>> len(list(l))
+    210
+    """
+
+    def __init__(self, values):
+        ExternalList.__init__(self, values)
+        self.count = sum(len(i) for i in values)
+
+    def append(self, value):
+        ExternalList.append(self, value)
+        # already counted 1 in ExternalList.append
+        self.count += len(value) - 1
+
+    def __iter__(self):
+        for values in ExternalList.__iter__(self):
+            for v in values:
+                yield v
+
+
+class GroupByKey(object):
+    """
+    Group a sorted iterator as [(k1, it1), (k2, it2), ...]
+
+    >>> k = [i/3 for i in range(6)]
+    >>> v = [[i] for i in range(6)]
+    >>> g = GroupByKey(iter(zip(k, v)))
+    >>> [(k, list(it)) for k, it in g]
+    [(0, [0, 1, 2]), (1, [3, 4, 5])]
+    """
+
+    def __init__(self, iterator):
+        self.iterator = iter(iterator)
+        self.next_item = None
+
+    def __iter__(self):
+        return self
+
+    def next(self):
+        key, value = self.next_item if self.next_item else next(self.iterator)
+        values = ExternalListOfList([value])
+        try:
+            while True:
+                k, v = next(self.iterator)
+                if k != key:
+                    self.next_item = (k, v)
+                    break
+                values.append(v)
+        except StopIteration:
+            self.next_item = None
+        return key, values
+
+
+class ExternalGroupBy(ExternalMerger):
+
+    """
+    Group by the items by key. If any partition of them can not been
+    hold in memory, it will do sort based group by.
+
+    This class works as follows:
+
+    - It repeatedly group the items by key and save them in one dict in
+      memory.
+
+    - When the used memory goes above memory limit, it will split
+      the combined data into partitions by hash code, dump them
+      into disk, one file per partition. If the number of keys
+      in one partitions is smaller than 1000, it will sort them
+      by key before dumping into disk.
+
+    - Then it goes through the rest of the iterator, group items
+      by key into different dict by hash. Until the used memory goes over
+      memory limit, it dump all the dicts into disks, one file per
+      dict. Repeat this again until combine all the items. It
+      also will try to sort the items by key in each partition
+      before dumping into disks.
+
+    - It will yield the grouped items partitions by partitions.
+      If the data in one partitions can be hold in memory, then it
+      will load and combine them in memory and yield.
+
+    - If the dataset in one partition cannot be hold in memory,
+      it will sort them first. If all the files are already sorted,
+      it merge them by heap.merge(), so it will do external sort
+      for all the files.
+
+    - After sorting, `GroupByKey` class will put all the continuous
+      items with the same key as a group, yield the values as
+      an iterator.
+    """
+    SORT_KEY_LIMIT = 1000
+
+    def flattened_serializer(self):
+        assert isinstance(self.serializer, BatchedSerializer)
+        ser = self.serializer
+        return FlattenedValuesSerializer(ser, 20)
+
+    def _object_size(self, obj):
+        return len(obj)
+
+    def _spill(self):
+        """
+        dump already partitioned data into disks.
+        """
+        global MemoryBytesSpilled, DiskBytesSpilled
+        path = self._get_spill_dir(self.spills)
+        if not os.path.exists(path):
+            os.makedirs(path)
+
+        used_memory = get_used_memory()
+        if not self.pdata:
+            # The data has not been partitioned, it will iterator the
+            # data once, write them into different files, has no
+            # additional memory. It only called when the memory goes
+            # above limit at the first time.
+
+            # open all the files for writing
+            streams = [open(os.path.join(path, str(i)), 'w')
+                       for i in range(self.partitions)]
+
+            # If the number of keys is small, then the overhead of sort is small
+            # sort them before dumping into disks
+            self._sorted = len(self.data) < self.SORT_KEY_LIMIT
+            if self._sorted:
+                self.serializer = self.flattened_serializer()
+                for k in sorted(self.data.keys()):
+                    h = self._partition(k)
+                    self.serializer.dump_stream([(k, self.data[k])], streams[h])
+            else:
+                for k, v in self.data.iteritems():
+                    h = self._partition(k)
+                    self.serializer.dump_stream([(k, v)], streams[h])
+
+            for s in streams:
+                DiskBytesSpilled += s.tell()
+                s.close()
+
+            self.data.clear()
+            # self.pdata is cached in `mergeValues` and `mergeCombiners`
+            self.pdata.extend([{} for i in range(self.partitions)])
+
+        else:
+            for i in range(self.partitions):
+                p = os.path.join(path, str(i))
+                with open(p, "w") as f:
+                    # dump items in batch
+                    if self._sorted:
+                        # sort by key only (stable)
+                        sorted_items = sorted(self.pdata[i].iteritems(), key=operator.itemgetter(0))
+                        self.serializer.dump_stream(sorted_items, f)
+                    else:
+                        self.serializer.dump_stream(self.pdata[i].iteritems(), f)
+                self.pdata[i].clear()
+                DiskBytesSpilled += os.path.getsize(p)
+
+        self.spills += 1
+        gc.collect()  # release the memory as much as possible
+        MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
+
+    def _merged_items(self, index):
+        size = sum(os.path.getsize(os.path.join(self._get_spill_dir(j), str(index)))
+                   for j in range(self.spills))
+        # if the memory can not hold all the partition,
+        # then use sort based merge. Because of compression,
+        # the data on disks will be much smaller than needed memory
+        if (size >> 20) >= self.memory_limit / 10:
+            return self._merge_sorted_items(index)
+
+        self.data = {}
+        for j in range(self.spills):
+            path = self._get_spill_dir(j)
+            p = os.path.join(path, str(index))
+            # do not check memory during merging
+            self.mergeCombiners(self.serializer.load_stream(open(p)), 0)
+        return self.data.iteritems()
+
+    def _merge_sorted_items(self, index):
+        """ load a partition from disk, then sort and group by key """
+        def load_partition(j):
+            path = self._get_spill_dir(j)
+            p = os.path.join(path, str(index))
+            return self.serializer.load_stream(open(p, 'r', 65536))
+
+        disk_items = [load_partition(j) for j in range(self.spills)]
+
+        if self._sorted:
+            # all the partitions are already sorted
+            sorted_items = heapq.merge(disk_items, key=operator.itemgetter(0))
+
+        else:
+            # Flatten the combined values, so it will not consume huge
+            # memory during merging sort.
+            ser = self.flattened_serializer()
+            sorter = ExternalSorter(self.memory_limit, ser)
+            sorted_items = sorter.sorted(itertools.chain(*disk_items),
+                                         key=operator.itemgetter(0))
+        return ((k, vs) for k, vs in GroupByKey(sorted_items))
+
+
 if __name__ == "__main__":
     import doctest
     doctest.testmod()
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index dd8d3b1c53733cd96711b33e7caffb43e35d3e48..0bd5d20f7877f3782eddc4f4f1a75047115ead68 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -31,6 +31,7 @@ import tempfile
 import time
 import zipfile
 import random
+import itertools
 import threading
 import hashlib
 
@@ -76,7 +77,7 @@ SPARK_HOME = os.environ["SPARK_HOME"]
 class MergerTests(unittest.TestCase):
 
     def setUp(self):
-        self.N = 1 << 14
+        self.N = 1 << 12
         self.l = [i for i in xrange(self.N)]
         self.data = zip(self.l, self.l)
         self.agg = Aggregator(lambda x: [x],
@@ -108,7 +109,7 @@ class MergerTests(unittest.TestCase):
                          sum(xrange(self.N)))
 
     def test_medium_dataset(self):
-        m = ExternalMerger(self.agg, 10)
+        m = ExternalMerger(self.agg, 30)
         m.mergeValues(self.data)
         self.assertTrue(m.spills >= 1)
         self.assertEqual(sum(sum(v) for k, v in m.iteritems()),
@@ -124,10 +125,36 @@ class MergerTests(unittest.TestCase):
         m = ExternalMerger(self.agg, 10, partitions=3)
         m.mergeCombiners(map(lambda (k, v): (k, [str(v)]), self.data * 10))
         self.assertTrue(m.spills >= 1)
-        self.assertEqual(sum(len(v) for k, v in m._recursive_merged_items(0)),
+        self.assertEqual(sum(len(v) for k, v in m.iteritems()),
                          self.N * 10)
         m._cleanup()
 
+    def test_group_by_key(self):
+
+        def gen_data(N, step):
+            for i in range(1, N + 1, step):
+                for j in range(i):
+                    yield (i, [j])
+
+        def gen_gs(N, step=1):
+            return shuffle.GroupByKey(gen_data(N, step))
+
+        self.assertEqual(1, len(list(gen_gs(1))))
+        self.assertEqual(2, len(list(gen_gs(2))))
+        self.assertEqual(100, len(list(gen_gs(100))))
+        self.assertEqual(range(1, 101), [k for k, _ in gen_gs(100)])
+        self.assertTrue(all(range(k) == list(vs) for k, vs in gen_gs(100)))
+
+        for k, vs in gen_gs(50002, 10000):
+            self.assertEqual(k, len(vs))
+            self.assertEqual(range(k), list(vs))
+
+        ser = PickleSerializer()
+        l = ser.loads(ser.dumps(list(gen_gs(50002, 30000))))
+        for k, vs in l:
+            self.assertEqual(k, len(vs))
+            self.assertEqual(range(k), list(vs))
+
 
 class SorterTests(unittest.TestCase):
     def test_in_memory_sort(self):
@@ -702,6 +729,21 @@ class RDDTests(ReusedPySparkTestCase):
         self.assertEquals(result.getNumPartitions(), 5)
         self.assertEquals(result.count(), 3)
 
+    def test_external_group_by_key(self):
+        self.sc._conf.set("spark.python.worker.memory", "5m")
+        N = 200001
+        kv = self.sc.parallelize(range(N)).map(lambda x: (x % 3, x))
+        gkv = kv.groupByKey().cache()
+        self.assertEqual(3, gkv.count())
+        filtered = gkv.filter(lambda (k, vs): k == 1)
+        self.assertEqual(1, filtered.count())
+        self.assertEqual([(1, N/3)], filtered.mapValues(len).collect())
+        self.assertEqual([(N/3, N/3)],
+                         filtered.values().map(lambda x: (len(x), len(list(x)))).collect())
+        result = filtered.collect()[0][1]
+        self.assertEqual(N/3, len(result))
+        self.assertTrue(isinstance(result.data, shuffle.ExternalList))
+
     def test_sort_on_empty_rdd(self):
         self.assertEqual([], self.sc.parallelize(zip([], [])).sortByKey().collect())
 
@@ -752,9 +794,9 @@ class RDDTests(ReusedPySparkTestCase):
         self.assertEqual(rdd.getNumPartitions() + 2, parted.union(rdd).getNumPartitions())
         self.assertEqual(rdd.getNumPartitions() + 2, rdd.union(parted).getNumPartitions())
 
-        self.sc.setJobGroup("test1", "test", True)
         tracker = self.sc.statusTracker()
 
+        self.sc.setJobGroup("test1", "test", True)
         d = sorted(parted.join(parted).collect())
         self.assertEqual(10, len(d))
         self.assertEqual((0, (0, 0)), d[0])