Skip to content
Snippets Groups Projects
Commit aa04f87c authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Added support for parallel execution of jobs in DAGScheduler.

parent 2587ce16
No related branches found
No related tags found
No related merge requests found
package spark
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.TimeUnit
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue, Map}
/**
* A task created by the DAG scheduler. Knows its stage ID and map ouput tracker generation.
*/
abstract class DAGTask[T](val stageId: Int) extends Task[T] {
abstract class DAGTask[T](val runId: Int, val stageId: Int) extends Task[T] {
val gen = SparkEnv.get.mapOutputTracker.getGeneration
override def generation: Option[Long] = Some(gen)
}
/**
* A completion event passed by the underlying task scheduler to the DAG scheduler
* A completion event passed by the underlying task scheduler to the DAG scheduler.
*/
case class CompletionEvent(
task: DAGTask[_],
......@@ -39,13 +39,22 @@ case class OtherFailure(message: String) extends TaskEndReason
* and to report fetch failures (the submitTasks method, and code to add CompletionEvents).
*/
private trait DAGScheduler extends Scheduler with Logging {
// Must be implemented by subclasses to start running a set of tasks
def submitTasks(tasks: Seq[Task[_]]): Unit
// Must be implemented by subclasses to start running a set of tasks. The subclass should also
// attempt to run different sets of tasks in the order given by runId (lower values first).
def submitTasks(tasks: Seq[Task[_]], runId: Int): Unit
// Must be called by subclasses to report task completions or failures
// Must be called by subclasses to report task completions or failures.
def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any]) {
val dagTask = task.asInstanceOf[DAGTask[_]]
completionEvents.put(CompletionEvent(dagTask, reason, result, accumUpdates))
lock.synchronized {
val dagTask = task.asInstanceOf[DAGTask[_]]
eventQueues.get(dagTask.runId) match {
case Some(queue) =>
queue += CompletionEvent(dagTask, reason, result, accumUpdates)
lock.notifyAll()
case None =>
logInfo("Ignoring completion event for DAG job " + dagTask.runId + " because it's gone")
}
}
}
// The time, in millis, to wait for fetch failure events to stop coming in after one is detected;
......@@ -57,16 +66,13 @@ private trait DAGScheduler extends Scheduler with Logging {
// resubmit failed stages
val POLL_TIMEOUT = 500L
private val completionEvents = new LinkedBlockingQueue[CompletionEvent]
private val lock = new Object
private val lock = new Object // Used for access to the entire DAGScheduler
var nextStageId = 0
private val eventQueues = new HashMap[Int, Queue[CompletionEvent]] // Indexed by run ID
def newStageId() = {
var res = nextStageId
nextStageId += 1
res
}
val nextRunId = new AtomicInteger(0)
val nextStageId = new AtomicInteger(0)
val idToStage = new HashMap[Int, Stage]
......@@ -103,7 +109,7 @@ private trait DAGScheduler extends Scheduler with Logging {
if (shuffleDep != None) {
mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size)
}
val id = newStageId()
val id = nextStageId.getAndIncrement()
val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd))
idToStage(id) = stage
stage
......@@ -167,6 +173,8 @@ private trait DAGScheduler extends Scheduler with Logging {
allowLocal: Boolean)
(implicit m: ClassManifest[U]): Array[U] = {
lock.synchronized {
val runId = nextRunId.getAndIncrement()
val outputParts = partitions.toArray
val numOutputParts: Int = partitions.size
val finalStage = newStage(finalRdd, None)
......@@ -196,6 +204,9 @@ private trait DAGScheduler extends Scheduler with Logging {
val taskContext = new TaskContext(finalStage.id, outputParts(0), 0)
return Array(func(taskContext, finalRdd.iterator(split)))
}
// Register the job ID so that we can get completion events for it
eventQueues(runId) = new Queue[CompletionEvent]
def submitStage(stage: Stage) {
if (!waiting(stage) && !running(stage)) {
......@@ -221,26 +232,27 @@ private trait DAGScheduler extends Scheduler with Logging {
for (id <- 0 until numOutputParts if (!finished(id))) {
val part = outputParts(id)
val locs = getPreferredLocs(finalRdd, part)
tasks += new ResultTask(finalStage.id, finalRdd, func, part, locs, id)
tasks += new ResultTask(runId, finalStage.id, finalRdd, func, part, locs, id)
}
} else {
for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) {
val locs = getPreferredLocs(stage.rdd, p)
tasks += new ShuffleMapTask(stage.id, stage.rdd, stage.shuffleDep.get, p, locs)
tasks += new ShuffleMapTask(runId, stage.id, stage.rdd, stage.shuffleDep.get, p, locs)
}
}
myPending ++= tasks
submitTasks(tasks)
submitTasks(tasks, runId)
}
submitStage(finalStage)
while (numFinished != numOutputParts) {
val evt = completionEvents.poll(POLL_TIMEOUT, TimeUnit.MILLISECONDS)
val eventOption = waitForEvent(runId, POLL_TIMEOUT)
val time = System.currentTimeMillis // TODO: use a pluggable clock for testability
// If we got an event off the queue, mark the task done or react to a fetch failure
if (evt != null) {
if (eventOption != None) {
val evt = eventOption.get
val stage = idToStage(evt.task.stageId)
pendingTasks(stage) -= evt.task
if (evt.reason == Success) {
......@@ -315,6 +327,7 @@ private trait DAGScheduler extends Scheduler with Logging {
}
}
eventQueues -= runId
return results
}
}
......@@ -344,4 +357,18 @@ private trait DAGScheduler extends Scheduler with Logging {
})
return Nil
}
// Assumes that lock is held on entrance, but will release it to wait for the next event.
def waitForEvent(runId: Int, timeout: Long): Option[CompletionEvent] = {
val endTime = System.currentTimeMillis() + timeout // TODO: Use pluggable clock for testing
while (eventQueues(runId).isEmpty) {
val time = System.currentTimeMillis()
if (time > endTime) {
return None
} else {
lock.wait(endTime - time)
}
}
return Some(eventQueues(runId).dequeue())
}
}
......@@ -7,12 +7,10 @@ import org.apache.mesos.Protos._
* Class representing a parallel job in MesosScheduler. Schedules the job by implementing various
* callbacks.
*/
abstract class Job(jobId: Int) {
abstract class Job(val runId: Int, val jobId: Int) {
def slaveOffer(s: Offer, availableCpus: Double): Option[TaskDescription]
def statusUpdate(t: TaskStatus): Unit
def error(code: Int, message: String): Unit
def getId(): Int = jobId
}
......@@ -12,11 +12,13 @@ private class LocalScheduler(threads: Int, maxFailures: Int) extends DAGSchedule
var attemptId = new AtomicInteger(0)
var threadPool = Executors.newFixedThreadPool(threads, DaemonThreadFactory)
// TODO: Need to take into account stage priority in scheduling
override def start() {}
override def waitForRegister() {}
override def submitTasks(tasks: Seq[Task[_]]) {
override def submitTasks(tasks: Seq[Task[_]], runId: Int) {
val failCount = new Array[Int](tasks.size)
def submitTask(task: Task[_], idInJob: Int) {
......
......@@ -9,8 +9,9 @@ import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import scala.collection.mutable.Map
import scala.collection.mutable.Queue
import scala.collection.mutable.PriorityQueue
import scala.collection.JavaConversions._
import scala.math.Ordering
import com.google.protobuf.ByteString
......@@ -53,7 +54,7 @@ private class MesosScheduler(
private val registeredLock = new Object()
private val activeJobs = new HashMap[Int, Job]
private val activeJobsQueue = new Queue[Job]
private var activeJobsQueue = new PriorityQueue[Job]()(jobOrdering)
private val taskIdToJobId = new HashMap[String, Int]
private val taskIdToSlaveId = new HashMap[String, String]
......@@ -74,6 +75,13 @@ private class MesosScheduler(
// URIs of JARs to pass to executor
var jarUris: String = ""
// Sorts jobs in reverse order of run ID for use in our priority queue (so lower IDs run first)
private val jobOrdering = new Ordering[Job] {
override def compare(j1: Job, j2: Job): Int = {
return j2.runId - j1.runId
}
}
def newJobId(): Int = this.synchronized {
val id = nextJobId
......@@ -138,14 +146,13 @@ private class MesosScheduler(
.addResources(memory)
.build()
}
def submitTasks(tasks: Seq[Task[_]]) {
def submitTasks(tasks: Seq[Task[_]], runId: Int) {
logInfo("Got a job with " + tasks.size + " tasks")
waitForRegister()
this.synchronized {
val jobId = newJobId()
val myJob = new SimpleJob(this, tasks, jobId)
val myJob = new SimpleJob(this, tasks, runId, jobId)
activeJobs(jobId) = myJob
activeJobsQueue += myJob
logInfo("Adding job with ID " + jobId)
......@@ -156,11 +163,11 @@ private class MesosScheduler(
def jobFinished(job: Job) {
this.synchronized {
activeJobs -= job.getId
activeJobsQueue.dequeueAll(x => (x == job))
taskIdToJobId --= jobTasks(job.getId)
taskIdToSlaveId --= jobTasks(job.getId)
jobTasks.remove(job.getId)
activeJobs -= job.jobId
activeJobsQueue = activeJobsQueue.filterNot(_ == job)
taskIdToJobId --= jobTasks(job.jobId)
taskIdToSlaveId --= jobTasks(job.jobId)
jobTasks.remove(job.jobId)
}
}
......@@ -204,8 +211,8 @@ private class MesosScheduler(
tasks(i).add(task)
val tid = task.getTaskId.getValue
val sid = offers(i).getSlaveId.getValue
taskIdToJobId(tid) = job.getId
jobTasks(job.getId) += tid
taskIdToJobId(tid) = job.jobId
jobTasks(job.jobId) += tid
taskIdToSlaveId(tid) = sid
slavesWithExecutors += sid
availableCpus(i) -= getResource(task.getResourcesList(), "cpus")
......
package spark
class ResultTask[T, U](
runId: Int,
stageId: Int,
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
val partition: Int,
locs: Seq[String],
val outputId: Int)
extends DAGTask[U](stageId) {
extends DAGTask[U](runId, stageId) {
val split = rdd.splits(partition)
......
......@@ -8,12 +8,13 @@ import java.util.{HashMap => JHashMap}
import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
class ShuffleMapTask(
runId: Int,
stageId: Int,
rdd: RDD[_],
dep: ShuffleDependency[_,_,_],
val partition: Int,
locs: Seq[String])
extends DAGTask[String](stageId)
extends DAGTask[String](runId, stageId)
with Logging {
val split = rdd.splits(partition)
......
......@@ -16,8 +16,9 @@ import org.apache.mesos.Protos._
class SimpleJob(
sched: MesosScheduler,
tasksSeq: Seq[Task[_]],
val jobId: Int)
extends Job(jobId)
runId: Int,
jobId: Int)
extends Job(runId, jobId)
with Logging {
// Maximum time to wait to run a task in a preferred location (in ms)
......
package spark
import java.util.concurrent.Semaphore
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicInteger
import org.scalatest.FunSuite
import SparkContext._
/**
* Holds state shared across task threads in some ThreadingSuite tests.
*/
object ThreadingSuiteState {
val runningThreads = new AtomicInteger
val failed = new AtomicBoolean
def clear() {
runningThreads.set(0)
failed.set(false)
}
}
class ThreadingSuite extends FunSuite {
test("accessing SparkContext form a different thread") {
val sc = new SparkContext("local", "test")
......@@ -54,4 +69,69 @@ class ThreadingSuite extends FunSuite {
}
sc.stop()
}
test("accessing multi-threaded SparkContext form multiple threads") {
val sc = new SparkContext("local[4]", "test")
val nums = sc.parallelize(1 to 10, 2)
val sem = new Semaphore(0)
@volatile var ok = true
for (i <- 0 until 10) {
new Thread {
override def run() {
val answer1 = nums.reduce(_ + _)
if (answer1 != 55) {
printf("In thread %d: answer1 was %d\n", i, answer1);
ok = false;
}
val answer2 = nums.first // This will run "locally" in the current thread
if (answer2 != 1) {
printf("In thread %d: answer2 was %d\n", i, answer2);
ok = false;
}
sem.release()
}
}.start()
}
sem.acquire(10)
if (!ok) {
fail("One or more threads got the wrong answer from an RDD operation")
}
sc.stop()
}
test("parallel job execution") {
// This test launches two jobs with two threads each on a 4-core local cluster. Each thread
// waits until there are 4 threads running at once, to test that both jobs have been launched.
val sc = new SparkContext("local[4]", "test")
val nums = sc.parallelize(1 to 2, 2)
val sem = new Semaphore(0)
ThreadingSuiteState.clear()
for (i <- 0 until 2) {
new Thread {
override def run() {
val ans = nums.map(number => {
val running = ThreadingSuiteState.runningThreads
running.getAndIncrement()
val time = System.currentTimeMillis()
while (running.get() != 4 && System.currentTimeMillis() < time + 1000) {
Thread.sleep(100)
}
if (running.get() != 4) {
println("Waited 1 second without seeing runningThreads = 4 (it was " +
running.get() + "); failing test")
ThreadingSuiteState.failed.set(true)
}
number
}).collect()
assert(ans.toList === List(1, 2))
sem.release()
}
}.start()
}
sem.acquire(2)
if (ThreadingSuiteState.failed.get()) {
fail("One or more threads didn't see runningThreads = 4")
}
sc.stop()
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment