From ec2e2ed1e1b2fb313f087cc0b0bbb33d3e6c5f75 Mon Sep 17 00:00:00 2001 From: Reynold Xin <rxin@apache.org> Date: Thu, 10 Oct 2013 18:55:25 -0700 Subject: [PATCH] Use the same Executor in LocalScheduler as in ClusterScheduler. --- .../org/apache/spark/executor/Executor.scala | 19 +- .../scheduler/local/LocalScheduler.scala | 177 +++++------------- .../scheduler/local/LocalTaskSetManager.scala | 13 +- .../spark/rdd/AsyncRDDActionsSuite.scala | 2 +- 4 files changed, 67 insertions(+), 144 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 3d82790427..4c544275c2 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -36,7 +36,8 @@ import org.apache.spark.util.Utils private[spark] class Executor( executorId: String, slaveHostname: String, - properties: Seq[(String, String)]) + properties: Seq[(String, String)], + isLocal: Boolean = false) extends Logging { // Application dependencies (added through SparkContext) that we've fetched so far on this node. @@ -101,10 +102,17 @@ private[spark] class Executor( val executorSource = new ExecutorSource(this, executorId) // Initialize Spark environment (using system properties read above) - private val env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0, - isDriver = false, isLocal = false) - SparkEnv.set(env) - env.metricsSystem.registerSource(executorSource) + private val env = { + if (!isLocal) { + val _env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0, + isDriver = false, isLocal = false) + SparkEnv.set(_env) + _env.metricsSystem.registerSource(executorSource) + _env + } else { + SparkEnv.get + } + } // Akka's message frame size. This is only used to warn the user when the task result is greater // than this value, in which case Akka will silently drop the task result message. @@ -205,6 +213,7 @@ private[spark] class Executor( if (task.killed) { logInfo("Executor killed task " + taskId) execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) + return } for (m <- task.metrics) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala index e132182231..cc16c688ca 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala @@ -17,23 +17,19 @@ package org.apache.spark.scheduler.local -import java.io.File -import java.lang.management.ManagementFactory -import java.util.concurrent.atomic.AtomicInteger import java.nio.ByteBuffer +import java.util.concurrent.atomic.AtomicInteger -import scala.collection.JavaConversions._ -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} + +import akka.actor._ import org.apache.spark._ import org.apache.spark.TaskState.TaskState -import org.apache.spark.executor.ExecutorURLClassLoader +import org.apache.spark.executor.{Executor, ExecutorBackend} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode -import akka.actor._ -import org.apache.spark.util.Utils + /** * A FIFO or Fair TaskScheduler implementation that runs tasks locally in a thread pool. Optionally @@ -51,7 +47,10 @@ private[local] case class KillTask(taskId: Long) private[spark] -class LocalActor(localScheduler: LocalScheduler, var freeCores: Int) extends Actor with Logging { +class LocalActor(localScheduler: LocalScheduler, private var freeCores: Int) + extends Actor with Logging { + + val executor = new Executor("local", "local", Seq.empty, isLocal = true) def receive = { case LocalReviveOffers => @@ -59,36 +58,27 @@ class LocalActor(localScheduler: LocalScheduler, var freeCores: Int) extends Act case LocalStatusUpdate(taskId, state, serializeData) => freeCores += 1 - localScheduler.statusUpdate(taskId, state, serializeData) launchTask(localScheduler.resourceOffer(freeCores)) case KillTask(taskId) => - killTask(taskId) + executor.killTask(taskId) } - def launchTask(tasks : Seq[TaskDescription]) { + private def launchTask(tasks: Seq[TaskDescription]) { for (task <- tasks) { freeCores -= 1 - localScheduler.threadPool.submit(new Runnable { - def run() { - localScheduler.runTask(task.taskId, task.serializedTask) - } - }) + executor.launchTask(localScheduler, task.taskId, task.serializedTask) } } - - def killTask(taskId: Long) { - - } } private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext) extends TaskScheduler + with ExecutorBackend with Logging { - var attemptId = new AtomicInteger(0) - var threadPool = Utils.newDaemonFixedThreadPool(threads) val env = SparkEnv.get + val attemptId = new AtomicInteger var listener: TaskSchedulerListener = null // Application dependencies (added through SparkContext) that we've fetched so far on this node. @@ -96,8 +86,6 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: val currentFiles: HashMap[String, Long] = new HashMap[String, Long]() val currentJars: HashMap[String, Long] = new HashMap[String, Long]() - val classLoader = new ExecutorURLClassLoader(Array(), Thread.currentThread.getContextClassLoader) - var schedulableBuilder: SchedulableBuilder = null var rootPool: Pool = null val schedulingMode: SchedulingMode = SchedulingMode.withName( @@ -139,10 +127,20 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: } override def cancelTasks(stageId: Int): Unit = synchronized { - schedulableBuilder.getTaskSetManagers(stageId).foreach { sched => - val taskIds = taskSetTaskIds(sched.asInstanceOf[TaskSetManager].taskSet.id) - for (tid <- taskIds) { - localActor ! KillTask(tid) + logInfo("Cancelling stage " + stageId) + schedulableBuilder.getTaskSetManagers(stageId).foreach { tsm => + // There are two possible cases here: + // 1. The task set manager has been created and some tasks have been scheduled. + // In this case, send a kill signal to the executors to kill the task. + // 2. The task set manager has been created but no tasks has been scheduled. In this case, + // simply abort the task set. + val taskIds = taskSetTaskIds(tsm.taskSet.id) + if (taskIds.size > 0) { + taskIds.foreach { tid => + localActor ! KillTask(tid) + } + } else { + tsm.error("Stage %d was cancelled before any tasks was launched".format(stageId)) } } } @@ -186,107 +184,32 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: } } - def runTask(taskId: Long, bytes: ByteBuffer) { - logInfo("Running " + taskId) - val info = new TaskInfo(taskId, 0, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL) - // Set the Spark execution environment for the worker thread - SparkEnv.set(env) - val ser = SparkEnv.get.closureSerializer.newInstance() - val objectSer = SparkEnv.get.serializer.newInstance() - var attemptedTask: Option[Task[_]] = None - val start = System.currentTimeMillis() - var taskStart: Long = 0 - def getTotalGCTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum - val startGCTime = getTotalGCTime - - try { - Accumulators.clear() - Thread.currentThread().setContextClassLoader(classLoader) - - // Serialize and deserialize the task so that accumulators are changed to thread-local ones; - // this adds a bit of unnecessary overhead but matches how the Mesos Executor works. - val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes) - updateDependencies(taskFiles, taskJars) // Download any files added with addFile - val deserializedTask = ser.deserialize[Task[_]]( - taskBytes, Thread.currentThread.getContextClassLoader) - attemptedTask = Some(deserializedTask) - val deserTime = System.currentTimeMillis() - start - taskStart = System.currentTimeMillis() - - // Run it - val result: Any = deserializedTask.run(taskId) - - // Serialize and deserialize the result to emulate what the Mesos - // executor does. This is useful to catch serialization errors early - // on in development (so when users move their local Spark programs - // to the cluster, they don't get surprised by serialization errors). - val serResult = objectSer.serialize(result) - deserializedTask.metrics.get.resultSize = serResult.limit() - val resultToReturn = objectSer.deserialize[Any](serResult) - val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]]( - ser.serialize(Accumulators.values)) - val serviceTime = System.currentTimeMillis() - taskStart - logInfo("Finished " + taskId) - deserializedTask.metrics.get.executorRunTime = serviceTime.toInt - deserializedTask.metrics.get.jvmGCTime = getTotalGCTime - startGCTime - deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt - val taskResult = new DirectTaskResult( - result, accumUpdates, deserializedTask.metrics.getOrElse(null)) - val serializedResult = ser.serialize(taskResult) - localActor ! LocalStatusUpdate(taskId, TaskState.FINISHED, serializedResult) - } catch { - case t: Throwable => { - val serviceTime = System.currentTimeMillis() - taskStart - val metrics = attemptedTask.flatMap(t => t.metrics) - for (m <- metrics) { - m.executorRunTime = serviceTime.toInt - m.jvmGCTime = getTotalGCTime - startGCTime - } - val failure = new ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace, metrics) - localActor ! LocalStatusUpdate(taskId, TaskState.FAILED, ser.serialize(failure)) - } - } - } - - /** - * Download any missing dependencies if we receive a new set of files and JARs from the - * SparkContext. Also adds any new JARs we fetched to the class loader. - */ - private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) { - synchronized { - // Fetch missing dependencies - for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { - logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) - currentFiles(name) = timestamp - } + override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) { + if (TaskState.isFinished(state)) synchronized { + taskIdToTaskSetId.get(taskId) match { + case Some(taskSetId) => + val taskSetManager = activeTaskSets(taskSetId) + taskSetTaskIds(taskSetId) -= taskId + + state match { + case TaskState.FINISHED => + taskSetManager.taskEnded(taskId, state, serializedData) + case TaskState.FAILED => + taskSetManager.taskFailed(taskId, state, serializedData) + case TaskState.KILLED => + taskSetManager.error("Task %d was killed".format(taskId)) + case _ => {} + } - for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { - logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) - currentJars(name) = timestamp - // Add it to our class loader - val localName = name.split("/").last - val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL - if (!classLoader.getURLs.contains(url)) { - logInfo("Adding " + url + " to class loader") - classLoader.addURL(url) - } + localActor ! LocalStatusUpdate(taskId, state, serializedData) + case None => + logInfo("Ignoring update from TID " + taskId + " because its task set is gone") } } } - def statusUpdate(taskId :Long, state: TaskState, serializedData: ByteBuffer) { - synchronized { - val taskSetId = taskIdToTaskSetId(taskId) - val taskSetManager = activeTaskSets(taskSetId) - taskSetTaskIds(taskSetId) -= taskId - taskSetManager.statusUpdate(taskId, state, serializedData) - } - } - - override def stop() { - threadPool.shutdownNow() + override def stop() { + //threadPool.shutdownNow() } override def defaultParallelism() = threads diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala index c2e2399ccb..f72e77d40f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala @@ -132,17 +132,6 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas return None } - def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { - SparkEnv.set(env) - state match { - case TaskState.FINISHED => - taskEnded(tid, state, serializedData) - case TaskState.FAILED => - taskFailed(tid, state, serializedData) - case _ => {} - } - } - def taskStarted(task: Task[_], info: TaskInfo) { sched.listener.taskStarted(task, info) } @@ -195,5 +184,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas } override def error(message: String) { + sched.listener.taskSetFailed(taskSet, message) + sched.taskSetFinished(this) } } diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala index 029f24a51b..ac84640751 100644 --- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala @@ -35,7 +35,7 @@ class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll { @transient private var sc: SparkContext = _ override def beforeAll() { - sc = new SparkContext("local-cluster[2,1,512]", "test") + sc = new SparkContext("local[2]", "test") } override def afterAll() { -- GitLab