From 16461e2edab5253e1cc3b9b8a74cc8adbb6b9be3 Mon Sep 17 00:00:00 2001
From: Reynold Xin <rxin@cs.berkeley.edu>
Date: Tue, 15 May 2012 00:31:52 -0700
Subject: [PATCH] Updated Cache's put method to use a case class for response.
 Previously it was pretty ugly that put() should return -1 for failures.

---
 .../main/scala/spark/BoundedMemoryCache.scala |  6 +--
 core/src/main/scala/spark/Cache.scala         | 16 +++++---
 core/src/main/scala/spark/CacheTracker.scala  | 39 ++++++++++++-------
 .../main/scala/spark/DiskSpillingCache.scala  |  2 +-
 .../main/scala/spark/SerializingCache.scala   |  2 +-
 .../main/scala/spark/SoftReferenceCache.scala |  4 +-
 core/src/main/scala/spark/Utils.scala         |  2 +-
 core/src/test/scala/spark/Utils.scala         | 14 +++----
 8 files changed, 50 insertions(+), 35 deletions(-)

diff --git a/core/src/main/scala/spark/BoundedMemoryCache.scala b/core/src/main/scala/spark/BoundedMemoryCache.scala
index c25d0a62df..f778d8cc17 100644
--- a/core/src/main/scala/spark/BoundedMemoryCache.scala
+++ b/core/src/main/scala/spark/BoundedMemoryCache.scala
@@ -30,7 +30,7 @@ class BoundedMemoryCache extends Cache with Logging {
     }
   }
 
-  override def put(datasetId: Any, partition: Int, value: Any): Long = {
+  override def put(datasetId: Any, partition: Int, value: Any): CachePutResponse = {
     val key = (datasetId, partition)
     logInfo("Asked to add key " + key)
     val startTime = System.currentTimeMillis
@@ -44,10 +44,10 @@ class BoundedMemoryCache extends Cache with Logging {
         map.put(key, new Entry(value, size))
         currentBytes += size
         logInfo("Number of entries is now " + map.size)
-        return size
+        return CachePutSuccess(size)
       } else {
         logInfo("Didn't add key " + key + " because we would have evicted part of same dataset")
-        return -1L
+        return CachePutFailure()
       }
     }
   }
diff --git a/core/src/main/scala/spark/Cache.scala b/core/src/main/scala/spark/Cache.scala
index a65d3b478d..aeff205884 100644
--- a/core/src/main/scala/spark/Cache.scala
+++ b/core/src/main/scala/spark/Cache.scala
@@ -2,6 +2,10 @@ package spark
 
 import java.util.concurrent.atomic.AtomicInteger
 
+sealed trait CachePutResponse
+case class CachePutSuccess(size: Long) extends CachePutResponse
+case class CachePutFailure extends CachePutResponse
+
 /**
  * An interface for caches in Spark, to allow for multiple implementations. Caches are used to store
  * both partitions of cached RDDs and broadcast variables on Spark executors. Caches are also aware
@@ -31,12 +35,12 @@ abstract class Cache {
   def get(datasetId: Any, partition: Int): Any
 
   /**
-   * Attempt to put a value in the cache; returns a negative number if this was
-   * not successful (e.g. because the cache replacement policy forbids it). If
-   * size estimation is available, the cache implementation should return the
-   * estimated size of the partition if the partition is successfully cached.
+   * Attempt to put a value in the cache; returns CachePutFailure if this was
+   * not successful (e.g. because the cache replacement policy forbids it), and
+   * CachePutSuccess if successful. If size estimation is available, the cache
+   * implementation should set the size field in CachePutSuccess.
    */
-  def put(datasetId: Any, partition: Int, value: Any): Long
+  def put(datasetId: Any, partition: Int, value: Any): CachePutResponse
 
   /**
    * Report the capacity of the cache partition. By default this just reports
@@ -52,7 +56,7 @@ class KeySpace(cache: Cache, val keySpaceId: Int) {
   def get(datasetId: Any, partition: Int): Any =
     cache.get((keySpaceId, datasetId), partition)
 
-  def put(datasetId: Any, partition: Int, value: Any): Long =
+  def put(datasetId: Any, partition: Int, value: Any): CachePutResponse =
     cache.put((keySpaceId, datasetId), partition, value)
 
   def getCapacity: Long = cache.getCapacity
diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala
index b472dc8070..5b5831b2de 100644
--- a/core/src/main/scala/spark/CacheTracker.scala
+++ b/core/src/main/scala/spark/CacheTracker.scala
@@ -20,15 +20,19 @@ case object StopCacheTracker extends CacheTrackerMessage
 
 
 class CacheTrackerActor extends DaemonActor with Logging {
-  val locs = new HashMap[Int, Array[List[String]]]
+  private val locs = new HashMap[Int, Array[List[String]]]
 
   /**
    * A map from the slave's host name to its cache size.
    */
-  val slaveCapacity = new HashMap[String, Long]
-  val slaveUsage = new HashMap[String, Long]
+  private val slaveCapacity = new HashMap[String, Long]
+  private val slaveUsage = new HashMap[String, Long]
 
   // TODO: Should probably store (String, CacheType) tuples
+
+  private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L)
+  private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L)
+  private def getCacheAvailable(host: String): Long = getCacheCapacity(host) - getCacheUsage(host)
   
   def act() {
     val port = System.getProperty("spark.master.port").toInt
@@ -39,7 +43,8 @@ class CacheTrackerActor extends DaemonActor with Logging {
     loop {
       react {
         case SlaveCacheStarted(host: String, size: Long) =>
-          logInfo("Started slave cache (size %s) on %s".format(Utils.sizeWithSuffix(size), host))
+          logInfo("Started slave cache (size %s) on %s".format(
+            Utils.memoryBytesToString(size), host))
           slaveCapacity.put(host, size)
           slaveUsage.put(host, 0)
           reply('OK)
@@ -51,9 +56,10 @@ class CacheTrackerActor extends DaemonActor with Logging {
         
         case AddedToCache(rddId, partition, host, size) =>
           if (size > 0) {
-            logInfo("Cache entry added: (%s, %s) on %s, size: %s".format(
-              rddId, partition, host, Utils.sizeWithSuffix(size)))
             slaveUsage.put(host, slaveUsage.getOrElse(host, 0L) + size)
+            logInfo("Cache entry added: (%s, %s) on %s (size added: %s, available: %s)".format(
+              rddId, partition, host, Utils.memoryBytesToString(size),
+              Utils.memoryBytesToString(getCacheAvailable(host))))
           } else {
             logInfo("Cache entry added: (%s, %s) on %s".format(rddId, partition, host))
           }
@@ -62,8 +68,9 @@ class CacheTrackerActor extends DaemonActor with Logging {
           
         case DroppedFromCache(rddId, partition, host, size) =>
           if (size > 0) {
-            logInfo("Cache entry removed: (%s, %s) on %s, size: %s".format(
-              rddId, partition, host, Utils.sizeWithSuffix(size)))
+            logInfo("Cache entry removed: (%s, %s) on %s (size dropped: %s, available: %s)".format(
+              rddId, partition, host, Utils.memoryBytesToString(size),
+              Utils.memoryBytesToString(getCacheAvailable(host))))
             slaveUsage.put(host, slaveUsage.getOrElse(host, 0L) - size)
 
             // Do a sanity check to make sure usage is greater than 0.
@@ -199,10 +206,10 @@ class CacheTracker(isMaster: Boolean, theCache: Cache) extends Logging {
       // TODO: fetch any remote copy of the split that may be available
       logInfo("Computing partition " + split)
       var array: Array[T] = null
-      var putRetval: Long = -1L
+      var putResponse: CachePutResponse = null
       try {
         array = rdd.compute(split).toArray(m)
-        putRetval = cache.put(rdd.id, split.index, array)
+        putResponse = cache.put(rdd.id, split.index, array)
       } finally {
         // Tell other threads that we've finished our attempt to load the key (whether or not
         // we've actually succeeded to put it in the map)
@@ -211,10 +218,14 @@ class CacheTracker(isMaster: Boolean, theCache: Cache) extends Logging {
           loading.notifyAll()
         }
       }
-      if (putRetval >= 0) {
-        // Tell the master that we added the entry. Don't return until it replies so it can
-        // properly schedule future tasks that use this RDD.
-        trackerActor !? AddedToCache(rdd.id, split.index, host, putRetval)
+
+      putResponse match {
+        case CachePutSuccess(size) => {
+          // Tell the master that we added the entry. Don't return until it
+          // replies so it can properly schedule future tasks that use this RDD.
+          trackerActor !? AddedToCache(rdd.id, split.index, host, size)
+        }
+        case _ => null
       }
       return array.iterator
     }
diff --git a/core/src/main/scala/spark/DiskSpillingCache.scala b/core/src/main/scala/spark/DiskSpillingCache.scala
index 037ed78688..e11466eb64 100644
--- a/core/src/main/scala/spark/DiskSpillingCache.scala
+++ b/core/src/main/scala/spark/DiskSpillingCache.scala
@@ -44,7 +44,7 @@ class DiskSpillingCache extends BoundedMemoryCache {
     }
   }
 
-  override def put(datasetId: Any, partition: Int, value: Any): Long = {
+  override def put(datasetId: Any, partition: Int, value: Any): CachePutResponse = {
     var ser = SparkEnv.get.serializer.newInstance()
     super.put(datasetId, partition, ser.serialize(value))
   }
diff --git a/core/src/main/scala/spark/SerializingCache.scala b/core/src/main/scala/spark/SerializingCache.scala
index 17dc735d5e..3d192f2403 100644
--- a/core/src/main/scala/spark/SerializingCache.scala
+++ b/core/src/main/scala/spark/SerializingCache.scala
@@ -9,7 +9,7 @@ import java.io._
 class SerializingCache extends Cache with Logging {
   val bmc = new BoundedMemoryCache
 
-  override def put(datasetId: Any, partition: Int, value: Any): Long = {
+  override def put(datasetId: Any, partition: Int, value: Any): CachePutResponse = {
     val ser = SparkEnv.get.serializer.newInstance()
     bmc.put(datasetId, partition, ser.serialize(value))
   }
diff --git a/core/src/main/scala/spark/SoftReferenceCache.scala b/core/src/main/scala/spark/SoftReferenceCache.scala
index cd2386eb83..ce9370c5d7 100644
--- a/core/src/main/scala/spark/SoftReferenceCache.scala
+++ b/core/src/main/scala/spark/SoftReferenceCache.scala
@@ -11,8 +11,8 @@ class SoftReferenceCache extends Cache {
   override def get(datasetId: Any, partition: Int): Any =
     map.get((datasetId, partition))
 
-  override def put(datasetId: Any, partition: Int, value: Any): Long = {
+  override def put(datasetId: Any, partition: Int, value: Any): CachePutResponse = {
     map.put((datasetId, partition), value)
-    return 0
+    return CachePutSuccess(0)
   }
 }
diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala
index 5aecbdde7d..d108c14f6b 100644
--- a/core/src/main/scala/spark/Utils.scala
+++ b/core/src/main/scala/spark/Utils.scala
@@ -180,7 +180,7 @@ object Utils {
    * Petabyte) in order to reduce the number of digits to four or less. For
    * example, 4,000,000 is returned as 4MB.
    */
-  def sizeWithSuffix(size: Long): String = {
+  def memoryBytesToString(size: Long): String = {
     val GB = 1L << 30
     val MB = 1L << 20
     val KB = 1L << 10
diff --git a/core/src/test/scala/spark/Utils.scala b/core/src/test/scala/spark/Utils.scala
index b78b638bb1..4e852903be 100644
--- a/core/src/test/scala/spark/Utils.scala
+++ b/core/src/test/scala/spark/Utils.scala
@@ -5,13 +5,13 @@ import org.scalatest.FunSuite
 
 class UtilsSuite extends FunSuite {
 
-  test("sizeWithSuffix") {
-    assert(Utils.sizeWithSuffix(10) === "10.0B")
-    assert(Utils.sizeWithSuffix(1500) === "1500.0B")
-    assert(Utils.sizeWithSuffix(2000000) === "1953.1KB")
-    assert(Utils.sizeWithSuffix(2097152) === "2.0MB")
-    assert(Utils.sizeWithSuffix(2306867) === "2.2MB")
-    assert(Utils.sizeWithSuffix(5368709120L) === "5.0GB")
+  test("memoryBytesToString") {
+    assert(Utils.memoryBytesToString(10) === "10.0B")
+    assert(Utils.memoryBytesToString(1500) === "1500.0B")
+    assert(Utils.memoryBytesToString(2000000) === "1953.1KB")
+    assert(Utils.memoryBytesToString(2097152) === "2.0MB")
+    assert(Utils.memoryBytesToString(2306867) === "2.2MB")
+    assert(Utils.memoryBytesToString(5368709120L) === "5.0GB")
   }
 
 }
-- 
GitLab