diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index 4cc5bcb7f9bafa600a8fa42b87c723dffbddfa09..e3556b72ad23ef23bc86a36aa6f27d020ec1b807 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -23,6 +23,7 @@ import java.util.LinkedHashMap import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.SortedSet import scala.reflect.ClassTag import com.google.common.io.ByteStreams @@ -88,6 +89,13 @@ private[spark] class MemoryStore( // Note: all changes to memory allocations, notably putting blocks, evicting blocks, and // acquiring or releasing unroll memory, must be synchronized on `memoryManager`! + private class OurBlockIdAndSizeType( + var blockId: BlockId, + var size: Long) extends Ordered[OurBlockIdAndSizeType] { + def compare(other: OurBlockIdAndSizeType): Int = this.size.compare(other.size) + } + + private val blockIdAndSizeSet = SortedSet[OurBlockIdAndSizeType]() private val entries = new LinkedHashMap[BlockId, MemoryEntry[_]](32, 0.75f, true) // A mapping from taskAttemptId to amount of memory used for unrolling a block (in bytes) @@ -151,6 +159,7 @@ private[spark] class MemoryStore( assert(bytes.size == size) val entry = new SerializedMemoryEntry[T](bytes, memoryMode, implicitly[ClassTag[T]]) entries.synchronized { + blockIdAndSizeSet += new OurBlockIdAndSizeType(blockId, entry.size) entries.put(blockId, entry) } logInfo("Block %s stored as bytes in memory (estimated size %s, free %s)".format( @@ -260,6 +269,7 @@ private[spark] class MemoryStore( } entries.synchronized { + blockIdAndSizeSet += new OurBlockIdAndSizeType(blockId, entry.size) entries.put(blockId, entry) } @@ -386,6 +396,8 @@ private[spark] class MemoryStore( def remove(blockId: BlockId): Boolean = memoryManager.synchronized { val entry = entries.synchronized { + val origEntry = entries.get(blockId) + blockIdAndSizeSet -= new OurBlockIdAndSizeType(blockId, origEntry.size) entries.remove(blockId) } if (entry != null) { @@ -404,6 +416,7 @@ private[spark] class MemoryStore( def clear(): Unit = memoryManager.synchronized { entries.synchronized { + blockIdAndSizeSet.clear() entries.clear() } onHeapUnrollMemoryMap.clear() @@ -446,18 +459,18 @@ private[spark] class MemoryStore( // (because of getValue or getBytes) while traversing the iterator, as that // can lead to exceptions. entries.synchronized { - val iterator = entries.entrySet().iterator() + val iterator = blockIdAndSizeSet.iterator while (freedMemory < space && iterator.hasNext) { - val pair = iterator.next() - val blockId = pair.getKey - val entry = pair.getValue + val idAndSize = iterator.next() + val blockId = idAndSize.blockId + val entry = entries.get(blockId) if (blockIsEvictable(blockId, entry)) { // We don't want to evict blocks which are currently being read, so we need to obtain // an exclusive write lock on blocks which are candidates for eviction. We perform a // non-blocking "tryLock" here in order to ignore blocks which are locked for reading: if (blockInfoManager.lockForWriting(blockId, blocking = false).isDefined) { selectedBlocks += blockId - freedMemory += pair.getValue.size + freedMemory += idAndSize.size } } }