Skip to content
Snippets Groups Projects
Commit 1df5a65a authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Pass cache locations correctly to DAGScheduler.

parent e1436f1e
No related branches found
No related tags found
No related merge requests found
...@@ -33,20 +33,14 @@ private abstract class DAGScheduler extends Scheduler with Logging { ...@@ -33,20 +33,14 @@ private abstract class DAGScheduler extends Scheduler with Logging {
val shuffleToMapStage = new HashMap[ShuffleDependency[_,_,_], Stage] val shuffleToMapStage = new HashMap[ShuffleDependency[_,_,_], Stage]
val cacheLocs = new HashMap[RDD[_], Array[List[String]]] var cacheLocs = new HashMap[Int, Array[List[String]]]
def getCacheLocs(rdd: RDD[_]): Array[List[String]] = { def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
cacheLocs.getOrElseUpdate(rdd, Array.fill[List[String]](rdd.splits.size)(Nil)) cacheLocs(rdd.id)
} }
def addCacheLoc(rdd: RDD[_], partition: Int, host: String) { def updateCacheLocs() {
val locs = getCacheLocs(rdd) cacheLocs = RDDCache.getLocationsSnapshot()
locs(partition) = host :: locs(partition)
}
def removeCacheLoc(rdd: RDD[_], partition: Int, host: String) {
val locs = getCacheLocs(rdd)
locs(partition) -= host
} }
def getShuffleMapStage(shuf: ShuffleDependency[_,_,_]): Stage = { def getShuffleMapStage(shuf: ShuffleDependency[_,_,_]): Stage = {
...@@ -60,6 +54,9 @@ private abstract class DAGScheduler extends Scheduler with Logging { ...@@ -60,6 +54,9 @@ private abstract class DAGScheduler extends Scheduler with Logging {
} }
def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_,_]]): Stage = { def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_,_]]): Stage = {
// Kind of ugly: need to register RDDs with the cache here since
// we can't do it in its constructor because # of splits is unknown
RDDCache.registerRDD(rdd.id, rdd.splits.size)
val id = newStageId() val id = newStageId()
val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd)) val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd))
idToStage(id) = stage idToStage(id) = stage
...@@ -113,10 +110,10 @@ private abstract class DAGScheduler extends Scheduler with Logging { ...@@ -113,10 +110,10 @@ private abstract class DAGScheduler extends Scheduler with Logging {
missing.toList missing.toList
} }
override def runJob[T, U](rdd: RDD[T], func: Iterator[T] => U)(implicit m: ClassManifest[U]) override def runJob[T, U](finalRdd: RDD[T], func: Iterator[T] => U)(implicit m: ClassManifest[U])
: Array[U] = { : Array[U] = {
val numOutputParts: Int = rdd.splits.size val numOutputParts: Int = finalRdd.splits.size
val finalStage = newStage(rdd, None) val finalStage = newStage(finalRdd, None)
val results = new Array[U](numOutputParts) val results = new Array[U](numOutputParts)
val finished = new Array[Boolean](numOutputParts) val finished = new Array[Boolean](numOutputParts)
var numFinished = 0 var numFinished = 0
...@@ -125,6 +122,8 @@ private abstract class DAGScheduler extends Scheduler with Logging { ...@@ -125,6 +122,8 @@ private abstract class DAGScheduler extends Scheduler with Logging {
val running = new HashSet[Stage] val running = new HashSet[Stage]
val pendingTasks = new HashMap[Stage, HashSet[Task[_]]] val pendingTasks = new HashMap[Stage, HashSet[Task[_]]]
updateCacheLocs()
logInfo("Final stage: " + finalStage) logInfo("Final stage: " + finalStage)
logInfo("Parents of final stage: " + finalStage.parents) logInfo("Parents of final stage: " + finalStage.parents)
logInfo("Missing parents: " + getMissingParentStages(finalStage)) logInfo("Missing parents: " + getMissingParentStages(finalStage))
...@@ -145,12 +144,13 @@ private abstract class DAGScheduler extends Scheduler with Logging { ...@@ -145,12 +144,13 @@ private abstract class DAGScheduler extends Scheduler with Logging {
} }
def submitMissingTasks(stage: Stage) { def submitMissingTasks(stage: Stage) {
// Get our pending tasks and remember them in our pendingTasks entry
val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet) val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet)
var tasks = ArrayBuffer[Task[_]]() var tasks = ArrayBuffer[Task[_]]()
if (stage == finalStage) { if (stage == finalStage) {
for (p <- 0 until numOutputParts if (!finished(p))) { for (p <- 0 until numOutputParts if (!finished(p))) {
val locs = getPreferredLocs(rdd, p) val locs = getPreferredLocs(finalRdd, p)
tasks += new ResultTask(finalStage.id, rdd, func, p, locs) tasks += new ResultTask(finalStage.id, finalRdd, func, p, locs)
} }
} else { } else {
for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) { for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) {
...@@ -186,6 +186,7 @@ private abstract class DAGScheduler extends Scheduler with Logging { ...@@ -186,6 +186,7 @@ private abstract class DAGScheduler extends Scheduler with Logging {
if (pending.isEmpty) { if (pending.isEmpty) {
logInfo(stage + " finished; looking for newly runnable stages") logInfo(stage + " finished; looking for newly runnable stages")
running -= stage running -= stage
updateCacheLocs()
val newlyRunnable = new ArrayBuffer[Stage] val newlyRunnable = new ArrayBuffer[Stage]
for (stage <- waiting if getMissingParentStages(stage) == Nil) { for (stage <- waiting if getMissingParentStages(stage) == Nil) {
newlyRunnable += stage newlyRunnable += stage
......
...@@ -11,7 +11,7 @@ class MapOutputTracker extends DaemonActor with Logging { ...@@ -11,7 +11,7 @@ class MapOutputTracker extends DaemonActor with Logging {
val port = System.getProperty("spark.master.port", "50501").toInt val port = System.getProperty("spark.master.port", "50501").toInt
RemoteActor.alive(port) RemoteActor.alive(port)
RemoteActor.register('MapOutputTracker, self) RemoteActor.register('MapOutputTracker, self)
logInfo("Started on port " + port) logInfo("Registered actor on port " + port)
} }
} }
......
...@@ -3,31 +3,57 @@ package spark ...@@ -3,31 +3,57 @@ package spark
import scala.actors._ import scala.actors._
import scala.actors.Actor._ import scala.actors.Actor._
import scala.actors.remote._ import scala.actors.remote._
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
sealed trait CacheMessage sealed trait CacheMessage
case class CacheEntryAdded(rddId: Int, partition: Int, host: String) case class AddedToCache(rddId: Int, partition: Int, host: String) extends CacheMessage
case class CacheEntryRemoved(rddId: Int, partition: Int, host: String) case class DroppedFromCache(rddId: Int, partition: Int, host: String) extends CacheMessage
case class MemoryCacheLost(host: String) extends CacheMessage
case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheMessage
case object GetCacheLocations extends CacheMessage
class RDDCacheTracker extends DaemonActor with Logging { class RDDCacheTracker extends DaemonActor with Logging {
val locs = new HashMap[Int, Array[List[String]]]
// TODO: Should probably store (String, CacheType) tuples
def act() { def act() {
val port = System.getProperty("spark.master.port", "50501").toInt val port = System.getProperty("spark.master.port", "50501").toInt
RemoteActor.alive(port) RemoteActor.alive(port)
RemoteActor.register('RDDCacheTracker, self) RemoteActor.register('RDDCacheTracker, self)
logInfo("Started on port " + port) logInfo("Registered actor on port " + port)
loop { loop {
react { react {
case CacheEntryAdded(rddId, partition, host) => case RegisterRDD(rddId: Int, numPartitions: Int) =>
logInfo("Cache entry added: %s, %s, %s".format(rddId, partition, host)) logInfo("Registering RDD " + rddId + " with " + numPartitions + " partitions")
locs(rddId) = Array.fill[List[String]](numPartitions)(Nil)
reply("")
case AddedToCache(rddId, partition, host) =>
logInfo("Cache entry added: (%s, %s) on %s".format(rddId, partition, host))
locs(rddId)(partition) = host :: locs(rddId)(partition)
case CacheEntryRemoved(rddId, partition, host) => case DroppedFromCache(rddId, partition, host) =>
logInfo("Cache entry removed: %s, %s, %s".format(rddId, partition, host)) logInfo("Cache entry removed: (%s, %s) on %s".format(rddId, partition, host))
locs(rddId)(partition) -= host
case MemoryCacheLost(host) =>
logInfo("Memory cache lost on " + host)
// TODO: Drop host from the memory locations list of all RDDs
case GetCacheLocations =>
logInfo("Asked for current cache locations")
val locsCopy = new HashMap[Int, Array[List[String]]]
for ((rddId, array) <- locs) {
locsCopy(rddId) = array.clone()
}
reply(locsCopy)
} }
} }
} }
} }
import scala.collection.mutable.HashSet
private object RDDCache extends Logging { private object RDDCache extends Logging {
// Stores map results for various splits locally // Stores map results for various splits locally
val cache = Cache.newKeySpace() val cache = Cache.newKeySpace()
...@@ -38,6 +64,8 @@ private object RDDCache extends Logging { ...@@ -38,6 +64,8 @@ private object RDDCache extends Logging {
// Tracker actor on the master, or remote reference to it on workers // Tracker actor on the master, or remote reference to it on workers
var trackerActor: AbstractActor = null var trackerActor: AbstractActor = null
val registeredRddIds = new HashSet[Int]
def initialize(isMaster: Boolean) { def initialize(isMaster: Boolean) {
if (isMaster) { if (isMaster) {
val tracker = new RDDCacheTracker val tracker = new RDDCacheTracker
...@@ -50,16 +78,34 @@ private object RDDCache extends Logging { ...@@ -50,16 +78,34 @@ private object RDDCache extends Logging {
} }
} }
// Registers an RDD (on master only)
def registerRDD(rddId: Int, numPartitions: Int) {
registeredRddIds.synchronized {
if (!registeredRddIds.contains(rddId)) {
registeredRddIds += rddId
trackerActor !? RegisterRDD(rddId, numPartitions)
}
}
}
// Get a snapshot of the currently known locations
def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = {
(trackerActor !? GetCacheLocations) match {
case h: HashMap[Int, Array[List[String]]] => h
case _ => throw new SparkException(
"Internal error: RDDCache did not reply with a HashMap")
}
}
// Gets or computes an RDD split // Gets or computes an RDD split
def getOrCompute[T](rdd: RDD[T], split: Split)(implicit m: ClassManifest[T]) def getOrCompute[T](rdd: RDD[T], split: Split)(implicit m: ClassManifest[T])
: Iterator[T] = { : Iterator[T] = {
val key = (rdd.id, split.index) val key = (rdd.id, split.index)
logInfo("CachedRDD split key is " + key) logInfo("CachedRDD partition key is " + key)
val cache = RDDCache.cache
val loading = RDDCache.loading
val cachedVal = cache.get(key) val cachedVal = cache.get(key)
if (cachedVal != null) { if (cachedVal != null) {
// Split is in cache, so just return its values // Split is in cache, so just return its values
logInfo("Found partition in cache!")
return Iterator.fromArray(cachedVal.asInstanceOf[Array[T]]) return Iterator.fromArray(cachedVal.asInstanceOf[Array[T]])
} else { } else {
// Mark the split as loading (unless someone else marks it first) // Mark the split as loading (unless someone else marks it first)
...@@ -73,13 +119,13 @@ private object RDDCache extends Logging { ...@@ -73,13 +119,13 @@ private object RDDCache extends Logging {
loading.add(key) loading.add(key)
} }
} }
val host = System.getProperty("spark.hostname", Utils.localHostName)
trackerActor ! CacheEntryAdded(rdd.id, split.index, host)
// If we got here, we have to load the split // If we got here, we have to load the split
// Tell the master that we're doing so
val host = System.getProperty("spark.hostname", Utils.localHostName)
trackerActor ! AddedToCache(rdd.id, split.index, host)
// TODO: fetch any remote copy of the split that may be available // TODO: fetch any remote copy of the split that may be available
// TODO: also notify the master that we're loading it
// TODO: also register a listener for when it unloads // TODO: also register a listener for when it unloads
logInfo("Computing and caching " + split) logInfo("Computing partition " + split)
val array = rdd.compute(split).toArray(m) val array = rdd.compute(split).toArray(m)
cache.put(key, array) cache.put(key, array)
loading.synchronized { loading.synchronized {
......
...@@ -175,6 +175,7 @@ extends Logging { ...@@ -175,6 +175,7 @@ extends Logging {
private var nextRddId = new AtomicInteger(0) private var nextRddId = new AtomicInteger(0)
// Register a new RDD, returning its RDD ID
private[spark] def newRddId(): Int = { private[spark] def newRddId(): Int = {
nextRddId.getAndIncrement() nextRddId.getAndIncrement()
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment