diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index ca8eef5f99edfdb1fef0dbbb2172b29b43ad5c2e..d5002fa02992baded7ce80486ac43caf1873c567 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -124,6 +124,10 @@ private[spark] class PythonRDD(
               val total = finishTime - startTime
               logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot,
                 init, finish))
+              val memoryBytesSpilled = stream.readLong()
+              val diskBytesSpilled = stream.readLong()
+              context.taskMetrics.memoryBytesSpilled += memoryBytesSpilled
+              context.taskMetrics.diskBytesSpilled += diskBytesSpilled
               read()
             case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
               // Signals that an exception has been thrown in python
diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py
index 49829f5280a5fd0095a2ec4ce77bdf99240e2053..ce597cbe91e152e2a2efc4179dd5a90bd09aecb6 100644
--- a/python/pyspark/shuffle.py
+++ b/python/pyspark/shuffle.py
@@ -68,6 +68,11 @@ def _get_local_dirs(sub):
     return [os.path.join(d, "python", str(os.getpid()), sub) for d in dirs]
 
 
+# global stats
+MemoryBytesSpilled = 0L
+DiskBytesSpilled = 0L
+
+
 class Aggregator(object):
 
     """
@@ -313,10 +318,12 @@ class ExternalMerger(Merger):
 
         It will dump the data in batch for better performance.
         """
+        global MemoryBytesSpilled, DiskBytesSpilled
         path = self._get_spill_dir(self.spills)
         if not os.path.exists(path):
             os.makedirs(path)
 
+        used_memory = get_used_memory()
         if not self.pdata:
             # The data has not been partitioned, it will iterator the
             # dataset once, write them into different files, has no
@@ -334,6 +341,7 @@ class ExternalMerger(Merger):
                 self.serializer.dump_stream([(k, v)], streams[h])
 
             for s in streams:
+                DiskBytesSpilled += s.tell()
                 s.close()
 
             self.data.clear()
@@ -346,9 +354,11 @@ class ExternalMerger(Merger):
                     # dump items in batch
                     self.serializer.dump_stream(self.pdata[i].iteritems(), f)
                 self.pdata[i].clear()
+                DiskBytesSpilled += os.path.getsize(p)
 
         self.spills += 1
         gc.collect()  # release the memory as much as possible
+        MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
 
     def iteritems(self):
         """ Return all merged items as iterator """
@@ -462,7 +472,6 @@ class ExternalSorter(object):
         self.memory_limit = memory_limit
         self.local_dirs = _get_local_dirs("sort")
         self.serializer = serializer or BatchedSerializer(PickleSerializer(), 1024)
-        self._spilled_bytes = 0
 
     def _get_path(self, n):
         """ Choose one directory for spill by number n """
@@ -476,6 +485,7 @@ class ExternalSorter(object):
         Sort the elements in iterator, do external sort when the memory
         goes above the limit.
         """
+        global MemoryBytesSpilled, DiskBytesSpilled
         batch = 10
         chunks, current_chunk = [], []
         iterator = iter(iterator)
@@ -486,15 +496,18 @@ class ExternalSorter(object):
             if len(chunk) < batch:
                 break
 
-            if get_used_memory() > self.memory_limit:
+            used_memory = get_used_memory()
+            if used_memory > self.memory_limit:
                 # sort them inplace will save memory
                 current_chunk.sort(key=key, reverse=reverse)
                 path = self._get_path(len(chunks))
                 with open(path, 'w') as f:
                     self.serializer.dump_stream(current_chunk, f)
-                self._spilled_bytes += os.path.getsize(path)
                 chunks.append(self.serializer.load_stream(open(path)))
                 current_chunk = []
+                gc.collect()
+                MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
+                DiskBytesSpilled += os.path.getsize(path)
 
             elif not chunks:
                 batch = min(batch * 2, 10000)
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 747cd1767de7bee0a228f08b4b9d69c1b3bec901..f3309a20fcffb72ff579b2108249ea23f7ea900d 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -46,6 +46,7 @@ from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer,
     CloudPickleSerializer
 from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
 from pyspark.sql import SQLContext, IntegerType
+from pyspark import shuffle
 
 _have_scipy = False
 _have_numpy = False
@@ -138,17 +139,17 @@ class TestSorter(unittest.TestCase):
         random.shuffle(l)
         sorter = ExternalSorter(1)
         self.assertEquals(sorted(l), list(sorter.sorted(l)))
-        self.assertGreater(sorter._spilled_bytes, 0)
-        last = sorter._spilled_bytes
+        self.assertGreater(shuffle.DiskBytesSpilled, 0)
+        last = shuffle.DiskBytesSpilled
         self.assertEquals(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True)))
-        self.assertGreater(sorter._spilled_bytes, last)
-        last = sorter._spilled_bytes
+        self.assertGreater(shuffle.DiskBytesSpilled, last)
+        last = shuffle.DiskBytesSpilled
         self.assertEquals(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x)))
-        self.assertGreater(sorter._spilled_bytes, last)
-        last = sorter._spilled_bytes
+        self.assertGreater(shuffle.DiskBytesSpilled, last)
+        last = shuffle.DiskBytesSpilled
         self.assertEquals(sorted(l, key=lambda x: -x, reverse=True),
                           list(sorter.sorted(l, key=lambda x: -x, reverse=True)))
-        self.assertGreater(sorter._spilled_bytes, last)
+        self.assertGreater(shuffle.DiskBytesSpilled, last)
 
     def test_external_sort_in_rdd(self):
         conf = SparkConf().set("spark.python.worker.memory", "1m")
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 61b8a74d060e81f16866f3ee5ee770a9acdab14d..252176ac65fec4daa5cee0553c2b1e3166b5f5dd 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -23,16 +23,14 @@ import sys
 import time
 import socket
 import traceback
-# CloudPickler needs to be imported so that depicklers are registered using the
-# copy_reg module.
+
 from pyspark.accumulators import _accumulatorRegistry
 from pyspark.broadcast import Broadcast, _broadcastRegistry
-from pyspark.cloudpickle import CloudPickler
 from pyspark.files import SparkFiles
 from pyspark.serializers import write_with_length, write_int, read_long, \
     write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
     CompressedSerializer
-
+from pyspark import shuffle
 
 pickleSer = PickleSerializer()
 utf8_deserializer = UTF8Deserializer()
@@ -52,6 +50,11 @@ def main(infile, outfile):
         if split_index == -1:  # for unit tests
             return
 
+        # initialize global state
+        shuffle.MemoryBytesSpilled = 0
+        shuffle.DiskBytesSpilled = 0
+        _accumulatorRegistry.clear()
+
         # fetch name of workdir
         spark_files_dir = utf8_deserializer.loads(infile)
         SparkFiles._root_directory = spark_files_dir
@@ -97,6 +100,9 @@ def main(infile, outfile):
         exit(-1)
     finish_time = time.time()
     report_times(outfile, boot_time, init_time, finish_time)
+    write_long(shuffle.MemoryBytesSpilled, outfile)
+    write_long(shuffle.DiskBytesSpilled, outfile)
+
     # Mark the beginning of the accumulators section of the output
     write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
     write_int(len(_accumulatorRegistry), outfile)