diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala
index 5ddda4d6953fad24817813b4d2526a5f6792e3ab..f8584b90cabe60b46fc243350f70aefa96c494e0 100644
--- a/core/src/main/scala/org/apache/spark/CacheManager.scala
+++ b/core/src/main/scala/org/apache/spark/CacheManager.scala
@@ -68,7 +68,9 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
           // Otherwise, cache the values and keep track of any updates in block statuses
           val updatedBlocks = new ArrayBuffer[(BlockId, BlockStatus)]
           val cachedValues = putInBlockManager(key, computedValues, storageLevel, updatedBlocks)
-          context.taskMetrics.updatedBlocks = Some(updatedBlocks)
+          val metrics = context.taskMetrics
+          val lastUpdatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]())
+          metrics.updatedBlocks = Some(lastUpdatedBlocks ++ updatedBlocks.toSeq)
           new InterruptibleIterator(context, cachedValues)
 
         } finally {
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala
index 67f72a94f02690946302346685a6687e5768dc21..76097f1c51f8e0d79e3f2bf50a6294141e4b397b 100644
--- a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala
+++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala
@@ -70,8 +70,11 @@ class StorageListener(storageStatusListener: StorageStatusListener) extends Spar
   }
 
   override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = synchronized {
-    // Remove all partitions that are no longer cached
-    _rddInfoMap.retain { case (_, info) => info.numCachedPartitions > 0 }
+    // Remove all partitions that are no longer cached in current completed stage
+    val completedRddIds = stageCompleted.stageInfo.rddInfos.map(r => r.id).toSet
+    _rddInfoMap.retain { case (id, info) =>
+      !completedRddIds.contains(id) || info.numCachedPartitions > 0
+    }
   }
 
   override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD) = synchronized {
diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
index 9c5f394d3899d56206f58f261405abbaed042063..90dcadcffd091183040a0eb99922cf8960c0fb36 100644
--- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
@@ -32,6 +32,8 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
   var split: Partition = _
   /** An RDD which returns the values [1, 2, 3, 4]. */
   var rdd: RDD[Int] = _
+  var rdd2: RDD[Int] = _
+  var rdd3: RDD[Int] = _
 
   before {
     sc = new SparkContext("local", "test")
@@ -43,6 +45,16 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
       override val getDependencies = List[Dependency[_]]()
       override def compute(split: Partition, context: TaskContext) = Array(1, 2, 3, 4).iterator
     }
+    rdd2 = new RDD[Int](sc, List(new OneToOneDependency(rdd))) {
+      override def getPartitions: Array[Partition] = firstParent[Int].partitions
+      override def compute(split: Partition, context: TaskContext) =
+        firstParent[Int].iterator(split, context)
+    }.cache()
+    rdd3 = new RDD[Int](sc, List(new OneToOneDependency(rdd2))) {
+      override def getPartitions: Array[Partition] = firstParent[Int].partitions
+      override def compute(split: Partition, context: TaskContext) =
+        firstParent[Int].iterator(split, context)
+    }.cache()
   }
 
   after {
@@ -87,4 +99,11 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
       assert(value.toList === List(1, 2, 3, 4))
     }
   }
+
+  test("verify task metrics updated correctly") {
+    cacheManager = sc.env.cacheManager
+    val context = new TaskContext(0, 0, 0)
+    cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY)
+    assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2)
+  }
 }
diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala
index b860177705d8415feadad1cc06c0667f58bf973c..a537c72ce7ab549e6cefd6188185e5af9a9135cb 100644
--- a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala
@@ -34,6 +34,7 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter {
   private val memOnly = StorageLevel.MEMORY_ONLY
   private val none = StorageLevel.NONE
   private val taskInfo = new TaskInfo(0, 0, 0, 0, "big", "dog", TaskLocality.ANY, false)
+  private val taskInfo1 = new TaskInfo(1, 1, 1, 1, "big", "cat", TaskLocality.ANY, false)
   private def rddInfo0 = new RDDInfo(0, "freedom", 100, memOnly)
   private def rddInfo1 = new RDDInfo(1, "hostage", 200, memOnly)
   private def rddInfo2 = new RDDInfo(2, "sanity", 300, memAndDisk)
@@ -162,4 +163,30 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter {
     assert(storageListener._rddInfoMap(2).numCachedPartitions === 0)
   }
 
+  test("verify StorageTab contains all cached rdds") {
+
+    val rddInfo0 = new RDDInfo(0, "rdd0", 1, memOnly)
+    val rddInfo1 = new RDDInfo(1, "rdd1", 1 ,memOnly)
+    val stageInfo0 = new StageInfo(0, 0, "stage0", 1, Seq(rddInfo0), "details")
+    val stageInfo1 = new StageInfo(1, 0, "stage1", 1, Seq(rddInfo1), "details")
+    val taskMetrics0 = new TaskMetrics
+    val taskMetrics1 = new TaskMetrics
+    val block0 = (RDDBlockId(0, 1), BlockStatus(memOnly, 100L, 0L, 0L))
+    val block1 = (RDDBlockId(1, 1), BlockStatus(memOnly, 200L, 0L, 0L))
+    taskMetrics0.updatedBlocks = Some(Seq(block0))
+    taskMetrics1.updatedBlocks = Some(Seq(block1))
+    bus.postToAll(SparkListenerBlockManagerAdded(bm1, 1000L))
+    bus.postToAll(SparkListenerStageSubmitted(stageInfo0))
+    assert(storageListener.rddInfoList.size === 0)
+    bus.postToAll(SparkListenerTaskEnd(0, 0, "big", Success, taskInfo, taskMetrics0))
+    assert(storageListener.rddInfoList.size === 1)
+    bus.postToAll(SparkListenerStageSubmitted(stageInfo1))
+    assert(storageListener.rddInfoList.size === 1)
+    bus.postToAll(SparkListenerStageCompleted(stageInfo0))
+    assert(storageListener.rddInfoList.size === 1)
+    bus.postToAll(SparkListenerTaskEnd(1, 0, "small", Success, taskInfo1, taskMetrics1))
+    assert(storageListener.rddInfoList.size === 2)
+    bus.postToAll(SparkListenerStageCompleted(stageInfo1))
+    assert(storageListener.rddInfoList.size === 2)
+  }
 }