diff --git a/core/src/main/scala/spark/DAGScheduler.scala b/core/src/main/scala/spark/DAGScheduler.scala index 734cbea8228f2f0edf72154ea52146ffbbac7189..2e427dcb0cbf21a9562cac3e597d68d41af917a0 100644 --- a/core/src/main/scala/spark/DAGScheduler.scala +++ b/core/src/main/scala/spark/DAGScheduler.scala @@ -10,7 +10,7 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} * only need to implement the code to send a task to the cluster and to report * failures from it (the submitTasks method, and code to add completionEvents). */ -private abstract class DAGScheduler extends Scheduler with Logging { +private trait DAGScheduler extends Scheduler with Logging { // Must be implemented by subclasses to start running a set of tasks def submitTasks(tasks: Seq[Task[_]]): Unit @@ -69,6 +69,9 @@ private abstract class DAGScheduler extends Scheduler with Logging { def visit(r: RDD[_]) { if (!visited(r)) { visited += r + // Kind of ugly: need to register RDDs with the cache here since + // we can't do it in its constructor because # of splits is unknown + RDDCache.registerRDD(r.id, r.splits.size) for (dep <- r.dependencies) { dep match { case shufDep: ShuffleDependency[_,_,_] => @@ -180,12 +183,14 @@ private abstract class DAGScheduler extends Scheduler with Logging { stage.addOutputLoc(smt.partition, evt.result.asInstanceOf[String]) val pending = pendingTasks(stage) pending -= smt - MapOutputTracker.registerMapOutputs( - stage.shuffleDep.get.shuffleId, - stage.outputLocs.map(_.first).toArray) if (pending.isEmpty) { logInfo(stage + " finished; looking for newly runnable stages") running -= stage + if (stage.shuffleDep != None) { + MapOutputTracker.registerMapOutputs( + stage.shuffleDep.get.shuffleId, + stage.outputLocs.map(_.first).toArray) + } updateCacheLocs() val newlyRunnable = new ArrayBuffer[Stage] for (stage <- waiting if getMissingParentStages(stage) == Nil) { diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index a2531761696ce1a1e04358aea03807a4e53bfc59..07fd605cca0af1ab61c3a6eee734141b1dedc497 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -5,22 +5,37 @@ import java.util.concurrent.ConcurrentHashMap import scala.actors._ import scala.actors.Actor._ import scala.actors.remote._ +import scala.collection.mutable.HashSet -class MapOutputTracker extends DaemonActor with Logging { +sealed trait MapOutputTrackerMessage +case class GetMapOutputLocations(shuffleId: Int) extends MapOutputTrackerMessage + +class MapOutputTracker(serverUris: ConcurrentHashMap[Int, Array[String]]) +extends DaemonActor with Logging { def act() { val port = System.getProperty("spark.master.port", "50501").toInt RemoteActor.alive(port) RemoteActor.register('MapOutputTracker, self) logInfo("Registered actor on port " + port) + + loop { + react { + case GetMapOutputLocations(shuffleId: Int) => + logInfo("Asked to get map output locations for shuffle " + shuffleId) + reply(serverUris.get(shuffleId)) + } + } } } -object MapOutputTracker { +object MapOutputTracker extends Logging { var trackerActor: AbstractActor = null + private val serverUris = new ConcurrentHashMap[Int, Array[String]] + def initialize(isMaster: Boolean) { if (isMaster) { - val tracker = new MapOutputTracker + val tracker = new MapOutputTracker(serverUris) tracker.start trackerActor = tracker } else { @@ -30,8 +45,6 @@ object MapOutputTracker { } } - private val serverUris = new ConcurrentHashMap[Int, Array[String]] - def registerMapOutput(shuffleId: Int, numMaps: Int, mapId: Int, serverUri: String) { var array = serverUris.get(shuffleId) if (array == null) { @@ -45,9 +58,38 @@ object MapOutputTracker { serverUris.put(shuffleId, Array[String]() ++ locs) } + + // Remembers which map output locations are currently being fetched + val fetching = new HashSet[Int] + def getServerUris(shuffleId: Int): Array[String] = { // TODO: On remote node, fetch locations from master - serverUris.get(shuffleId) + val locs = serverUris.get(shuffleId) + if (locs == null) { + logInfo("Don't have map outputs for " + shuffleId + ", fetching them") + fetching.synchronized { + if (fetching.contains(shuffleId)) { + // Someone else is fetching it; wait for them to be done + while (fetching.contains(shuffleId)) { + try {fetching.wait()} catch {case _ =>} + } + return serverUris.get(shuffleId) + } else { + fetching += shuffleId + } + } + // We won the race to fetch the output locs; do so + logInfo("Doing the fetch; tracker actor = " + trackerActor) + val fetched = (trackerActor !? GetMapOutputLocations(shuffleId)).asInstanceOf[Array[String]] + serverUris.put(shuffleId, fetched) + fetching.synchronized { + fetching -= shuffleId + fetching.notifyAll() + } + return fetched + } else { + return locs + } } def getMapOutputUri(serverUri: String, shuffleId: Int, mapId: Int, reduceId: Int): String = { diff --git a/core/src/main/scala/spark/MesosScheduler.scala b/core/src/main/scala/spark/MesosScheduler.scala index 35ad552775fc02ca69c2ffb00efa71a8452d3c01..fc8c111bccf5371b5b498bbce700d795fbc0dfad 100644 --- a/core/src/main/scala/spark/MesosScheduler.scala +++ b/core/src/main/scala/spark/MesosScheduler.scala @@ -21,7 +21,7 @@ import mesos._ */ private class MesosScheduler( sc: SparkContext, master: String, frameworkName: String) -extends MScheduler with spark.Scheduler with Logging +extends MScheduler with DAGScheduler with Logging { // Environment variables to pass to our executors val ENV_VARS_TO_SEND_TO_EXECUTORS = Array( @@ -52,7 +52,7 @@ extends MScheduler with spark.Scheduler with Logging // URIs of JARs to pass to executor var jarUris: String = "" - + def newJobId(): Int = this.synchronized { val id = nextJobId nextJobId += 1 @@ -101,6 +101,31 @@ extends MScheduler with spark.Scheduler with Logging new ExecutorInfo(execScript, createExecArg(), params) } + + def submitTasks(tasks: Seq[Task[_]]) { + logInfo("Got a job with " + tasks.size + " tasks") + waitForRegister() + this.synchronized { + val jobId = newJobId() + val myJob = new SimpleJob(this, tasks, jobId) + activeJobs(jobId) = myJob + activeJobsQueue += myJob + logInfo("Adding job with ID " + jobId) + jobTasks(jobId) = new HashSet() + } + driver.reviveOffers(); + } + + def jobFinished(job: Job) { + this.synchronized { + activeJobs -= job.getId + activeJobsQueue.dequeueAll(x => (x == job)) + taskIdToJobId --= jobTasks(job.getId) + jobTasks.remove(job.getId) + } + } + + /* /** * The primary means to submit a job to the scheduler. Given a list of tasks, * runs them and returns an array of the results. @@ -126,6 +151,7 @@ extends MScheduler with spark.Scheduler with Logging } } } + */ override def registered(d: SchedulerDriver, frameworkId: String) { logInfo("Registered as framework ID " + frameworkId) @@ -199,7 +225,8 @@ extends MScheduler with spark.Scheduler with Logging } if (isFinished(status.getState)) { taskIdToJobId.remove(status.getTaskId) - jobTasks(jobId) -= status.getTaskId + if (jobTasks.contains(jobId)) + jobTasks(jobId) -= status.getTaskId } case None => logInfo("TID " + status.getTaskId + " already finished") @@ -291,9 +318,4 @@ extends MScheduler with spark.Scheduler with Logging // Serialize the map as an array of (String, String) pairs return Utils.serialize(props.toArray) } - - override def runJob[T, U](rdd: RDD[T], func: Iterator[T] => U)(implicit m: ClassManifest[U]) - : Array[U] = { - new Array[U](0) - } } diff --git a/core/src/main/scala/spark/RDDCache.scala b/core/src/main/scala/spark/RDDCache.scala index aae2d74900666d247d3257babc691ed774fe57e2..5c9b137b4b90546a3c3968bd7f64c62e8fd8ff11 100644 --- a/core/src/main/scala/spark/RDDCache.scala +++ b/core/src/main/scala/spark/RDDCache.scala @@ -82,6 +82,7 @@ private object RDDCache extends Logging { def registerRDD(rddId: Int, numPartitions: Int) { registeredRddIds.synchronized { if (!registeredRddIds.contains(rddId)) { + logInfo("Registering RDD ID " + rddId + " with cache") registeredRddIds += rddId trackerActor !? RegisterRDD(rddId, numPartitions) } diff --git a/core/src/main/scala/spark/SimpleJob.scala b/core/src/main/scala/spark/SimpleJob.scala index 09846ccc34a1f26000b93ea1fc2e7f0034be2a21..afe9093dcf2abdcfbf0badda7140a4653302874b 100644 --- a/core/src/main/scala/spark/SimpleJob.scala +++ b/core/src/main/scala/spark/SimpleJob.scala @@ -11,8 +11,8 @@ import mesos._ /** * A Job that runs a set of tasks with no interdependencies. */ -class SimpleJob[T: ClassManifest]( - sched: MesosScheduler, tasks: Array[Task[T]], val jobId: Int) +class SimpleJob( + sched: MesosScheduler, tasksSeq: Seq[Task[_]], val jobId: Int) extends Job(jobId) with Logging { // Maximum time to wait to run a task in a preferred location (in ms) @@ -26,8 +26,8 @@ extends Job(jobId) with Logging val MAX_TASK_FAILURES = 4 val callingThread = currentThread + val tasks = tasksSeq.toArray val numTasks = tasks.length - val results = new Array[T](numTasks) val launched = new Array[Boolean](numTasks) val finished = new Array[Boolean](numTasks) val numFailures = new Array[Int](numTasks) @@ -87,20 +87,7 @@ extends Job(jobId) with Logging allFinished = true joinLock.notifyAll() } - } - - // Wait until the job finishes and return its results - def join(): Array[T] = { - joinLock.synchronized { - while (!allFinished) { - joinLock.wait() - } - if (failed) { - throw new SparkException(causeOfFailure) - } else { - return results - } - } + sched.jobFinished(this) } // Return the pending tasks list for a given host, or an empty list if @@ -145,7 +132,7 @@ extends Job(jobId) with Logging // Does a host count as a preferred location for a task? This is true if // either the task has preferred locations and this host is one, or it has // no preferred locations (in which we still count the launch as preferred). - def isPreferredLocation(task: Task[T], host: String): Boolean = { + def isPreferredLocation(task: Task[_], host: String): Boolean = { val locs = task.preferredLocations return (locs.contains(host) || locs.isEmpty) } @@ -215,10 +202,8 @@ extends Job(jobId) with Logging logInfo("Finished TID %d (progress: %d/%d)".format( tid, tasksFinished, numTasks)) // Deserialize task result - val result = Utils.deserialize[TaskResult[T]](status.getData) - results(index) = result.value - // Update accumulators - Accumulators.add(callingThread, result.accumUpdates) + val result = Utils.deserialize[TaskResult[_]](status.getData) + sched.taskEnded(tasks(index), true, result.value, result.accumUpdates) // Mark finished and stop if we've finished all the tasks finished(index) = true if (tasksFinished == numTasks) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 5cce873c7271c3a6d3559934ef87924dddc2d3dc..34641918f08b383c38ca7cb59d7f869ed4d9b542 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -15,6 +15,12 @@ class SparkContext( val sparkHome: String = null, val jars: Seq[String] = Nil) extends Logging { + // Set Spark master host and port system properties + if (System.getProperty("spark.master.host") == null) + System.setProperty("spark.master.host", Utils.localHostName) + if (System.getProperty("spark.master.port") == null) + System.setProperty("spark.master.port", "50501") + private var scheduler: Scheduler = { // Regular expression used for local[N] master format val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r @@ -30,7 +36,7 @@ extends Logging { } private val isLocal = scheduler.isInstanceOf[LocalScheduler] - + // Start the scheduler, the cache and the broadcast system scheduler.start() Cache.initialize()