From f48742683adf8ed18b0d25a724a13c66b3fc12e9 Mon Sep 17 00:00:00 2001 From: Matei Zaharia <matei@eecs.berkeley.edu> Date: Sun, 6 May 2012 20:14:40 -0700 Subject: [PATCH] Made caches dataset-aware so that they won't cyclically evict partitions from the same dataset. --- .../main/scala/spark/BoundedMemoryCache.scala | 51 +++++++++------ core/src/main/scala/spark/Cache.scala | 25 +++++--- core/src/main/scala/spark/CacheTracker.scala | 62 ++++++++++++------- .../main/scala/spark/DiskSpillingCache.scala | 36 +++++------ .../main/scala/spark/SerializingCache.scala | 8 +-- .../main/scala/spark/SoftReferenceCache.scala | 9 ++- .../main/scala/spark/WeakReferenceCache.scala | 14 ----- .../spark/broadcast/BitTorrentBroadcast.scala | 8 +-- .../spark/broadcast/ChainedBroadcast.scala | 8 +-- .../scala/spark/broadcast/DfsBroadcast.scala | 6 +- .../scala/spark/broadcast/TreeBroadcast.scala | 8 +-- 11 files changed, 132 insertions(+), 103 deletions(-) delete mode 100644 core/src/main/scala/spark/WeakReferenceCache.scala diff --git a/core/src/main/scala/spark/BoundedMemoryCache.scala b/core/src/main/scala/spark/BoundedMemoryCache.scala index e8e50ac360..c49be803e4 100644 --- a/core/src/main/scala/spark/BoundedMemoryCache.scala +++ b/core/src/main/scala/spark/BoundedMemoryCache.scala @@ -14,14 +14,14 @@ class BoundedMemoryCache extends Cache with Logging { logInfo("BoundedMemoryCache.maxBytes = " + maxBytes) private var currentBytes = 0L - private val map = new LinkedHashMap[Any, Entry](32, 0.75f, true) + private val map = new LinkedHashMap[(Any, Int), Entry](32, 0.75f, true) // An entry in our map; stores a cached object and its size in bytes class Entry(val value: Any, val size: Long) {} - override def get(key: Any): Any = { + override def get(datasetId: Any, partition: Int): Any = { synchronized { - val entry = map.get(key) + val entry = map.get((datasetId, partition)) if (entry != null) { entry.value } else { @@ -30,7 +30,8 @@ class BoundedMemoryCache extends Cache with Logging { } } - override def put(key: Any, value: Any) { + override def put(datasetId: Any, partition: Int, value: Any): Boolean = { + val key = (datasetId, partition) logInfo("Asked to add key " + key) val startTime = System.currentTimeMillis val size = SizeEstimator.estimate(value.asInstanceOf[AnyRef]) @@ -38,11 +39,16 @@ class BoundedMemoryCache extends Cache with Logging { logInfo("Estimated size for key %s is %d".format(key, size)) logInfo("Size estimation for key %s took %d ms".format(key, timeTaken)) synchronized { - ensureFreeSpace(size) - logInfo("Adding key " + key) - map.put(key, new Entry(value, size)) - currentBytes += size - logInfo("Number of entries is now " + map.size) + if (ensureFreeSpace(datasetId, size)) { + logInfo("Adding key " + key) + map.put(key, new Entry(value, size)) + currentBytes += size + logInfo("Number of entries is now " + map.size) + return true + } else { + logInfo("Didn't add key " + key + " because we would have evicted part of same dataset") + return false + } } } @@ -53,23 +59,32 @@ class BoundedMemoryCache extends Cache with Logging { } /** - * Remove least recently used entries from the map until at least space bytes are free. Assumes + * Remove least recently used entries from the map until at least space bytes are free, in order + * to make space for a partition from the given dataset ID. If this cannot be done without + * evicting other data from the same dataset, returns false; otherwise, returns true. Assumes * that a lock is held on the BoundedMemoryCache. */ - private def ensureFreeSpace(space: Long) { - logInfo("ensureFreeSpace(%d) called with curBytes=%d, maxBytes=%d".format( - space, currentBytes, maxBytes)) - val iter = map.entrySet.iterator + private def ensureFreeSpace(datasetId: Any, space: Long): Boolean = { + logInfo("ensureFreeSpace(%s, %d) called with curBytes=%d, maxBytes=%d".format( + datasetId, space, currentBytes, maxBytes)) + val iter = map.entrySet.iterator // Will give entries in LRU order while (maxBytes - currentBytes < space && iter.hasNext) { val mapEntry = iter.next() - dropEntry(mapEntry.getKey, mapEntry.getValue) + val (entryDatasetId, entryPartition) = mapEntry.getKey + if (entryDatasetId == datasetId) { + // Cannot make space without removing part of the same dataset, or a more recently used one + return false + } + reportEntryDropped(entryDatasetId, entryPartition, mapEntry.getValue) currentBytes -= mapEntry.getValue.size iter.remove() } + return true } - protected def dropEntry(key: Any, entry: Entry) { - logInfo("Dropping key %s of size %d to make space".format(key, entry.size)) - SparkEnv.get.cacheTracker.dropEntry(key) + protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) { + logInfo("Dropping key (%s, %d) of size %d to make space".format( + datasetId, partition, entry.size)) + SparkEnv.get.cacheTracker.dropEntry(datasetId, partition) } } diff --git a/core/src/main/scala/spark/Cache.scala b/core/src/main/scala/spark/Cache.scala index 696fff4e5e..263761bb95 100644 --- a/core/src/main/scala/spark/Cache.scala +++ b/core/src/main/scala/spark/Cache.scala @@ -1,10 +1,12 @@ package spark -import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.atomic.AtomicInteger /** * An interface for caches in Spark, to allow for multiple implementations. Caches are used to store - * both partitions of cached RDDs and broadcast variables on Spark executors. + * both partitions of cached RDDs and broadcast variables on Spark executors. Caches are also aware + * of which entries are part of the same dataset (for example, partitions in the same RDD). The key + * for each value in a cache is a (datasetID, partition) pair. * * A single Cache instance gets created on each machine and is shared by all caches (i.e. both the * RDD split cache and the broadcast variable cache), to enable global replacement policies. @@ -17,19 +19,26 @@ import java.util.concurrent.atomic.AtomicLong * keys that are unique across modules. */ abstract class Cache { - private val nextKeySpaceId = new AtomicLong(0) + private val nextKeySpaceId = new AtomicInteger(0) private def newKeySpaceId() = nextKeySpaceId.getAndIncrement() def newKeySpace() = new KeySpace(this, newKeySpaceId()) - def get(key: Any): Any - def put(key: Any, value: Any): Unit + // Get the value for a given (datasetId, partition), or null if it is not found. + def get(datasetId: Any, partition: Int): Any + + // Attempt to put a value in the cache; returns false if this was not successful (e.g. because + // the cache replacement policy forbids it). + def put(datasetId: Any, partition: Int, value: Any): Boolean } /** * A key namespace in a Cache. */ -class KeySpace(cache: Cache, id: Long) { - def get(key: Any): Any = cache.get((id, key)) - def put(key: Any, value: Any): Unit = cache.put((id, key), value) +class KeySpace(cache: Cache, val keySpaceId: Int) { + def get(datasetId: Any, partition: Int): Any = + cache.get((keySpaceId, datasetId), partition) + + def put(datasetId: Any, partition: Int, value: Any): Boolean = + cache.put((keySpaceId, datasetId), partition, value) } diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala index 5b6eed743f..c399748af3 100644 --- a/core/src/main/scala/spark/CacheTracker.scala +++ b/core/src/main/scala/spark/CacheTracker.scala @@ -106,51 +106,65 @@ class CacheTracker(isMaster: Boolean, theCache: Cache) extends Logging { // Gets or computes an RDD split def getOrCompute[T](rdd: RDD[T], split: Split)(implicit m: ClassManifest[T]): Iterator[T] = { - val key = (rdd.id, split.index) - logInfo("CachedRDD partition key is " + key) - val cachedVal = cache.get(key) + logInfo("Looking for RDD partition %d:%d".format(rdd.id, split.index)) + val cachedVal = cache.get(rdd.id, split.index) if (cachedVal != null) { // Split is in cache, so just return its values logInfo("Found partition in cache!") return cachedVal.asInstanceOf[Array[T]].iterator } else { // Mark the split as loading (unless someone else marks it first) + val key = (rdd.id, split.index) loading.synchronized { - if (loading.contains(key)) { - while (loading.contains(key)) { - try {loading.wait()} catch {case _ =>} - } - return cache.get(key).asInstanceOf[Array[T]].iterator - } else { - loading.add(key) + while (loading.contains(key)) { + // Someone else is loading it; let's wait for them + try { loading.wait() } catch { case _ => } + } + // See whether someone else has successfully loaded it. The main way this would fail + // is for the RDD-level cache eviction policy if someone else has loaded the same RDD + // partition but we didn't want to make space for it. However, that case is unlikely + // because it's unlikely that two threads would work on the same RDD partition. One + // downside of the current code is that threads wait serially if this does happen. + val cachedVal = cache.get(rdd.id, split.index) + if (cachedVal != null) { + return cachedVal.asInstanceOf[Array[T]].iterator } + // Nobody's loading it and it's not in the cache; let's load it ourselves + loading.add(key) } // If we got here, we have to load the split // Tell the master that we're doing so val host = System.getProperty("spark.hostname", Utils.localHostName) - val future = trackerActor !! AddedToCache(rdd.id, split.index, host) // TODO: fetch any remote copy of the split that may be available - // TODO: also register a listener for when it unloads logInfo("Computing partition " + split) - val array = rdd.compute(split).toArray(m) - cache.put(key, array) - loading.synchronized { - loading.remove(key) - loading.notifyAll() + var array: Array[T] = null + var putSuccessful: Boolean = false + try { + array = rdd.compute(split).toArray(m) + putSuccessful = cache.put(rdd.id, split.index, array) + } finally { + // Tell other threads that we've finished our attempt to load the key (whether or not + // we've actually succeeded to put it in the map) + loading.synchronized { + loading.remove(key) + loading.notifyAll() + } + } + if (putSuccessful) { + // Tell the master that we added the entry. Don't return until it replies so it can + // properly schedule future tasks that use this RDD. + trackerActor !? AddedToCache(rdd.id, split.index, host) } - future.apply() // Wait for the reply from the cache tracker return array.iterator } } - // Reports that an entry has been dropped from the cache - def dropEntry(key: Any) { - key match { - case (keySpaceId: Long, (rddId: Int, partition: Int)) => + // Called by the Cache to report that an entry has been dropped from it + def dropEntry(datasetId: Any, partition: Int) { + datasetId match { + case (cache.keySpaceId, rddId: Int) => val host = System.getProperty("spark.hostname", Utils.localHostName) trackerActor !! DroppedFromCache(rddId, partition, host) - case _ => - logWarning("Unknown key format: %s".format(key)) } } diff --git a/core/src/main/scala/spark/DiskSpillingCache.scala b/core/src/main/scala/spark/DiskSpillingCache.scala index 157e071c7f..e4d0f991aa 100644 --- a/core/src/main/scala/spark/DiskSpillingCache.scala +++ b/core/src/main/scala/spark/DiskSpillingCache.scala @@ -9,31 +9,31 @@ import java.util.UUID // TODO: cache into a separate directory using Utils.createTempDir // TODO: clean up disk cache afterwards class DiskSpillingCache extends BoundedMemoryCache { - private val diskMap = new LinkedHashMap[Any, File](32, 0.75f, true) + private val diskMap = new LinkedHashMap[(Any, Int), File](32, 0.75f, true) - override def get(key: Any): Any = { + override def get(datasetId: Any, partition: Int): Any = { synchronized { val ser = SparkEnv.get.serializer.newInstance() - super.get(key) match { + super.get(datasetId, partition) match { case bytes: Any => // found in memory ser.deserialize(bytes.asInstanceOf[Array[Byte]]) - case _ => diskMap.get(key) match { + case _ => diskMap.get((datasetId, partition)) match { case file: Any => // found on disk try { val startTime = System.currentTimeMillis val bytes = new Array[Byte](file.length.toInt) new FileInputStream(file).read(bytes) val timeTaken = System.currentTimeMillis - startTime - logInfo("Reading key %s of size %d bytes from disk took %d ms".format( - key, file.length, timeTaken)) - super.put(key, bytes) + logInfo("Reading key (%s, %d) of size %d bytes from disk took %d ms".format( + datasetId, partition, file.length, timeTaken)) + super.put(datasetId, partition, bytes) ser.deserialize(bytes.asInstanceOf[Array[Byte]]) } catch { case e: IOException => - logWarning("Failed to read key %s from disk at %s: %s".format( - key, file.getPath(), e.getMessage())) - diskMap.remove(key) // remove dead entry + logWarning("Failed to read key (%s, %d) from disk at %s: %s".format( + datasetId, partition, file.getPath(), e.getMessage())) + diskMap.remove((datasetId, partition)) // remove dead entry null } @@ -44,18 +44,18 @@ class DiskSpillingCache extends BoundedMemoryCache { } } - override def put(key: Any, value: Any) { + override def put(datasetId: Any, partition: Int, value: Any): Boolean = { var ser = SparkEnv.get.serializer.newInstance() - super.put(key, ser.serialize(value)) + super.put(datasetId, partition, ser.serialize(value)) } /** * Spill the given entry to disk. Assumes that a lock is held on the * DiskSpillingCache. Assumes that entry.value is a byte array. */ - override protected def dropEntry(key: Any, entry: Entry) { - logInfo("Spilling key %s of size %d to make space".format( - key, entry.size)) + override protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) { + logInfo("Spilling key (%s, %d) of size %d to make space".format( + datasetId, partition, entry.size)) val cacheDir = System.getProperty( "spark.diskSpillingCache.cacheDir", System.getProperty("java.io.tmpdir")) @@ -64,11 +64,11 @@ class DiskSpillingCache extends BoundedMemoryCache { val stream = new FileOutputStream(file) stream.write(entry.value.asInstanceOf[Array[Byte]]) stream.close() - diskMap.put(key, file) + diskMap.put((datasetId, partition), file) } catch { case e: IOException => - logWarning("Failed to spill key %s to disk at %s: %s".format( - key, file.getPath(), e.getMessage())) + logWarning("Failed to spill key (%s, %d) to disk at %s: %s".format( + datasetId, partition, file.getPath(), e.getMessage())) // Do nothing and let the entry be discarded } } diff --git a/core/src/main/scala/spark/SerializingCache.scala b/core/src/main/scala/spark/SerializingCache.scala index a74922ec4c..f6964905c7 100644 --- a/core/src/main/scala/spark/SerializingCache.scala +++ b/core/src/main/scala/spark/SerializingCache.scala @@ -9,13 +9,13 @@ import java.io._ class SerializingCache extends Cache with Logging { val bmc = new BoundedMemoryCache - override def put(key: Any, value: Any) { + override def put(datasetId: Any, partition: Int, value: Any): Boolean = { val ser = SparkEnv.get.serializer.newInstance() - bmc.put(key, ser.serialize(value)) + bmc.put(datasetId, partition, ser.serialize(value)) } - override def get(key: Any): Any = { - val bytes = bmc.get(key) + override def get(datasetId: Any, partition: Int): Any = { + val bytes = bmc.get(datasetId, partition) if (bytes != null) { val ser = SparkEnv.get.serializer.newInstance() return ser.deserialize(bytes.asInstanceOf[Array[Byte]]) diff --git a/core/src/main/scala/spark/SoftReferenceCache.scala b/core/src/main/scala/spark/SoftReferenceCache.scala index e84aa57efa..c507df928b 100644 --- a/core/src/main/scala/spark/SoftReferenceCache.scala +++ b/core/src/main/scala/spark/SoftReferenceCache.scala @@ -8,6 +8,11 @@ import com.google.common.collect.MapMaker class SoftReferenceCache extends Cache { val map = new MapMaker().softValues().makeMap[Any, Any]() - override def get(key: Any): Any = map.get(key) - override def put(key: Any, value: Any) = map.put(key, value) + override def get(datasetId: Any, partition: Int): Any = + map.get((datasetId, partition)) + + override def put(datasetId: Any, partition: Int, value: Any): Boolean = { + map.put((datasetId, partition), value) + return true + } } diff --git a/core/src/main/scala/spark/WeakReferenceCache.scala b/core/src/main/scala/spark/WeakReferenceCache.scala deleted file mode 100644 index ddca065454..0000000000 --- a/core/src/main/scala/spark/WeakReferenceCache.scala +++ /dev/null @@ -1,14 +0,0 @@ -package spark - -import com.google.common.collect.MapMaker - -/** - * An implementation of Cache that uses weak references. - */ -class WeakReferenceCache extends Cache { - val map = new MapMaker().weakValues().makeMap[Any, Any]() - - override def get(key: Any): Any = map.get(key) - override def put(key: Any, value: Any) = map.put(key, value) -} - diff --git a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala index 6960339bf8..5a873dca3d 100644 --- a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala @@ -16,7 +16,7 @@ extends Broadcast[T] with Logging with Serializable { def value = value_ BitTorrentBroadcast.synchronized { - BitTorrentBroadcast.values.put(uuid, value_) + BitTorrentBroadcast.values.put(uuid, 0, value_) } @transient var arrayOfBlocks: Array[BroadcastBlock] = null @@ -130,7 +130,7 @@ extends Broadcast[T] with Logging with Serializable { private def readObject(in: ObjectInputStream): Unit = { in.defaultReadObject BitTorrentBroadcast.synchronized { - val cachedVal = BitTorrentBroadcast.values.get(uuid) + val cachedVal = BitTorrentBroadcast.values.get(uuid, 0) if (cachedVal != null) { value_ = cachedVal.asInstanceOf[T] @@ -152,12 +152,12 @@ extends Broadcast[T] with Logging with Serializable { // If does not succeed, then get from HDFS copy if (receptionSucceeded) { value_ = Broadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) - BitTorrentBroadcast.values.put(uuid, value_) + BitTorrentBroadcast.values.put(uuid, 0, value_) } else { // TODO: This part won't work, cause HDFS writing is turned OFF val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid)) value_ = fileIn.readObject.asInstanceOf[T] - BitTorrentBroadcast.values.put(uuid, value_) + BitTorrentBroadcast.values.put(uuid, 0, value_) fileIn.close() } diff --git a/core/src/main/scala/spark/broadcast/ChainedBroadcast.scala b/core/src/main/scala/spark/broadcast/ChainedBroadcast.scala index e33ef78e8a..64da650142 100644 --- a/core/src/main/scala/spark/broadcast/ChainedBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/ChainedBroadcast.scala @@ -15,7 +15,7 @@ extends Broadcast[T] with Logging with Serializable { def value = value_ ChainedBroadcast.synchronized { - ChainedBroadcast.values.put(uuid, value_) + ChainedBroadcast.values.put(uuid, 0, value_) } @transient var arrayOfBlocks: Array[BroadcastBlock] = null @@ -101,7 +101,7 @@ extends Broadcast[T] with Logging with Serializable { private def readObject(in: ObjectInputStream): Unit = { in.defaultReadObject ChainedBroadcast.synchronized { - val cachedVal = ChainedBroadcast.values.get(uuid) + val cachedVal = ChainedBroadcast.values.get(uuid, 0) if (cachedVal != null) { value_ = cachedVal.asInstanceOf[T] } else { @@ -121,11 +121,11 @@ extends Broadcast[T] with Logging with Serializable { // If does not succeed, then get from HDFS copy if (receptionSucceeded) { value_ = Broadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) - ChainedBroadcast.values.put(uuid, value_) + ChainedBroadcast.values.put(uuid, 0, value_) } else { val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid)) value_ = fileIn.readObject.asInstanceOf[T] - ChainedBroadcast.values.put(uuid, value_) + ChainedBroadcast.values.put(uuid, 0, value_) fileIn.close() } diff --git a/core/src/main/scala/spark/broadcast/DfsBroadcast.scala b/core/src/main/scala/spark/broadcast/DfsBroadcast.scala index 076f18afac..b053e2b62e 100644 --- a/core/src/main/scala/spark/broadcast/DfsBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/DfsBroadcast.scala @@ -17,7 +17,7 @@ extends Broadcast[T] with Logging with Serializable { def value = value_ DfsBroadcast.synchronized { - DfsBroadcast.values.put(uuid, value_) + DfsBroadcast.values.put(uuid, 0, value_) } if (!isLocal) { @@ -34,7 +34,7 @@ extends Broadcast[T] with Logging with Serializable { private def readObject(in: ObjectInputStream): Unit = { in.defaultReadObject DfsBroadcast.synchronized { - val cachedVal = DfsBroadcast.values.get(uuid) + val cachedVal = DfsBroadcast.values.get(uuid, 0) if (cachedVal != null) { value_ = cachedVal.asInstanceOf[T] } else { @@ -43,7 +43,7 @@ extends Broadcast[T] with Logging with Serializable { val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid)) value_ = fileIn.readObject.asInstanceOf[T] - DfsBroadcast.values.put(uuid, value_) + DfsBroadcast.values.put(uuid, 0, value_) fileIn.close val time = (System.nanoTime - start) / 1e9 diff --git a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala index 945d8cd8a4..374389def5 100644 --- a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala @@ -15,7 +15,7 @@ extends Broadcast[T] with Logging with Serializable { def value = value_ TreeBroadcast.synchronized { - TreeBroadcast.values.put(uuid, value_) + TreeBroadcast.values.put(uuid, 0, value_) } @transient var arrayOfBlocks: Array[BroadcastBlock] = null @@ -104,7 +104,7 @@ extends Broadcast[T] with Logging with Serializable { private def readObject(in: ObjectInputStream): Unit = { in.defaultReadObject TreeBroadcast.synchronized { - val cachedVal = TreeBroadcast.values.get(uuid) + val cachedVal = TreeBroadcast.values.get(uuid, 0) if (cachedVal != null) { value_ = cachedVal.asInstanceOf[T] } else { @@ -124,11 +124,11 @@ extends Broadcast[T] with Logging with Serializable { // If does not succeed, then get from HDFS copy if (receptionSucceeded) { value_ = Broadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) - TreeBroadcast.values.put(uuid, value_) + TreeBroadcast.values.put(uuid, 0, value_) } else { val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid)) value_ = fileIn.readObject.asInstanceOf[T] - TreeBroadcast.values.put(uuid, value_) + TreeBroadcast.values.put(uuid, 0, value_) fileIn.close() } -- GitLab