diff --git a/core/src/main/scala/spark/DAGScheduler.scala b/core/src/main/scala/spark/DAGScheduler.scala index 3eee0970b6a0294af44d925d130007dabf267ab3..423510d883229b7b49ca31e86fbc397c7ce7e0c7 100644 --- a/core/src/main/scala/spark/DAGScheduler.scala +++ b/core/src/main/scala/spark/DAGScheduler.scala @@ -1,19 +1,19 @@ package spark -import java.util.concurrent.LinkedBlockingQueue -import java.util.concurrent.TimeUnit -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue, Map} /** * A task created by the DAG scheduler. Knows its stage ID and map ouput tracker generation. */ -abstract class DAGTask[T](val stageId: Int) extends Task[T] { +abstract class DAGTask[T](val runId: Int, val stageId: Int) extends Task[T] { val gen = SparkEnv.get.mapOutputTracker.getGeneration override def generation: Option[Long] = Some(gen) } /** - * A completion event passed by the underlying task scheduler to the DAG scheduler + * A completion event passed by the underlying task scheduler to the DAG scheduler. */ case class CompletionEvent( task: DAGTask[_], @@ -39,13 +39,22 @@ case class OtherFailure(message: String) extends TaskEndReason * and to report fetch failures (the submitTasks method, and code to add CompletionEvents). */ private trait DAGScheduler extends Scheduler with Logging { - // Must be implemented by subclasses to start running a set of tasks - def submitTasks(tasks: Seq[Task[_]]): Unit + // Must be implemented by subclasses to start running a set of tasks. The subclass should also + // attempt to run different sets of tasks in the order given by runId (lower values first). + def submitTasks(tasks: Seq[Task[_]], runId: Int): Unit - // Must be called by subclasses to report task completions or failures + // Must be called by subclasses to report task completions or failures. def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any]) { - val dagTask = task.asInstanceOf[DAGTask[_]] - completionEvents.put(CompletionEvent(dagTask, reason, result, accumUpdates)) + lock.synchronized { + val dagTask = task.asInstanceOf[DAGTask[_]] + eventQueues.get(dagTask.runId) match { + case Some(queue) => + queue += CompletionEvent(dagTask, reason, result, accumUpdates) + lock.notifyAll() + case None => + logInfo("Ignoring completion event for DAG job " + dagTask.runId + " because it's gone") + } + } } // The time, in millis, to wait for fetch failure events to stop coming in after one is detected; @@ -57,16 +66,13 @@ private trait DAGScheduler extends Scheduler with Logging { // resubmit failed stages val POLL_TIMEOUT = 500L - private val completionEvents = new LinkedBlockingQueue[CompletionEvent] - private val lock = new Object + private val lock = new Object // Used for access to the entire DAGScheduler - var nextStageId = 0 + private val eventQueues = new HashMap[Int, Queue[CompletionEvent]] // Indexed by run ID - def newStageId() = { - var res = nextStageId - nextStageId += 1 - res - } + val nextRunId = new AtomicInteger(0) + + val nextStageId = new AtomicInteger(0) val idToStage = new HashMap[Int, Stage] @@ -103,7 +109,7 @@ private trait DAGScheduler extends Scheduler with Logging { if (shuffleDep != None) { mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size) } - val id = newStageId() + val id = nextStageId.getAndIncrement() val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd)) idToStage(id) = stage stage @@ -167,6 +173,8 @@ private trait DAGScheduler extends Scheduler with Logging { allowLocal: Boolean) (implicit m: ClassManifest[U]): Array[U] = { lock.synchronized { + val runId = nextRunId.getAndIncrement() + val outputParts = partitions.toArray val numOutputParts: Int = partitions.size val finalStage = newStage(finalRdd, None) @@ -196,6 +204,9 @@ private trait DAGScheduler extends Scheduler with Logging { val taskContext = new TaskContext(finalStage.id, outputParts(0), 0) return Array(func(taskContext, finalRdd.iterator(split))) } + + // Register the job ID so that we can get completion events for it + eventQueues(runId) = new Queue[CompletionEvent] def submitStage(stage: Stage) { if (!waiting(stage) && !running(stage)) { @@ -221,26 +232,27 @@ private trait DAGScheduler extends Scheduler with Logging { for (id <- 0 until numOutputParts if (!finished(id))) { val part = outputParts(id) val locs = getPreferredLocs(finalRdd, part) - tasks += new ResultTask(finalStage.id, finalRdd, func, part, locs, id) + tasks += new ResultTask(runId, finalStage.id, finalRdd, func, part, locs, id) } } else { for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) { val locs = getPreferredLocs(stage.rdd, p) - tasks += new ShuffleMapTask(stage.id, stage.rdd, stage.shuffleDep.get, p, locs) + tasks += new ShuffleMapTask(runId, stage.id, stage.rdd, stage.shuffleDep.get, p, locs) } } myPending ++= tasks - submitTasks(tasks) + submitTasks(tasks, runId) } submitStage(finalStage) while (numFinished != numOutputParts) { - val evt = completionEvents.poll(POLL_TIMEOUT, TimeUnit.MILLISECONDS) + val eventOption = waitForEvent(runId, POLL_TIMEOUT) val time = System.currentTimeMillis // TODO: use a pluggable clock for testability // If we got an event off the queue, mark the task done or react to a fetch failure - if (evt != null) { + if (eventOption != None) { + val evt = eventOption.get val stage = idToStage(evt.task.stageId) pendingTasks(stage) -= evt.task if (evt.reason == Success) { @@ -315,6 +327,7 @@ private trait DAGScheduler extends Scheduler with Logging { } } + eventQueues -= runId return results } } @@ -344,4 +357,18 @@ private trait DAGScheduler extends Scheduler with Logging { }) return Nil } + + // Assumes that lock is held on entrance, but will release it to wait for the next event. + def waitForEvent(runId: Int, timeout: Long): Option[CompletionEvent] = { + val endTime = System.currentTimeMillis() + timeout // TODO: Use pluggable clock for testing + while (eventQueues(runId).isEmpty) { + val time = System.currentTimeMillis() + if (time > endTime) { + return None + } else { + lock.wait(endTime - time) + } + } + return Some(eventQueues(runId).dequeue()) + } } diff --git a/core/src/main/scala/spark/Job.scala b/core/src/main/scala/spark/Job.scala index 9846e918738d1d3ee1b395f0f00f23bbd602b107..0d68470c03a17859b3d576a31312ffdd51967d55 100644 --- a/core/src/main/scala/spark/Job.scala +++ b/core/src/main/scala/spark/Job.scala @@ -7,12 +7,10 @@ import org.apache.mesos.Protos._ * Class representing a parallel job in MesosScheduler. Schedules the job by implementing various * callbacks. */ -abstract class Job(jobId: Int) { +abstract class Job(val runId: Int, val jobId: Int) { def slaveOffer(s: Offer, availableCpus: Double): Option[TaskDescription] def statusUpdate(t: TaskStatus): Unit def error(code: Int, message: String): Unit - - def getId(): Int = jobId } diff --git a/core/src/main/scala/spark/LocalScheduler.scala b/core/src/main/scala/spark/LocalScheduler.scala index c4919c6516be090ada544612ee1b026acbd1a5f8..0cbc68ffc50ac850be768214ad5b14ddcb00161f 100644 --- a/core/src/main/scala/spark/LocalScheduler.scala +++ b/core/src/main/scala/spark/LocalScheduler.scala @@ -12,11 +12,13 @@ private class LocalScheduler(threads: Int, maxFailures: Int) extends DAGSchedule var attemptId = new AtomicInteger(0) var threadPool = Executors.newFixedThreadPool(threads, DaemonThreadFactory) + // TODO: Need to take into account stage priority in scheduling + override def start() {} override def waitForRegister() {} - override def submitTasks(tasks: Seq[Task[_]]) { + override def submitTasks(tasks: Seq[Task[_]], runId: Int) { val failCount = new Array[Int](tasks.size) def submitTask(task: Task[_], idInJob: Int) { diff --git a/core/src/main/scala/spark/MesosScheduler.scala b/core/src/main/scala/spark/MesosScheduler.scala index 618ee724f9beeaa5bd0fbab09750ebd766bee801..ee14d091ce931647d38f3922681eda17231bfdba 100644 --- a/core/src/main/scala/spark/MesosScheduler.scala +++ b/core/src/main/scala/spark/MesosScheduler.scala @@ -9,8 +9,9 @@ import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet import scala.collection.mutable.Map -import scala.collection.mutable.Queue +import scala.collection.mutable.PriorityQueue import scala.collection.JavaConversions._ +import scala.math.Ordering import com.google.protobuf.ByteString @@ -53,7 +54,7 @@ private class MesosScheduler( private val registeredLock = new Object() private val activeJobs = new HashMap[Int, Job] - private val activeJobsQueue = new Queue[Job] + private var activeJobsQueue = new PriorityQueue[Job]()(jobOrdering) private val taskIdToJobId = new HashMap[String, Int] private val taskIdToSlaveId = new HashMap[String, String] @@ -74,6 +75,13 @@ private class MesosScheduler( // URIs of JARs to pass to executor var jarUris: String = "" + + // Sorts jobs in reverse order of run ID for use in our priority queue (so lower IDs run first) + private val jobOrdering = new Ordering[Job] { + override def compare(j1: Job, j2: Job): Int = { + return j2.runId - j1.runId + } + } def newJobId(): Int = this.synchronized { val id = nextJobId @@ -138,14 +146,13 @@ private class MesosScheduler( .addResources(memory) .build() } - - def submitTasks(tasks: Seq[Task[_]]) { + def submitTasks(tasks: Seq[Task[_]], runId: Int) { logInfo("Got a job with " + tasks.size + " tasks") waitForRegister() this.synchronized { val jobId = newJobId() - val myJob = new SimpleJob(this, tasks, jobId) + val myJob = new SimpleJob(this, tasks, runId, jobId) activeJobs(jobId) = myJob activeJobsQueue += myJob logInfo("Adding job with ID " + jobId) @@ -156,11 +163,11 @@ private class MesosScheduler( def jobFinished(job: Job) { this.synchronized { - activeJobs -= job.getId - activeJobsQueue.dequeueAll(x => (x == job)) - taskIdToJobId --= jobTasks(job.getId) - taskIdToSlaveId --= jobTasks(job.getId) - jobTasks.remove(job.getId) + activeJobs -= job.jobId + activeJobsQueue = activeJobsQueue.filterNot(_ == job) + taskIdToJobId --= jobTasks(job.jobId) + taskIdToSlaveId --= jobTasks(job.jobId) + jobTasks.remove(job.jobId) } } @@ -204,8 +211,8 @@ private class MesosScheduler( tasks(i).add(task) val tid = task.getTaskId.getValue val sid = offers(i).getSlaveId.getValue - taskIdToJobId(tid) = job.getId - jobTasks(job.getId) += tid + taskIdToJobId(tid) = job.jobId + jobTasks(job.jobId) += tid taskIdToSlaveId(tid) = sid slavesWithExecutors += sid availableCpus(i) -= getResource(task.getResourcesList(), "cpus") diff --git a/core/src/main/scala/spark/ResultTask.scala b/core/src/main/scala/spark/ResultTask.scala index 25d85b7e0ced19366b5a2172220382ce5edeef7f..3952bf85b2cdb89f83aaed4bbca8c73086e08f5d 100644 --- a/core/src/main/scala/spark/ResultTask.scala +++ b/core/src/main/scala/spark/ResultTask.scala @@ -1,13 +1,14 @@ package spark class ResultTask[T, U]( + runId: Int, stageId: Int, rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, val partition: Int, locs: Seq[String], val outputId: Int) - extends DAGTask[U](stageId) { + extends DAGTask[U](runId, stageId) { val split = rdd.splits(partition) diff --git a/core/src/main/scala/spark/ShuffleMapTask.scala b/core/src/main/scala/spark/ShuffleMapTask.scala index d7c488109737bcbdfe9a4b56e242d18d49c3a112..5fc59af06c039f6d74638c63cea13ad824058e40 100644 --- a/core/src/main/scala/spark/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/ShuffleMapTask.scala @@ -8,12 +8,13 @@ import java.util.{HashMap => JHashMap} import it.unimi.dsi.fastutil.io.FastBufferedOutputStream class ShuffleMapTask( + runId: Int, stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_,_], val partition: Int, locs: Seq[String]) - extends DAGTask[String](stageId) + extends DAGTask[String](runId, stageId) with Logging { val split = rdd.splits(partition) diff --git a/core/src/main/scala/spark/SimpleJob.scala b/core/src/main/scala/spark/SimpleJob.scala index 6eee8b45cea8d741f1299a82ee1a68ada9dc9206..5e42ae6ecd00521d76a61d043853496e7d309ffe 100644 --- a/core/src/main/scala/spark/SimpleJob.scala +++ b/core/src/main/scala/spark/SimpleJob.scala @@ -16,8 +16,9 @@ import org.apache.mesos.Protos._ class SimpleJob( sched: MesosScheduler, tasksSeq: Seq[Task[_]], - val jobId: Int) - extends Job(jobId) + runId: Int, + jobId: Int) + extends Job(runId, jobId) with Logging { // Maximum time to wait to run a task in a preferred location (in ms) diff --git a/core/src/test/scala/spark/ThreadingSuite.scala b/core/src/test/scala/spark/ThreadingSuite.scala index f47f6ffc95f4319a4de28e2ab9de92e6a4a576da..cadf01432f9a3e1b81c81173730342e2c903eef2 100644 --- a/core/src/test/scala/spark/ThreadingSuite.scala +++ b/core/src/test/scala/spark/ThreadingSuite.scala @@ -1,11 +1,26 @@ package spark import java.util.concurrent.Semaphore +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.AtomicInteger import org.scalatest.FunSuite import SparkContext._ +/** + * Holds state shared across task threads in some ThreadingSuite tests. + */ +object ThreadingSuiteState { + val runningThreads = new AtomicInteger + val failed = new AtomicBoolean + + def clear() { + runningThreads.set(0) + failed.set(false) + } +} + class ThreadingSuite extends FunSuite { test("accessing SparkContext form a different thread") { val sc = new SparkContext("local", "test") @@ -54,4 +69,69 @@ class ThreadingSuite extends FunSuite { } sc.stop() } + + test("accessing multi-threaded SparkContext form multiple threads") { + val sc = new SparkContext("local[4]", "test") + val nums = sc.parallelize(1 to 10, 2) + val sem = new Semaphore(0) + @volatile var ok = true + for (i <- 0 until 10) { + new Thread { + override def run() { + val answer1 = nums.reduce(_ + _) + if (answer1 != 55) { + printf("In thread %d: answer1 was %d\n", i, answer1); + ok = false; + } + val answer2 = nums.first // This will run "locally" in the current thread + if (answer2 != 1) { + printf("In thread %d: answer2 was %d\n", i, answer2); + ok = false; + } + sem.release() + } + }.start() + } + sem.acquire(10) + if (!ok) { + fail("One or more threads got the wrong answer from an RDD operation") + } + sc.stop() + } + + test("parallel job execution") { + // This test launches two jobs with two threads each on a 4-core local cluster. Each thread + // waits until there are 4 threads running at once, to test that both jobs have been launched. + val sc = new SparkContext("local[4]", "test") + val nums = sc.parallelize(1 to 2, 2) + val sem = new Semaphore(0) + ThreadingSuiteState.clear() + for (i <- 0 until 2) { + new Thread { + override def run() { + val ans = nums.map(number => { + val running = ThreadingSuiteState.runningThreads + running.getAndIncrement() + val time = System.currentTimeMillis() + while (running.get() != 4 && System.currentTimeMillis() < time + 1000) { + Thread.sleep(100) + } + if (running.get() != 4) { + println("Waited 1 second without seeing runningThreads = 4 (it was " + + running.get() + "); failing test") + ThreadingSuiteState.failed.set(true) + } + number + }).collect() + assert(ans.toList === List(1, 2)) + sem.release() + } + }.start() + } + sem.acquire(2) + if (ThreadingSuiteState.failed.get()) { + fail("One or more threads didn't see runningThreads = 4") + } + sc.stop() + } }