From 1df5a65a0158716c5634c55d57578fd00d3f5f1f Mon Sep 17 00:00:00 2001
From: Matei Zaharia <matei@eecs.berkeley.edu>
Date: Sun, 6 Mar 2011 12:16:38 -0800
Subject: [PATCH] Pass cache locations correctly to DAGScheduler.

---
 core/src/main/scala/spark/DAGScheduler.scala  | 33 ++++----
 .../main/scala/spark/MapOutputTracker.scala   |  2 +-
 core/src/main/scala/spark/RDDCache.scala      | 76 +++++++++++++++----
 core/src/main/scala/spark/SparkContext.scala  |  1 +
 4 files changed, 80 insertions(+), 32 deletions(-)

diff --git a/core/src/main/scala/spark/DAGScheduler.scala b/core/src/main/scala/spark/DAGScheduler.scala
index ee3fda25a8..734cbea822 100644
--- a/core/src/main/scala/spark/DAGScheduler.scala
+++ b/core/src/main/scala/spark/DAGScheduler.scala
@@ -33,20 +33,14 @@ private abstract class DAGScheduler extends Scheduler with Logging {
 
   val shuffleToMapStage = new HashMap[ShuffleDependency[_,_,_], Stage]
 
-  val cacheLocs = new HashMap[RDD[_], Array[List[String]]]
+  var cacheLocs = new HashMap[Int, Array[List[String]]]
 
   def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
-    cacheLocs.getOrElseUpdate(rdd, Array.fill[List[String]](rdd.splits.size)(Nil))
+    cacheLocs(rdd.id)
   }
-
-  def addCacheLoc(rdd: RDD[_], partition: Int, host: String) {
-    val locs = getCacheLocs(rdd)
-    locs(partition) = host :: locs(partition)
-  }
-
-  def removeCacheLoc(rdd: RDD[_], partition: Int, host: String) {
-    val locs = getCacheLocs(rdd)
-    locs(partition) -= host
+  
+  def updateCacheLocs() {
+    cacheLocs = RDDCache.getLocationsSnapshot()
   }
 
   def getShuffleMapStage(shuf: ShuffleDependency[_,_,_]): Stage = {
@@ -60,6 +54,9 @@ private abstract class DAGScheduler extends Scheduler with Logging {
   }
 
   def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_,_]]): Stage = {
+    // Kind of ugly: need to register RDDs with the cache here since
+    // we can't do it in its constructor because # of splits is unknown
+    RDDCache.registerRDD(rdd.id, rdd.splits.size)
     val id = newStageId()
     val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd))
     idToStage(id) = stage
@@ -113,10 +110,10 @@ private abstract class DAGScheduler extends Scheduler with Logging {
     missing.toList
   }
 
-  override def runJob[T, U](rdd: RDD[T], func: Iterator[T] => U)(implicit m: ClassManifest[U])
+  override def runJob[T, U](finalRdd: RDD[T], func: Iterator[T] => U)(implicit m: ClassManifest[U])
       : Array[U] = {
-    val numOutputParts: Int = rdd.splits.size
-    val finalStage = newStage(rdd, None)
+    val numOutputParts: Int = finalRdd.splits.size
+    val finalStage = newStage(finalRdd, None)
     val results = new Array[U](numOutputParts)
     val finished = new Array[Boolean](numOutputParts)
     var numFinished = 0
@@ -125,6 +122,8 @@ private abstract class DAGScheduler extends Scheduler with Logging {
     val running = new HashSet[Stage]
     val pendingTasks = new HashMap[Stage, HashSet[Task[_]]]
 
+    updateCacheLocs()
+    
     logInfo("Final stage: " + finalStage)
     logInfo("Parents of final stage: " + finalStage.parents)
     logInfo("Missing parents: " + getMissingParentStages(finalStage))
@@ -145,12 +144,13 @@ private abstract class DAGScheduler extends Scheduler with Logging {
     }
 
     def submitMissingTasks(stage: Stage) {
+      // Get our pending tasks and remember them in our pendingTasks entry
       val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet)
       var tasks = ArrayBuffer[Task[_]]()
       if (stage == finalStage) {
         for (p <- 0 until numOutputParts if (!finished(p))) {
-          val locs = getPreferredLocs(rdd, p)
-          tasks += new ResultTask(finalStage.id, rdd, func, p, locs)
+          val locs = getPreferredLocs(finalRdd, p)
+          tasks += new ResultTask(finalStage.id, finalRdd, func, p, locs)
         }
       } else {
         for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) {
@@ -186,6 +186,7 @@ private abstract class DAGScheduler extends Scheduler with Logging {
             if (pending.isEmpty) {
               logInfo(stage + " finished; looking for newly runnable stages")
               running -= stage
+              updateCacheLocs()
               val newlyRunnable = new ArrayBuffer[Stage]
               for (stage <- waiting if getMissingParentStages(stage) == Nil) {
                 newlyRunnable += stage
diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala
index ac62c6e411..a253176169 100644
--- a/core/src/main/scala/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/spark/MapOutputTracker.scala
@@ -11,7 +11,7 @@ class MapOutputTracker extends DaemonActor with Logging {
     val port = System.getProperty("spark.master.port", "50501").toInt
     RemoteActor.alive(port)
     RemoteActor.register('MapOutputTracker, self)
-    logInfo("Started on port " + port)
+    logInfo("Registered actor on port " + port)
   }
 }
 
diff --git a/core/src/main/scala/spark/RDDCache.scala b/core/src/main/scala/spark/RDDCache.scala
index 2f2ec9d237..aae2d74900 100644
--- a/core/src/main/scala/spark/RDDCache.scala
+++ b/core/src/main/scala/spark/RDDCache.scala
@@ -3,31 +3,57 @@ package spark
 import scala.actors._
 import scala.actors.Actor._
 import scala.actors.remote._
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
 
 sealed trait CacheMessage
-case class CacheEntryAdded(rddId: Int, partition: Int, host: String)
-case class CacheEntryRemoved(rddId: Int, partition: Int, host: String)
+case class AddedToCache(rddId: Int, partition: Int, host: String) extends CacheMessage
+case class DroppedFromCache(rddId: Int, partition: Int, host: String) extends CacheMessage
+case class MemoryCacheLost(host: String) extends CacheMessage
+case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheMessage
+case object GetCacheLocations extends CacheMessage
 
 class RDDCacheTracker extends DaemonActor with Logging {
+  val locs = new HashMap[Int, Array[List[String]]]
+  // TODO: Should probably store (String, CacheType) tuples
+  
   def act() {
     val port = System.getProperty("spark.master.port", "50501").toInt
     RemoteActor.alive(port)
     RemoteActor.register('RDDCacheTracker, self)
-    logInfo("Started on port " + port)
+    logInfo("Registered actor on port " + port)
     
     loop {
       react {
-        case CacheEntryAdded(rddId, partition, host) =>
-          logInfo("Cache entry added: %s, %s, %s".format(rddId, partition, host))
+        case RegisterRDD(rddId: Int, numPartitions: Int) =>
+          logInfo("Registering RDD " + rddId + " with " + numPartitions + " partitions")
+          locs(rddId) = Array.fill[List[String]](numPartitions)(Nil)
+          reply("")
+        
+        case AddedToCache(rddId, partition, host) =>
+          logInfo("Cache entry added: (%s, %s) on %s".format(rddId, partition, host))
+          locs(rddId)(partition) = host :: locs(rddId)(partition)
           
-        case CacheEntryRemoved(rddId, partition, host) =>
-          logInfo("Cache entry removed: %s, %s, %s".format(rddId, partition, host))
+        case DroppedFromCache(rddId, partition, host) =>
+          logInfo("Cache entry removed: (%s, %s) on %s".format(rddId, partition, host))
+          locs(rddId)(partition) -= host
+        
+        case MemoryCacheLost(host) =>
+          logInfo("Memory cache lost on " + host)
+          // TODO: Drop host from the memory locations list of all RDDs
+        
+        case GetCacheLocations =>
+          logInfo("Asked for current cache locations")
+          val locsCopy = new HashMap[Int, Array[List[String]]]
+          for ((rddId, array) <- locs) {
+            locsCopy(rddId) = array.clone()
+          }
+          reply(locsCopy)
       }
     }
   }
 }
 
-import scala.collection.mutable.HashSet
 private object RDDCache extends Logging {
   // Stores map results for various splits locally
   val cache = Cache.newKeySpace()
@@ -38,6 +64,8 @@ private object RDDCache extends Logging {
   // Tracker actor on the master, or remote reference to it on workers
   var trackerActor: AbstractActor = null
   
+  val registeredRddIds = new HashSet[Int]
+  
   def initialize(isMaster: Boolean) {
     if (isMaster) {
       val tracker = new RDDCacheTracker
@@ -50,16 +78,34 @@ private object RDDCache extends Logging {
     }
   }
   
+  // Registers an RDD (on master only)
+  def registerRDD(rddId: Int, numPartitions: Int) {
+    registeredRddIds.synchronized {
+      if (!registeredRddIds.contains(rddId)) {
+        registeredRddIds += rddId
+        trackerActor !? RegisterRDD(rddId, numPartitions)
+      }
+    }
+  }
+  
+  // Get a snapshot of the currently known locations
+  def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = {
+    (trackerActor !? GetCacheLocations) match {
+      case h: HashMap[Int, Array[List[String]]] => h
+      case _ => throw new SparkException(
+          "Internal error: RDDCache did not reply with a HashMap")
+    }
+  }
+  
   // Gets or computes an RDD split
   def getOrCompute[T](rdd: RDD[T], split: Split)(implicit m: ClassManifest[T])
       : Iterator[T] = {
     val key = (rdd.id, split.index)
-    logInfo("CachedRDD split key is " + key)
-    val cache = RDDCache.cache
-    val loading = RDDCache.loading
+    logInfo("CachedRDD partition key is " + key)
     val cachedVal = cache.get(key)
     if (cachedVal != null) {
       // Split is in cache, so just return its values
+      logInfo("Found partition in cache!")
       return Iterator.fromArray(cachedVal.asInstanceOf[Array[T]])
     } else {
       // Mark the split as loading (unless someone else marks it first)
@@ -73,13 +119,13 @@ private object RDDCache extends Logging {
           loading.add(key)
         }
       }
-      val host = System.getProperty("spark.hostname", Utils.localHostName)
-      trackerActor ! CacheEntryAdded(rdd.id, split.index, host)
       // If we got here, we have to load the split
+      // Tell the master that we're doing so
+      val host = System.getProperty("spark.hostname", Utils.localHostName)
+      trackerActor ! AddedToCache(rdd.id, split.index, host)
       // TODO: fetch any remote copy of the split that may be available
-      // TODO: also notify the master that we're loading it
       // TODO: also register a listener for when it unloads
-      logInfo("Computing and caching " + split)
+      logInfo("Computing partition " + split)
       val array = rdd.compute(split).toArray(m)
       cache.put(key, array)
       loading.synchronized {
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index fda2ee3be7..5cce873c72 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -175,6 +175,7 @@ extends Logging {
   
   private var nextRddId = new AtomicInteger(0)
 
+  // Register a new RDD, returning its RDD ID
   private[spark] def newRddId(): Int = {
     nextRddId.getAndIncrement()
   }
-- 
GitLab