Skip to content
Snippets Groups Projects
Commit 1df5a65a authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Pass cache locations correctly to DAGScheduler.

parent e1436f1e
No related branches found
No related tags found
No related merge requests found
......@@ -33,20 +33,14 @@ private abstract class DAGScheduler extends Scheduler with Logging {
val shuffleToMapStage = new HashMap[ShuffleDependency[_,_,_], Stage]
val cacheLocs = new HashMap[RDD[_], Array[List[String]]]
var cacheLocs = new HashMap[Int, Array[List[String]]]
def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
cacheLocs.getOrElseUpdate(rdd, Array.fill[List[String]](rdd.splits.size)(Nil))
cacheLocs(rdd.id)
}
def addCacheLoc(rdd: RDD[_], partition: Int, host: String) {
val locs = getCacheLocs(rdd)
locs(partition) = host :: locs(partition)
}
def removeCacheLoc(rdd: RDD[_], partition: Int, host: String) {
val locs = getCacheLocs(rdd)
locs(partition) -= host
def updateCacheLocs() {
cacheLocs = RDDCache.getLocationsSnapshot()
}
def getShuffleMapStage(shuf: ShuffleDependency[_,_,_]): Stage = {
......@@ -60,6 +54,9 @@ private abstract class DAGScheduler extends Scheduler with Logging {
}
def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_,_]]): Stage = {
// Kind of ugly: need to register RDDs with the cache here since
// we can't do it in its constructor because # of splits is unknown
RDDCache.registerRDD(rdd.id, rdd.splits.size)
val id = newStageId()
val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd))
idToStage(id) = stage
......@@ -113,10 +110,10 @@ private abstract class DAGScheduler extends Scheduler with Logging {
missing.toList
}
override def runJob[T, U](rdd: RDD[T], func: Iterator[T] => U)(implicit m: ClassManifest[U])
override def runJob[T, U](finalRdd: RDD[T], func: Iterator[T] => U)(implicit m: ClassManifest[U])
: Array[U] = {
val numOutputParts: Int = rdd.splits.size
val finalStage = newStage(rdd, None)
val numOutputParts: Int = finalRdd.splits.size
val finalStage = newStage(finalRdd, None)
val results = new Array[U](numOutputParts)
val finished = new Array[Boolean](numOutputParts)
var numFinished = 0
......@@ -125,6 +122,8 @@ private abstract class DAGScheduler extends Scheduler with Logging {
val running = new HashSet[Stage]
val pendingTasks = new HashMap[Stage, HashSet[Task[_]]]
updateCacheLocs()
logInfo("Final stage: " + finalStage)
logInfo("Parents of final stage: " + finalStage.parents)
logInfo("Missing parents: " + getMissingParentStages(finalStage))
......@@ -145,12 +144,13 @@ private abstract class DAGScheduler extends Scheduler with Logging {
}
def submitMissingTasks(stage: Stage) {
// Get our pending tasks and remember them in our pendingTasks entry
val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet)
var tasks = ArrayBuffer[Task[_]]()
if (stage == finalStage) {
for (p <- 0 until numOutputParts if (!finished(p))) {
val locs = getPreferredLocs(rdd, p)
tasks += new ResultTask(finalStage.id, rdd, func, p, locs)
val locs = getPreferredLocs(finalRdd, p)
tasks += new ResultTask(finalStage.id, finalRdd, func, p, locs)
}
} else {
for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) {
......@@ -186,6 +186,7 @@ private abstract class DAGScheduler extends Scheduler with Logging {
if (pending.isEmpty) {
logInfo(stage + " finished; looking for newly runnable stages")
running -= stage
updateCacheLocs()
val newlyRunnable = new ArrayBuffer[Stage]
for (stage <- waiting if getMissingParentStages(stage) == Nil) {
newlyRunnable += stage
......
......@@ -11,7 +11,7 @@ class MapOutputTracker extends DaemonActor with Logging {
val port = System.getProperty("spark.master.port", "50501").toInt
RemoteActor.alive(port)
RemoteActor.register('MapOutputTracker, self)
logInfo("Started on port " + port)
logInfo("Registered actor on port " + port)
}
}
......
......@@ -3,31 +3,57 @@ package spark
import scala.actors._
import scala.actors.Actor._
import scala.actors.remote._
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
sealed trait CacheMessage
case class CacheEntryAdded(rddId: Int, partition: Int, host: String)
case class CacheEntryRemoved(rddId: Int, partition: Int, host: String)
case class AddedToCache(rddId: Int, partition: Int, host: String) extends CacheMessage
case class DroppedFromCache(rddId: Int, partition: Int, host: String) extends CacheMessage
case class MemoryCacheLost(host: String) extends CacheMessage
case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheMessage
case object GetCacheLocations extends CacheMessage
class RDDCacheTracker extends DaemonActor with Logging {
val locs = new HashMap[Int, Array[List[String]]]
// TODO: Should probably store (String, CacheType) tuples
def act() {
val port = System.getProperty("spark.master.port", "50501").toInt
RemoteActor.alive(port)
RemoteActor.register('RDDCacheTracker, self)
logInfo("Started on port " + port)
logInfo("Registered actor on port " + port)
loop {
react {
case CacheEntryAdded(rddId, partition, host) =>
logInfo("Cache entry added: %s, %s, %s".format(rddId, partition, host))
case RegisterRDD(rddId: Int, numPartitions: Int) =>
logInfo("Registering RDD " + rddId + " with " + numPartitions + " partitions")
locs(rddId) = Array.fill[List[String]](numPartitions)(Nil)
reply("")
case AddedToCache(rddId, partition, host) =>
logInfo("Cache entry added: (%s, %s) on %s".format(rddId, partition, host))
locs(rddId)(partition) = host :: locs(rddId)(partition)
case CacheEntryRemoved(rddId, partition, host) =>
logInfo("Cache entry removed: %s, %s, %s".format(rddId, partition, host))
case DroppedFromCache(rddId, partition, host) =>
logInfo("Cache entry removed: (%s, %s) on %s".format(rddId, partition, host))
locs(rddId)(partition) -= host
case MemoryCacheLost(host) =>
logInfo("Memory cache lost on " + host)
// TODO: Drop host from the memory locations list of all RDDs
case GetCacheLocations =>
logInfo("Asked for current cache locations")
val locsCopy = new HashMap[Int, Array[List[String]]]
for ((rddId, array) <- locs) {
locsCopy(rddId) = array.clone()
}
reply(locsCopy)
}
}
}
}
import scala.collection.mutable.HashSet
private object RDDCache extends Logging {
// Stores map results for various splits locally
val cache = Cache.newKeySpace()
......@@ -38,6 +64,8 @@ private object RDDCache extends Logging {
// Tracker actor on the master, or remote reference to it on workers
var trackerActor: AbstractActor = null
val registeredRddIds = new HashSet[Int]
def initialize(isMaster: Boolean) {
if (isMaster) {
val tracker = new RDDCacheTracker
......@@ -50,16 +78,34 @@ private object RDDCache extends Logging {
}
}
// Registers an RDD (on master only)
def registerRDD(rddId: Int, numPartitions: Int) {
registeredRddIds.synchronized {
if (!registeredRddIds.contains(rddId)) {
registeredRddIds += rddId
trackerActor !? RegisterRDD(rddId, numPartitions)
}
}
}
// Get a snapshot of the currently known locations
def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = {
(trackerActor !? GetCacheLocations) match {
case h: HashMap[Int, Array[List[String]]] => h
case _ => throw new SparkException(
"Internal error: RDDCache did not reply with a HashMap")
}
}
// Gets or computes an RDD split
def getOrCompute[T](rdd: RDD[T], split: Split)(implicit m: ClassManifest[T])
: Iterator[T] = {
val key = (rdd.id, split.index)
logInfo("CachedRDD split key is " + key)
val cache = RDDCache.cache
val loading = RDDCache.loading
logInfo("CachedRDD partition key is " + key)
val cachedVal = cache.get(key)
if (cachedVal != null) {
// Split is in cache, so just return its values
logInfo("Found partition in cache!")
return Iterator.fromArray(cachedVal.asInstanceOf[Array[T]])
} else {
// Mark the split as loading (unless someone else marks it first)
......@@ -73,13 +119,13 @@ private object RDDCache extends Logging {
loading.add(key)
}
}
val host = System.getProperty("spark.hostname", Utils.localHostName)
trackerActor ! CacheEntryAdded(rdd.id, split.index, host)
// If we got here, we have to load the split
// Tell the master that we're doing so
val host = System.getProperty("spark.hostname", Utils.localHostName)
trackerActor ! AddedToCache(rdd.id, split.index, host)
// TODO: fetch any remote copy of the split that may be available
// TODO: also notify the master that we're loading it
// TODO: also register a listener for when it unloads
logInfo("Computing and caching " + split)
logInfo("Computing partition " + split)
val array = rdd.compute(split).toArray(m)
cache.put(key, array)
loading.synchronized {
......
......@@ -175,6 +175,7 @@ extends Logging {
private var nextRddId = new AtomicInteger(0)
// Register a new RDD, returning its RDD ID
private[spark] def newRddId(): Int = {
nextRddId.getAndIncrement()
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment