diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala index 14e87af653b32e7b5c230150cbb6320376befe56..860a38e9f8a01d35230e05f99ebe8221ceb350c8 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala @@ -17,7 +17,8 @@ package spark.scheduler.cluster -import java.util.{HashMap => JHashMap, NoSuchElementException, Arrays} +import java.nio.ByteBuffer +import java.util.{Arrays, NoSuchElementException} import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap @@ -25,12 +26,14 @@ import scala.collection.mutable.HashSet import scala.math.max import scala.math.min -import spark._ -import spark.scheduler._ +import spark.{FetchFailed, Logging, Resubmitted, SparkEnv, Success, TaskEndReason, TaskState, Utils} +import spark.{ExceptionFailure, SparkException, TaskResultTooBigFailure} import spark.TaskState.TaskState -import java.nio.ByteBuffer +import spark.scheduler.{ShuffleMapTask, Task, TaskResult, TaskSet} -private[spark] object TaskLocality extends Enumeration("PROCESS_LOCAL", "NODE_LOCAL", "RACK_LOCAL", "ANY") with Logging { + +private[spark] object TaskLocality + extends Enumeration("PROCESS_LOCAL", "NODE_LOCAL", "RACK_LOCAL", "ANY") with Logging { // process local is expected to be used ONLY within tasksetmanager for now. val PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY = Value @@ -43,8 +46,10 @@ private[spark] object TaskLocality extends Enumeration("PROCESS_LOCAL", "NODE_LO assert (constraint != TaskLocality.PROCESS_LOCAL) constraint match { - case TaskLocality.NODE_LOCAL => condition == TaskLocality.NODE_LOCAL - case TaskLocality.RACK_LOCAL => condition == TaskLocality.NODE_LOCAL || condition == TaskLocality.RACK_LOCAL + case TaskLocality.NODE_LOCAL => + condition == TaskLocality.NODE_LOCAL + case TaskLocality.RACK_LOCAL => + condition == TaskLocality.NODE_LOCAL || condition == TaskLocality.RACK_LOCAL // For anything else, allow case _ => true } @@ -56,11 +61,10 @@ private[spark] object TaskLocality extends Enumeration("PROCESS_LOCAL", "NODE_LO val retval = TaskLocality.withName(str) // Must not specify PROCESS_LOCAL ! assert (retval != TaskLocality.PROCESS_LOCAL) - retval } catch { case nEx: NoSuchElementException => { - logWarning("Invalid task locality specified '" + str + "', defaulting to NODE_LOCAL"); + logWarning("Invalid task locality specified '" + str + "', defaulting to NODE_LOCAL") // default to preserve earlier behavior NODE_LOCAL } @@ -71,11 +75,8 @@ private[spark] object TaskLocality extends Enumeration("PROCESS_LOCAL", "NODE_LO /** * Schedules the tasks within a single TaskSet in the ClusterScheduler. */ -private[spark] class ClusterTaskSetManager( - sched: ClusterScheduler, - val taskSet: TaskSet) - extends TaskSetManager - with Logging { +private[spark] class ClusterTaskSetManager(sched: ClusterScheduler, val taskSet: TaskSet) + extends TaskSetManager with Logging { // Maximum time to wait to run a task in a preferred location (in ms) val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong @@ -106,13 +107,14 @@ private[spark] class ClusterTaskSetManager( var runningTasks = 0 var priority = taskSet.priority var stageId = taskSet.stageId - var name = "TaskSet_"+taskSet.stageId.toString - var parent:Schedulable = null + var name = "TaskSet_" + taskSet.stageId.toString + var parent: Schedulable = null // Last time when we launched a preferred task (for delay scheduling) var lastPreferredLaunchTime = System.currentTimeMillis - // List of pending tasks for each node (process local to container). These collections are actually + // List of pending tasks for each node (process local to container). + // These collections are actually // treated as stacks, in which new tasks are added to the end of the // ArrayBuffer and removed from the end. This makes it faster to detect // tasks that repeatedly fail because whenever a task failed, it is put @@ -172,9 +174,11 @@ private[spark] class ClusterTaskSetManager( // Note that it follows the hierarchy. // if we search for NODE_LOCAL, the output will include PROCESS_LOCAL and // if we search for RACK_LOCAL, it will include PROCESS_LOCAL & NODE_LOCAL - private def findPreferredLocations(_taskPreferredLocations: Seq[String], scheduler: ClusterScheduler, - taskLocality: TaskLocality.TaskLocality): HashSet[String] = { - + private def findPreferredLocations( + _taskPreferredLocations: Seq[String], + scheduler: ClusterScheduler, + taskLocality: TaskLocality.TaskLocality): HashSet[String] = + { if (TaskLocality.PROCESS_LOCAL == taskLocality) { // straight forward comparison ! Special case it. val retval = new HashSet[String]() @@ -189,13 +193,14 @@ private[spark] class ClusterTaskSetManager( return retval } - val taskPreferredLocations = + val taskPreferredLocations = { if (TaskLocality.NODE_LOCAL == taskLocality) { _taskPreferredLocations } else { assert (TaskLocality.RACK_LOCAL == taskLocality) // Expand set to include all 'seen' rack local hosts. - // This works since container allocation/management happens within master - so any rack locality information is updated in msater. + // This works since container allocation/management happens within master - + // so any rack locality information is updated in msater. // Best case effort, and maybe sort of kludge for now ... rework it later ? val hosts = new HashSet[String] _taskPreferredLocations.foreach(h => { @@ -213,6 +218,7 @@ private[spark] class ClusterTaskSetManager( hosts } + } val retval = new HashSet[String] scheduler.synchronized { @@ -229,11 +235,13 @@ private[spark] class ClusterTaskSetManager( // Add a task to all the pending-task lists that it should be on. private def addPendingTask(index: Int) { - // We can infer hostLocalLocations from rackLocalLocations by joining it against tasks(index).preferredLocations (with appropriate - // hostPort <-> host conversion). But not doing it for simplicity sake. If this becomes a performance issue, modify it. - val processLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.PROCESS_LOCAL) - val hostLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL) - val rackLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL) + // We can infer hostLocalLocations from rackLocalLocations by joining it against + // tasks(index).preferredLocations (with appropriate hostPort <-> host conversion). + // But not doing it for simplicity sake. If this becomes a performance issue, modify it. + val locs = tasks(index).preferredLocations + val processLocalLocations = findPreferredLocations(locs, sched, TaskLocality.PROCESS_LOCAL) + val hostLocalLocations = findPreferredLocations(locs, sched, TaskLocality.NODE_LOCAL) + val rackLocalLocations = findPreferredLocations(locs, sched, TaskLocality.RACK_LOCAL) if (rackLocalLocations.size == 0) { // Current impl ensures this. @@ -298,18 +306,24 @@ private[spark] class ClusterTaskSetManager( } // Number of pending tasks for a given host Port (which would be process local) - def numPendingTasksForHostPort(hostPort: String): Int = { - getPendingTasksForHostPort(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) ) + override def numPendingTasksForHostPort(hostPort: String): Int = { + getPendingTasksForHostPort(hostPort).count { index => + copiesRunning(index) == 0 && !finished(index) + } } // Number of pending tasks for a given host (which would be data local) - def numPendingTasksForHost(hostPort: String): Int = { - getPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) ) + override def numPendingTasksForHost(hostPort: String): Int = { + getPendingTasksForHost(hostPort).count { index => + copiesRunning(index) == 0 && !finished(index) + } } // Number of pending rack local tasks for a given host - def numRackLocalPendingTasksForHost(hostPort: String): Int = { - getRackLocalPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) ) + override def numRackLocalPendingTasksForHost(hostPort: String): Int = { + getRackLocalPendingTasksForHost(hostPort).count { index => + copiesRunning(index) == 0 && !finished(index) + } } @@ -337,12 +351,12 @@ private[spark] class ClusterTaskSetManager( speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set if (speculatableTasks.size > 0) { - val localTask = speculatableTasks.find { - index => - val locations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL) - val attemptLocs = taskAttempts(index).map(_.hostPort) - (locations.size == 0 || locations.contains(hostPort)) && !attemptLocs.contains(hostPort) - } + val localTask = speculatableTasks.find { index => + val locations = findPreferredLocations(tasks(index).preferredLocations, sched, + TaskLocality.NODE_LOCAL) + val attemptLocs = taskAttempts(index).map(_.hostPort) + (locations.size == 0 || locations.contains(hostPort)) && !attemptLocs.contains(hostPort) + } if (localTask != None) { speculatableTasks -= localTask.get @@ -351,11 +365,11 @@ private[spark] class ClusterTaskSetManager( // check for rack locality if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { - val rackTask = speculatableTasks.find { - index => - val locations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL) - val attemptLocs = taskAttempts(index).map(_.hostPort) - locations.contains(hostPort) && !attemptLocs.contains(hostPort) + val rackTask = speculatableTasks.find { index => + val locations = findPreferredLocations(tasks(index).preferredLocations, sched, + TaskLocality.RACK_LOCAL) + val attemptLocs = taskAttempts(index).map(_.hostPort) + locations.contains(hostPort) && !attemptLocs.contains(hostPort) } if (rackTask != None) { @@ -367,7 +381,9 @@ private[spark] class ClusterTaskSetManager( // Any task ... if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { // Check for attemptLocs also ? - val nonLocalTask = speculatableTasks.find(i => !taskAttempts(i).map(_.hostPort).contains(hostPort)) + val nonLocalTask = speculatableTasks.find { i => + !taskAttempts(i).map(_.hostPort).contains(hostPort) + } if (nonLocalTask != None) { speculatableTasks -= nonLocalTask.get return nonLocalTask @@ -397,7 +413,8 @@ private[spark] class ClusterTaskSetManager( } } - // Look for no pref tasks AFTER rack local tasks - this has side effect that we will get to failed tasks later rather than sooner. + // Look for no pref tasks AFTER rack local tasks - this has side effect that we will get to + // failed tasks later rather than sooner. // TODO: That code path needs to be revisited (adding to no prefs list when host:port goes down). val noPrefTask = findTaskFromList(pendingTasksWithNoPrefs) if (noPrefTask != None) { @@ -433,7 +450,8 @@ private[spark] class ClusterTaskSetManager( locs.find(h => Utils.parseHostPort(h)._1 == host).isDefined } - // Does a host count as a rack local preferred location for a task? (assumes host is NOT preferred location). + // Does a host count as a rack local preferred location for a task? + // (assumes host is NOT preferred location). // This is true if either the task has preferred locations and this host is one, or it has // no preferred locations (in which we still count the launch as preferred). private def isRackLocalLocation(task: Task[_], hostPort: String): Boolean = { @@ -454,14 +472,22 @@ private[spark] class ClusterTaskSetManager( } // Respond to an offer of a single slave from the scheduler by finding a task - def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = { - + override def slaveOffer( + execId: String, + hostPort: String, + availableCpus: Double, + overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = + { if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) { // If explicitly specified, use that val locality = if (overrideLocality != null) overrideLocality else { // expand only if we have waited for more than LOCALITY_WAIT for a host local task ... val time = System.currentTimeMillis - if (time - lastPreferredLaunchTime < LOCALITY_WAIT) TaskLocality.NODE_LOCAL else TaskLocality.ANY + if (time - lastPreferredLaunchTime < LOCALITY_WAIT) { + TaskLocality.NODE_LOCAL + } else { + TaskLocality.ANY + } } findTask(hostPort, locality) match { @@ -489,6 +515,8 @@ private[spark] class ClusterTaskSetManager( } // Serialize and return the task val startTime = System.currentTimeMillis + // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here + // we assume the task can be serialized without exceptions. val serializedTask = Task.serializeWithDependencies( task, sched.sc.addedFiles, sched.sc.addedJars, ser) val timeTaken = System.currentTimeMillis - startTime @@ -506,7 +534,7 @@ private[spark] class ClusterTaskSetManager( return None } - def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { + override def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { state match { case TaskState.FINISHED => taskFinished(tid, state, serializedData) @@ -542,7 +570,8 @@ private[spark] class ClusterTaskSetManager( try { val result = ser.deserialize[TaskResult[_]](serializedData) result.metrics.resultSize = serializedData.limit() - sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates, info, result.metrics) + sched.listener.taskEnded( + tasks(index), Success, result.value, result.accumUpdates, info, result.metrics) } catch { case cnf: ClassNotFoundException => val loader = Thread.currentThread().getContextClassLoader @@ -588,8 +617,8 @@ private[spark] class ClusterTaskSetManager( return case taskResultTooBig: TaskResultTooBigFailure => - logInfo("Loss was due to task %s result exceeding Akka frame size; " + - "aborting job".format(tid)) + logInfo("Loss was due to task %s result exceeding Akka frame size; aborting job".format( + tid)) abort("Task %s result exceeded Akka frame size".format(tid)) return @@ -640,7 +669,7 @@ private[spark] class ClusterTaskSetManager( } } - def error(message: String) { + override def error(message: String) { // Save the error message abort("Error: " + message) } @@ -668,7 +697,8 @@ private[spark] class ClusterTaskSetManager( } } - //TODO: for now we just find Pool not TaskSetManager, we can extend this function in future if needed + // TODO: for now we just find Pool not TaskSetManager, + // we can extend this function in future if needed override def getSchedulableByName(name: String): Schedulable = { return null } @@ -693,13 +723,15 @@ private[spark] class ClusterTaskSetManager( // If some task has preferred locations only on hostname, and there are no more executors there, // put it in the no-prefs list to avoid the wait from delay scheduling - // host local tasks - should we push this to rack local or no pref list ? For now, preserving behavior and moving to - // no prefs list. Note, this was done due to impliations related to 'waiting' for data local tasks, etc. - // Note: NOT checking process local list - since host local list is super set of that. We need to ad to no prefs only if - // there is no host local node for the task (not if there is no process local node for the task) + // host local tasks - should we push this to rack local or no pref list ? For now, preserving + // behavior and moving to no prefs list. Note, this was done due to impliations related to + // 'waiting' for data local tasks, etc. + // Note: NOT checking process local list - since host local list is super set of that. We need + // to ad to no prefs only if there is no host local node for the task (not if there is no + // process local node for the task) for (index <- getPendingTasksForHost(Utils.parseHostPort(hostPort)._1)) { - // val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL) - val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL) + val newLocs = findPreferredLocations( + tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL) if (newLocs.isEmpty) { pendingTasksWithNoPrefs += index } diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index 07c3ddcc7e66a669e4a9d5576bbf0f7535226441..7978a5df7464d316f195b3423983130af1c192a7 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -17,18 +17,28 @@ package spark.scheduler.cluster -import scala.collection.mutable.ArrayBuffer -import spark.scheduler._ -import spark.TaskState.TaskState import java.nio.ByteBuffer +import spark.TaskState.TaskState +import spark.scheduler.TaskSet + private[spark] trait TaskSetManager extends Schedulable { + def taskSet: TaskSet - def slaveOffer(execId: String, hostPort: String, availableCpus: Double, - overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] + + def slaveOffer( + execId: String, + hostPort: String, + availableCpus: Double, + overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] + def numPendingTasksForHostPort(hostPort: String): Int - def numRackLocalPendingTasksForHost(hostPort :String): Int + + def numRackLocalPendingTasksForHost(hostPort: String): Int + def numPendingTasksForHost(hostPort: String): Int + def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) + def error(message: String) } diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index 1f73cb99a78e47834cc7aeb88787f11654d7f17c..edd83d4cb49639c236d0df50ee33fe66fc6b4d4d 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -37,10 +37,15 @@ import akka.actor._ * testing fault recovery. */ -private[spark] case class LocalReviveOffers() -private[spark] case class LocalStatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) +private[spark] +case class LocalReviveOffers() + +private[spark] +case class LocalStatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) + +private[spark] +class LocalActor(localScheduler: LocalScheduler, var freeCores: Int) extends Actor with Logging { -private[spark] class LocalActor(localScheduler: LocalScheduler, var freeCores: Int) extends Actor with Logging { def receive = { case LocalReviveOffers => launchTask(localScheduler.resourceOffer(freeCores)) @@ -55,7 +60,7 @@ private[spark] class LocalActor(localScheduler: LocalScheduler, var freeCores: I freeCores -= 1 localScheduler.threadPool.submit(new Runnable { def run() { - localScheduler.runTask(task.taskId,task.serializedTask) + localScheduler.runTask(task.taskId, task.serializedTask) } }) } @@ -110,7 +115,7 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: override def submitTasks(taskSet: TaskSet) { synchronized { - var manager = new LocalTaskSetManager(this, taskSet) + val manager = new LocalTaskSetManager(this, taskSet) schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) activeTaskSets(taskSet.id) = manager taskSetTaskIds(taskSet.id) = new HashSet[Long]() @@ -124,14 +129,15 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: val tasks = new ArrayBuffer[TaskDescription](freeCores) val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue() for (manager <- sortedTaskSetQueue) { - logDebug("parentName:%s,name:%s,runningTasks:%s".format(manager.parent.name, manager.name, manager.runningTasks)) + logDebug("parentName:%s,name:%s,runningTasks:%s".format( + manager.parent.name, manager.name, manager.runningTasks)) } var launchTask = false for (manager <- sortedTaskSetQueue) { do { launchTask = false - manager.slaveOffer(null,null,freeCpuCores) match { + manager.slaveOffer(null, null, freeCpuCores) match { case Some(task) => tasks += task taskIdToTaskSetId(task.taskId) = manager.taskSet.id @@ -139,7 +145,7 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: freeCpuCores -= 1 launchTask = true case None => {} - } + } } while(launchTask) } return tasks diff --git a/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala index b500451990397efc28285cb05e5bdc5ee2c156c5..b29740c886e572b7f82413dce53097fe710ec1d4 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala @@ -17,27 +17,26 @@ package spark.scheduler.local -import java.io.File -import java.util.concurrent.atomic.AtomicInteger import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet -import spark._ +import spark.{ExceptionFailure, Logging, SparkEnv, Success, TaskState} import spark.TaskState.TaskState -import spark.scheduler._ -import spark.scheduler.cluster._ +import spark.scheduler.{Task, TaskResult, TaskSet} +import spark.scheduler.cluster.{Schedulable, TaskDescription, TaskInfo, TaskLocality, TaskSetManager} + + +private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: TaskSet) + extends TaskSetManager with Logging { -private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: TaskSet) extends TaskSetManager with Logging { var parent: Schedulable = null var weight: Int = 1 var minShare: Int = 0 var runningTasks: Int = 0 var priority: Int = taskSet.priority var stageId: Int = taskSet.stageId - var name: String = "TaskSet_"+taskSet.stageId.toString - + var name: String = "TaskSet_" + taskSet.stageId.toString var failCount = new Array[Int](taskSet.tasks.size) val taskInfos = new HashMap[Long, TaskInfo] @@ -49,49 +48,45 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas val numFailures = new Array[Int](numTasks) val MAX_TASK_FAILURES = sched.maxFailures - def increaseRunningTasks(taskNum: Int): Unit = { - runningTasks += taskNum - if (parent != null) { - parent.increaseRunningTasks(taskNum) - } + override def increaseRunningTasks(taskNum: Int): Unit = { + runningTasks += taskNum + if (parent != null) { + parent.increaseRunningTasks(taskNum) + } } - def decreaseRunningTasks(taskNum: Int): Unit = { + override def decreaseRunningTasks(taskNum: Int): Unit = { runningTasks -= taskNum if (parent != null) { parent.decreaseRunningTasks(taskNum) } } - def addSchedulable(schedulable: Schedulable): Unit = { + override def addSchedulable(schedulable: Schedulable): Unit = { //nothing } - def removeSchedulable(schedulable: Schedulable): Unit = { + override def removeSchedulable(schedulable: Schedulable): Unit = { //nothing } - def getSchedulableByName(name: String): Schedulable = { + override def getSchedulableByName(name: String): Schedulable = { return null } - def executorLost(executorId: String, host: String): Unit = { + override def executorLost(executorId: String, host: String): Unit = { //nothing } - def checkSpeculatableTasks(): Boolean = { - return true - } + override def checkSpeculatableTasks() = true - def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { + override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager] sortedTaskSetQueue += this return sortedTaskSetQueue } - def hasPendingTasks(): Boolean = { - return true - } + override def hasPendingTasks() = true def findTask(): Option[Int] = { for (i <- 0 to numTasks-1) { @@ -102,17 +97,27 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas return None } - def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = { + override def slaveOffer( + execId: String, + hostPort: String, + availableCpus: Double, + overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = + { SparkEnv.set(sched.env) - logDebug("availableCpus:%d,numFinished:%d,numTasks:%d".format(availableCpus.toInt, numFinished, numTasks)) + logDebug("availableCpus:%d,numFinished:%d,numTasks:%d".format( + availableCpus.toInt, numFinished, numTasks)) if (availableCpus > 0 && numFinished < numTasks) { findTask() match { case Some(index) => val taskId = sched.attemptId.getAndIncrement() val task = taskSet.tasks(index) - val info = new TaskInfo(taskId, index, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL) + val info = new TaskInfo(taskId, index, System.currentTimeMillis(), "local", "local:1", + TaskLocality.NODE_LOCAL) taskInfos(taskId) = info - val bytes = Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser) + // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here + // we assume the task can be serialized without exceptions. + val bytes = Task.serializeWithDependencies( + task, sched.sc.addedFiles, sched.sc.addedJars, ser) logInfo("Size of task " + taskId + " is " + bytes.limit + " bytes") val taskName = "task %s:%d".format(taskSet.id, index) copiesRunning(index) += 1 @@ -125,19 +130,19 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas return None } - def numPendingTasksForHostPort(hostPort: String): Int = { + override def numPendingTasksForHostPort(hostPort: String): Int = { return 0 } - def numRackLocalPendingTasksForHost(hostPort :String): Int = { + override def numRackLocalPendingTasksForHost(hostPort :String): Int = { return 0 } - def numPendingTasksForHost(hostPort: String): Int = { + override def numPendingTasksForHost(hostPort: String): Int = { return 0 } - def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { + override def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { state match { case TaskState.FINISHED => taskEnded(tid, state, serializedData) @@ -173,15 +178,18 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas val task = taskSet.tasks(index) info.markFailed() decreaseRunningTasks(1) - val reason: ExceptionFailure = ser.deserialize[ExceptionFailure](serializedData, getClass.getClassLoader) + val reason: ExceptionFailure = ser.deserialize[ExceptionFailure]( + serializedData, getClass.getClassLoader) sched.listener.taskEnded(task, reason, null, null, info, reason.metrics.getOrElse(null)) if (!finished(index)) { copiesRunning(index) -= 1 numFailures(index) += 1 val locs = reason.stackTrace.map(loc => "\tat %s".format(loc.toString)) - logInfo("Loss was due to %s\n%s\n%s".format(reason.className, reason.description, locs.mkString("\n"))) + logInfo("Loss was due to %s\n%s\n%s".format( + reason.className, reason.description, locs.mkString("\n"))) if (numFailures(index) > MAX_TASK_FAILURES) { - val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format(taskSet.id, index, 4, reason.description) + val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format( + taskSet.id, index, 4, reason.description) decreaseRunningTasks(runningTasks) sched.listener.taskSetFailed(taskSet, errorMessage) // need to delete failed Taskset from schedule queue @@ -190,6 +198,6 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas } } - def error(message: String) { + override def error(message: String) { } } diff --git a/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala index 8f81d0b6ee559c4cfc9896ff6df4e6b50f1ad237..05afcd656760bff60f54c1111cea4f03cc4b22b9 100644 --- a/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala +++ b/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala @@ -33,7 +33,7 @@ class DummyTaskSetManager( initNumTasks: Int, clusterScheduler: ClusterScheduler, taskSet: TaskSet) - extends ClusterTaskSetManager(clusterScheduler,taskSet) { + extends ClusterTaskSetManager(clusterScheduler, taskSet) { parent = null weight = 1 diff --git a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala index f802b66cf13f914ba1eab1f144650c2aea59c4ab..a8b88d7936e4f5413d540c5c56c0fa094286aaff 100644 --- a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala @@ -23,21 +23,14 @@ import org.scalatest.FunSuite import org.scalatest.BeforeAndAfter import spark.LocalSparkContext - -import spark.storage.BlockManager -import spark.storage.BlockManagerId -import spark.storage.BlockManagerMaster -import spark.{Dependency, ShuffleDependency, OneToOneDependency} -import spark.FetchFailedException import spark.MapOutputTracker import spark.RDD import spark.SparkContext -import spark.SparkException import spark.Partition import spark.TaskContext -import spark.TaskEndReason - -import spark.{FetchFailed, Success} +import spark.{Dependency, ShuffleDependency, OneToOneDependency} +import spark.{FetchFailed, Success, TaskEndReason} +import spark.storage.{BlockManagerId, BlockManagerMaster} /** * Tests for DAGScheduler. These tests directly call the event processing functions in DAGScheduler