diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index 0adbf1d96e9d38f2ba502e85287297902d3b4da8..bca90886a32692add76c479b182bd479d8673166 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -54,11 +54,7 @@ class SparkEnv ( val connectionManager: ConnectionManager, val httpFileServer: HttpFileServer, val sparkFilesDir: String, - val metricsSystem: MetricsSystem, - // To be set only as part of initialization of SparkContext. - // (executorId, defaultHostPort) => executorHostPort - // If executorId is NOT found, return defaultHostPort - var executorIdToHostPort: Option[(String, String) => String]) { + val metricsSystem: MetricsSystem) { private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() @@ -83,16 +79,6 @@ class SparkEnv ( pythonWorkers.getOrElseUpdate(key, new PythonWorkerFactory(pythonExec, envVars)).create() } } - - def resolveExecutorIdToHostPort(executorId: String, defaultHostPort: String): String = { - val env = SparkEnv.get - if (env.executorIdToHostPort.isEmpty) { - // default to using host, not host port. Relevant to non cluster modes. - return defaultHostPort - } - - env.executorIdToHostPort.get(executorId, defaultHostPort) - } } object SparkEnv extends Logging { @@ -236,7 +222,6 @@ object SparkEnv extends Logging { connectionManager, httpFileServer, sparkFilesDir, - metricsSystem, - None) + metricsSystem) } } diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 885a7391d63d9bbd81d79e8c6d68bad24eb82748..a05dcdcd97da3b7790ada73b0cd415de979cd524 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -393,41 +393,14 @@ private object Utils extends Logging { retval } -/* - // Used by DEBUG code : remove when all testing done - private val ipPattern = Pattern.compile("^[0-9]+(\\.[0-9]+)*$") def checkHost(host: String, message: String = "") { - // Currently catches only ipv4 pattern, this is just a debugging tool - not rigourous ! - // if (host.matches("^[0-9]+(\\.[0-9]+)*$")) { - if (ipPattern.matcher(host).matches()) { - Utils.logErrorWithStack("Unexpected to have host " + host + " which matches IP pattern. Message " + message) - } - if (Utils.parseHostPort(host)._2 != 0){ - Utils.logErrorWithStack("Unexpected to have host " + host + " which has port in it. Message " + message) - } + assert(host.indexOf(':') == -1, message) } - // Used by DEBUG code : remove when all testing done def checkHostPort(hostPort: String, message: String = "") { - val (host, port) = Utils.parseHostPort(hostPort) - checkHost(host) - if (port <= 0){ - Utils.logErrorWithStack("Unexpected to have port " + port + " which is not valid in " + hostPort + ". Message " + message) - } + assert(hostPort.indexOf(':') != -1, message) } - // Used by DEBUG code : remove when all testing done - def logErrorWithStack(msg: String) { - try { throw new Exception } catch { case ex: Exception => { logError(msg, ex) } } - // temp code for debug - System.exit(-1) - } -*/ - - // Once testing is complete in various modes, replace with this ? - def checkHost(host: String, message: String = "") {} - def checkHostPort(hostPort: String, message: String = "") {} - // Used by DEBUG code : remove when all testing done def logErrorWithStack(msg: String) { try { throw new Exception } catch { case ex: Exception => { logError(msg, ex) } } diff --git a/core/src/main/scala/spark/deploy/DeployMessage.scala b/core/src/main/scala/spark/deploy/DeployMessage.scala index 31861f3ac2d985b5615eed7bb08adca40a6e02af..0db13ffc98723ff59183e02212a1e13a5575a5b9 100644 --- a/core/src/main/scala/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/spark/deploy/DeployMessage.scala @@ -80,6 +80,7 @@ private[deploy] object DeployMessages { case class RegisteredApplication(appId: String) extends DeployMessage + // TODO(matei): replace hostPort with host case class ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) { Utils.checkHostPort(hostPort, "Required hostport") } diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala index f661accd2ff2f37d64c8c59e085546b7410cd990..5e53d95ac2f364dd01e601200e4d361b4a0596f2 100644 --- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala @@ -40,13 +40,11 @@ private[spark] class ExecutorRunner( val memory: Int, val worker: ActorRef, val workerId: String, - val hostPort: String, + val host: String, val sparkHome: File, val workDir: File) extends Logging { - Utils.checkHostPort(hostPort, "Expected hostport") - val fullId = appId + "/" + execId var workerThread: Thread = null var process: Process = null @@ -92,7 +90,7 @@ private[spark] class ExecutorRunner( /** Replace variables such as {{EXECUTOR_ID}} and {{CORES}} in a command argument passed to us */ def substituteVariables(argument: String): String = argument match { case "{{EXECUTOR_ID}}" => execId.toString - case "{{HOSTNAME}}" => Utils.parseHostPort(hostPort)._1 + case "{{HOSTNAME}}" => host case "{{CORES}}" => cores.toString case other => other } diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala index d4b58fc34e88deacb1670d2ef7b9c9c5b88005d5..053ac5522607b0d9bd57df2d4501163adae0806c 100644 --- a/core/src/main/scala/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/spark/deploy/worker/Worker.scala @@ -132,7 +132,7 @@ private[spark] class Worker( case LaunchExecutor(appId, execId, appDesc, cores_, memory_, execSparkHome_) => logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name)) val manager = new ExecutorRunner( - appId, execId, appDesc, cores_, memory_, self, workerId, host + ":" + port, new File(execSparkHome_), workDir) + appId, execId, appDesc, cores_, memory_, self, workerId, host, new File(execSparkHome_), workDir) executors(appId + "/" + execId) = manager manager.start() coresUsed += cores_ diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala index 150e5bca29d4ffc4ebcea95c5f98082808153837..91b3e69d6fb3e759916c917be1d4e8b004466b53 100644 --- a/core/src/main/scala/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala @@ -64,7 +64,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( override def getPreferredLocations(split: Partition): Seq[String] = { val currSplit = split.asInstanceOf[CartesianPartition] - rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2) + (rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)).distinct } override def compute(split: Partition, context: TaskContext) = { diff --git a/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala index 51f5cc3251786080f57191d9b2157ccc31b76117..9a0831bd899e2ff48e7db6601e8fdc72ac279208 100644 --- a/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala @@ -55,29 +55,15 @@ abstract class ZippedPartitionsBaseRDD[V: ClassManifest]( } override def getPreferredLocations(s: Partition): Seq[String] = { - // TODO(matei): Fix this for hostPort - - // Note that as number of rdd's increase and/or number of slaves in cluster increase, the computed preferredLocations below - // become diminishingly small : so we might need to look at alternate strategies to alleviate this. - // If there are no (or very small number of preferred locations), we will end up transferred the blocks to 'any' node in the - // cluster - paying with n/w and cache cost. - // Maybe pick a node which figures max amount of time ? - // Choose node which is hosting 'larger' of some subset of blocks ? - // Look at rack locality to ensure chosen host is atleast rack local to both hosting node ?, etc (would be good to defer this if possible) - val splits = s.asInstanceOf[ZippedPartitionsPartition].partitions - val rddSplitZip = rdds.zip(splits) - - // exact match. - val exactMatchPreferredLocations = rddSplitZip.map(x => x._1.preferredLocations(x._2)) - val exactMatchLocations = exactMatchPreferredLocations.reduce((x, y) => x.intersect(y)) - - // Remove exact match and then do host local match. - val exactMatchHosts = exactMatchLocations.map(Utils.parseHostPort(_)._1) - val matchPreferredHosts = exactMatchPreferredLocations.map(locs => locs.map(Utils.parseHostPort(_)._1)) - .reduce((x, y) => x.intersect(y)) - val otherNodeLocalLocations = matchPreferredHosts.filter { s => !exactMatchHosts.contains(s) } - - otherNodeLocalLocations ++ exactMatchLocations + val parts = s.asInstanceOf[ZippedPartitionsPartition].partitions + val prefs = rdds.zip(parts).map { case (rdd, p) => rdd.preferredLocations(p) } + // Check whether there are any hosts that match all RDDs; otherwise return the union + val exactMatchLocations = prefs.reduce((x, y) => x.intersect(y)) + if (!exactMatchLocations.isEmpty) { + exactMatchLocations + } else { + prefs.flatten.distinct + } } override def clearDependencies() { diff --git a/core/src/main/scala/spark/rdd/ZippedRDD.scala b/core/src/main/scala/spark/rdd/ZippedRDD.scala index b1c43b3195ab63b0be8a5ae9ed423d984e198319..4074e50e44040b96b756efeb69c238b7165ff142 100644 --- a/core/src/main/scala/spark/rdd/ZippedRDD.scala +++ b/core/src/main/scala/spark/rdd/ZippedRDD.scala @@ -65,27 +65,16 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest]( } override def getPreferredLocations(s: Partition): Seq[String] = { - // Note that as number of slaves in cluster increase, the computed preferredLocations can become small : so we might need - // to look at alternate strategies to alleviate this. (If there are no (or very small number of preferred locations), we - // will end up transferred the blocks to 'any' node in the cluster - paying with n/w and cache cost. - // Maybe pick one or the other ? (so that atleast one block is local ?). - // Choose node which is hosting 'larger' of the blocks ? - // Look at rack locality to ensure chosen host is atleast rack local to both hosting node ?, etc (would be good to defer this if possible) val (partition1, partition2) = s.asInstanceOf[ZippedPartition[T, U]].partitions val pref1 = rdd1.preferredLocations(partition1) val pref2 = rdd2.preferredLocations(partition2) - - // exact match - instance local and host local. + // Check whether there are any hosts that match both RDDs; otherwise return the union val exactMatchLocations = pref1.intersect(pref2) - - // remove locations which are already handled via exactMatchLocations, and intersect where both partitions are node local. - val otherNodeLocalPref1 = pref1.filter(loc => ! exactMatchLocations.contains(loc)).map(loc => Utils.parseHostPort(loc)._1) - val otherNodeLocalPref2 = pref2.filter(loc => ! exactMatchLocations.contains(loc)).map(loc => Utils.parseHostPort(loc)._1) - val otherNodeLocalLocations = otherNodeLocalPref1.intersect(otherNodeLocalPref2) - - - // Can have mix of instance local (hostPort) and node local (host) locations as preference ! - exactMatchLocations ++ otherNodeLocalLocations + if (!exactMatchLocations.isEmpty) { + exactMatchLocations + } else { + (pref1 ++ pref2).distinct + } } override def clearDependencies() { diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index 036e36bca07a076279e14adb88a133fd6d1757ab..ec76e9018552cae9284295cf040ce4d0f5a2def4 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -45,39 +45,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext) // Threshold above which we warn user initial TaskSet may be starved val STARVATION_TIMEOUT = System.getProperty("spark.starvation.timeout", "15000").toLong - // How often to revive offers in case there are pending tasks - that is how often to try to get - // tasks scheduled in case there are nodes available : default 0 is to disable it - to preserve existing behavior - // Note that this is required due to delay scheduling due to data locality waits, etc. - // TODO(matei): move to StandaloneSchedulerBackend? - val TASK_REVIVAL_INTERVAL = System.getProperty("spark.scheduler.revival.interval", "1000").toLong - - // TODO(matei): replace this with something that only affects levels past PROCESS_LOCAL; - // basically it can be a "cliff" for locality - /* - This property controls how aggressive we should be to modulate waiting for node local task scheduling. - To elaborate, currently there is a time limit (3 sec def) to ensure that spark attempts to wait for node locality of tasks before - scheduling on other nodes. We have modified this in yarn branch such that offers to task set happen in prioritized order : - node-local, rack-local and then others - But once all available node local (and no pref) tasks are scheduled, instead of waiting for 3 sec before - scheduling to other nodes (which degrades performance for time sensitive tasks and on larger clusters), we can - modulate that : to also allow rack local nodes or any node. The default is still set to HOST - so that previous behavior is - maintained. This is to allow tuning the tension between pulling rdd data off node and scheduling computation asap. - - TODO: rename property ? The value is one of - - NODE_LOCAL (default, no change w.r.t current behavior), - - RACK_LOCAL and - - ANY - - Note that this property makes more sense when used in conjugation with spark.tasks.revive.interval > 0 : else it is not very effective. - - Additional Note: For non trivial clusters, there is a 4x - 5x reduction in running time (in some of our experiments) based on whether - it is left at default NODE_LOCAL, RACK_LOCAL (if cluster is configured to be rack aware) or ANY. - If cluster is rack aware, then setting it to RACK_LOCAL gives best tradeoff and a 3x - 4x performance improvement while minimizing IO impact. - Also, it brings down the variance in running time drastically. - */ - val TASK_SCHEDULING_AGGRESSION = TaskLocality.withName( - System.getProperty("spark.tasks.schedule.aggression", "NODE_LOCAL")) - val activeTaskSets = new HashMap[String, TaskSetManager] val taskIdToTaskSetId = new HashMap[Long, String] @@ -136,14 +103,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } } schedulableBuilder.buildPools() - // resolve executorId to hostPort mapping. - def executorToHostPort(executorId: String, defaultHostPort: String): String = { - executorIdToHost.getOrElse(executorId, defaultHostPort) - } - - // Unfortunately, this means that SparkEnv is indirectly referencing ClusterScheduler - // Will that be a design violation ? - SparkEnv.get.executorIdToHostPort = Some(executorToHostPort) } def newTaskId(): Long = nextTaskId.getAndIncrement() @@ -168,28 +127,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } }.start() } - - - // Change to always run with some default if TASK_REVIVAL_INTERVAL <= 0 ? - // TODO(matei): remove this thread - if (TASK_REVIVAL_INTERVAL > 0) { - new Thread("ClusterScheduler task offer revival check") { - setDaemon(true) - - override def run() { - logInfo("Starting speculative task offer revival thread") - while (true) { - try { - Thread.sleep(TASK_REVIVAL_INTERVAL) - } catch { - case e: InterruptedException => {} - } - - if (hasPendingTasks()) backend.reviveOffers() - } - } - }.start() - } } override def submitTasks(taskSet: TaskSet) { @@ -329,7 +266,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext) backend.reviveOffers() } if (taskFailed) { - // Also revive offers if a task had failed for some reason other than host lost backend.reviveOffers() } @@ -384,7 +320,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } // Check for pending tasks in all our active jobs. - def hasPendingTasks(): Boolean = { + def hasPendingTasks: Boolean = { synchronized { rootPool.hasPendingTasks() } @@ -416,10 +352,8 @@ private[spark] class ClusterScheduler(val sc: SparkContext) /** Remove an executor from all our data structures and mark it as lost */ private def removeExecutor(executorId: String) { - // TODO(matei): remove HostPort activeExecutorIds -= executorId val host = executorIdToHost(executorId) - val execs = executorsByHost.getOrElse(host, new HashSet) execs -= executorId if (execs.isEmpty) { diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala index 1947c516dbc7bfa8a1e47e3e9ba277cdf935babc..cf406f876f5f136d5d5279f6469b6e6caac29dfd 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala @@ -42,9 +42,6 @@ import spark.TaskResultTooBigFailure private[spark] class ClusterTaskSetManager(sched: ClusterScheduler, val taskSet: TaskSet) extends TaskSetManager with Logging { - // Maximum time to wait to run a task in a preferred location (in ms) - val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong - // CPUs to request per task val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toDouble @@ -74,8 +71,6 @@ private[spark] class ClusterTaskSetManager(sched: ClusterScheduler, val taskSet: var stageId = taskSet.stageId var name = "TaskSet_"+taskSet.stageId.toString var parent: Schedulable = null - // Last time when we launched a preferred task (for delay scheduling) - var lastPreferredLaunchTime = System.currentTimeMillis // Set of pending tasks for each executor. These collections are actually // treated as stacks, in which new tasks are added to the end of the @@ -114,11 +109,9 @@ private[spark] class ClusterTaskSetManager(sched: ClusterScheduler, val taskSet: val EXCEPTION_PRINT_INTERVAL = System.getProperty("spark.logging.exceptionPrintInterval", "10000").toLong - // Map of recent exceptions (identified by string representation and - // top stack frame) to duplicate count (how many times the same - // exception has appeared) and time the full exception was - // printed. This should ideally be an LRU map that can drop old - // exceptions automatically. + // Map of recent exceptions (identified by string representation and top stack frame) to + // duplicate count (how many times the same exception has appeared) and time the full exception + // was printed. This should ideally be an LRU map that can drop old exceptions automatically. val recentExceptions = HashMap[String, (Int, Long)]() // Figure out the current map output tracker epoch and set it on all tasks @@ -134,6 +127,16 @@ private[spark] class ClusterTaskSetManager(sched: ClusterScheduler, val taskSet: addPendingTask(i) } + // Figure out which locality levels we have in our TaskSet, so we can do delay scheduling + val myLocalityLevels = computeValidLocalityLevels() + val localityWaits = myLocalityLevels.map(getLocalityWait) // spark.locality.wait + + // Delay scheduling variables: we keep track of our current locality level and the time we + // last launched a task at that level, and move up a level when localityWaits[curLevel] expires. + // We then move down if we manage to launch a "more local" task. + var currentLocalityIndex = 0 // Index of our current locality level in validLocalityLevels + var lastLaunchTime = System.currentTimeMillis() // Time we last launched a task at this level + /** * Add a task to all the pending-task lists that it should be on. If readding is set, we are * re-adding the task so only include it in each list if it's not already there. @@ -169,7 +172,9 @@ private[spark] class ClusterTaskSetManager(sched: ClusterScheduler, val taskSet: addTo(pendingTasksWithNoPrefs) } - addTo(allPendingTasks) + if (!readding) { + allPendingTasks += index // No point scanning this whole list to find the old task there + } } /** @@ -324,18 +329,9 @@ private[spark] class ClusterTaskSetManager(sched: ClusterScheduler, val taskSet: : Option[TaskDescription] = { if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) { - val curTime = System.currentTimeMillis - - // If explicitly specified, use that - val locality = { - // expand only if we have waited for more than LOCALITY_WAIT for a host local task ... - // TODO(matei): Multi-level delay scheduling - if (curTime - lastPreferredLaunchTime < LOCALITY_WAIT) { - TaskLocality.NODE_LOCAL - } else { - TaskLocality.ANY - } - } + val curTime = System.currentTimeMillis() + + val locality = getAllowedLocalityLevel(curTime) findTask(execId, host, locality) match { case Some((index, taskLocality)) => { @@ -350,16 +346,16 @@ private[spark] class ClusterTaskSetManager(sched: ClusterScheduler, val taskSet: val info = new TaskInfo(taskId, index, curTime, execId, host, taskLocality) taskInfos(taskId) = info taskAttempts(index) = info :: taskAttempts(index) - if (taskLocality == TaskLocality.PROCESS_LOCAL || taskLocality == TaskLocality.NODE_LOCAL) { - lastPreferredLaunchTime = curTime - } + // Update our locality level for delay scheduling + currentLocalityIndex = getLocalityIndex(locality) + lastLaunchTime = curTime // Serialize and return the task - val startTime = System.currentTimeMillis + val startTime = System.currentTimeMillis() // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here // we assume the task can be serialized without exceptions. val serializedTask = Task.serializeWithDependencies( task, sched.sc.addedFiles, sched.sc.addedJars, ser) - val timeTaken = System.currentTimeMillis - startTime + val timeTaken = System.currentTimeMillis() - startTime increaseRunningTasks(1) logInfo("Serialized task %s:%d as %d bytes in %d ms".format( taskSet.id, index, serializedTask.limit, timeTaken)) @@ -374,6 +370,34 @@ private[spark] class ClusterTaskSetManager(sched: ClusterScheduler, val taskSet: return None } + /** + * Get the level we can launch tasks according to delay scheduling, based on current wait time. + */ + private def getAllowedLocalityLevel(curTime: Long): TaskLocality.TaskLocality = { + while (curTime - lastLaunchTime >= localityWaits(currentLocalityIndex) && + currentLocalityIndex < myLocalityLevels.length - 1) + { + // Jump to the next locality level, and remove our waiting time for the current one since + // we don't want to count it again on the next one + lastLaunchTime += localityWaits(currentLocalityIndex) + currentLocalityIndex += 1 + } + myLocalityLevels(currentLocalityIndex) + } + + /** + * Find the index in myLocalityLevels for a given locality. This is also designed to work with + * localities that are not in myLocalityLevels (in case we somehow get those) by returning the + * next-biggest level we have. Uses the fact that the last value in myLocalityLevels is ANY. + */ + def getLocalityIndex(locality: TaskLocality.TaskLocality): Int = { + var index = 0 + while (locality > myLocalityLevels(index)) { + index += 1 + } + index + } + /** Called by cluster scheduler when one of our tasks changes state */ override def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { SparkEnv.set(env) @@ -467,7 +491,7 @@ private[spark] class ClusterTaskSetManager(sched: ClusterScheduler, val taskSet: case ef: ExceptionFailure => sched.listener.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null)) val key = ef.description - val now = System.currentTimeMillis + val now = System.currentTimeMillis() val (printFull, dupCount) = { if (recentExceptions.contains(key)) { val (dupCount, printTime) = recentExceptions(key) @@ -631,4 +655,38 @@ private[spark] class ClusterTaskSetManager(sched: ClusterScheduler, val taskSet: override def hasPendingTasks(): Boolean = { numTasks > 0 && tasksFinished < numTasks } + + private def getLocalityWait(level: TaskLocality.TaskLocality): Long = { + val defaultWait = System.getProperty("spark.locality.wait", "3000") + level match { + case TaskLocality.PROCESS_LOCAL => + System.getProperty("spark.locality.wait.process", defaultWait).toLong + case TaskLocality.NODE_LOCAL => + System.getProperty("spark.locality.wait.node", defaultWait).toLong + case TaskLocality.RACK_LOCAL => + System.getProperty("spark.locality.wait.rack", defaultWait).toLong + case TaskLocality.ANY => + 0L + } + } + + /** + * Compute the locality levels used in this TaskSet. Assumes that all tasks have already been + * added to queues using addPendingTask. + */ + private def computeValidLocalityLevels(): Array[TaskLocality.TaskLocality] = { + import TaskLocality.{PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY} + val levels = new ArrayBuffer[TaskLocality.TaskLocality] + if (!pendingTasksForExecutor.isEmpty && getLocalityWait(PROCESS_LOCAL) != 0) { + levels += PROCESS_LOCAL + } + if (!pendingTasksForHost.isEmpty && getLocalityWait(NODE_LOCAL) != 0) { + levels += NODE_LOCAL + } + if (!pendingTasksForRack.isEmpty && getLocalityWait(RACK_LOCAL) != 0) { + levels += RACK_LOCAL + } + levels += ANY + levels.toArray + } } diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index 3b49af12585d1812bc7970b3c3d558920183dd05..3203be10294d375d860e0e14ed24912a2edb31d0 100644 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -26,6 +26,7 @@ import akka.dispatch.Await import akka.pattern.ask import akka.remote.{RemoteClientShutdown, RemoteClientDisconnected, RemoteClientLifeCycleEvent} import akka.util.Duration +import akka.util.duration._ import spark.{Utils, SparkException, Logging, TaskState} import spark.scheduler.cluster.StandaloneClusterMessages._ @@ -39,8 +40,6 @@ private[spark] class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: ActorSystem) extends SchedulerBackend with Logging { - // TODO(matei): periodically revive offers as in MesosScheduler - // Use an atomic variable to track total number of cores in the cluster for simplicity and speed var totalCoreCount = new AtomicInteger(0) @@ -55,6 +54,10 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor override def preStart() { // Listen for remote client disconnection events, since they don't go through Akka's watch() context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) + + // Periodically revive offers to allow delay scheduling to work + val reviveInterval = System.getProperty("spark.scheduler.revive.interval", "1000").toLong + context.system.scheduler.schedule(0.millis, reviveInterval.millis, self, ReviveOffers) } def receive = {