diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
index 4cc5bcb7f9bafa600a8fa42b87c723dffbddfa09..e3556b72ad23ef23bc86a36aa6f27d020ec1b807 100644
--- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
@@ -23,6 +23,7 @@ import java.util.LinkedHashMap
 
 import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.SortedSet
 import scala.reflect.ClassTag
 
 import com.google.common.io.ByteStreams
@@ -88,6 +89,13 @@ private[spark] class MemoryStore(
   // Note: all changes to memory allocations, notably putting blocks, evicting blocks, and
   // acquiring or releasing unroll memory, must be synchronized on `memoryManager`!
 
+  private class OurBlockIdAndSizeType(
+      var blockId: BlockId,
+      var size: Long) extends Ordered[OurBlockIdAndSizeType] {
+    def compare(other: OurBlockIdAndSizeType): Int = this.size.compare(other.size)
+  }
+
+  private val blockIdAndSizeSet = SortedSet[OurBlockIdAndSizeType]()
   private val entries = new LinkedHashMap[BlockId, MemoryEntry[_]](32, 0.75f, true)
 
   // A mapping from taskAttemptId to amount of memory used for unrolling a block (in bytes)
@@ -151,6 +159,7 @@ private[spark] class MemoryStore(
       assert(bytes.size == size)
       val entry = new SerializedMemoryEntry[T](bytes, memoryMode, implicitly[ClassTag[T]])
       entries.synchronized {
+        blockIdAndSizeSet += new OurBlockIdAndSizeType(blockId, entry.size)
         entries.put(blockId, entry)
       }
       logInfo("Block %s stored as bytes in memory (estimated size %s, free %s)".format(
@@ -260,6 +269,7 @@ private[spark] class MemoryStore(
         }
 
         entries.synchronized {
+          blockIdAndSizeSet += new OurBlockIdAndSizeType(blockId, entry.size)
           entries.put(blockId, entry)
         }
 
@@ -386,6 +396,8 @@ private[spark] class MemoryStore(
 
   def remove(blockId: BlockId): Boolean = memoryManager.synchronized {
     val entry = entries.synchronized {
+      val origEntry = entries.get(blockId)
+      blockIdAndSizeSet -= new OurBlockIdAndSizeType(blockId, origEntry.size)
       entries.remove(blockId)
     }
     if (entry != null) {
@@ -404,6 +416,7 @@ private[spark] class MemoryStore(
 
   def clear(): Unit = memoryManager.synchronized {
     entries.synchronized {
+      blockIdAndSizeSet.clear()
       entries.clear()
     }
     onHeapUnrollMemoryMap.clear()
@@ -446,18 +459,18 @@ private[spark] class MemoryStore(
       // (because of getValue or getBytes) while traversing the iterator, as that
       // can lead to exceptions.
       entries.synchronized {
-        val iterator = entries.entrySet().iterator()
+        val iterator = blockIdAndSizeSet.iterator
         while (freedMemory < space && iterator.hasNext) {
-          val pair = iterator.next()
-          val blockId = pair.getKey
-          val entry = pair.getValue
+          val idAndSize = iterator.next()
+          val blockId = idAndSize.blockId
+          val entry = entries.get(blockId)
           if (blockIsEvictable(blockId, entry)) {
             // We don't want to evict blocks which are currently being read, so we need to obtain
             // an exclusive write lock on blocks which are candidates for eviction. We perform a
             // non-blocking "tryLock" here in order to ignore blocks which are locked for reading:
             if (blockInfoManager.lockForWriting(blockId, blocking = false).isDefined) {
               selectedBlocks += blockId
-              freedMemory += pair.getValue.size
+              freedMemory += idAndSize.size
             }
           }
         }