From ec2e2ed1e1b2fb313f087cc0b0bbb33d3e6c5f75 Mon Sep 17 00:00:00 2001
From: Reynold Xin <rxin@apache.org>
Date: Thu, 10 Oct 2013 18:55:25 -0700
Subject: [PATCH] Use the same Executor in LocalScheduler as in
 ClusterScheduler.

---
 .../org/apache/spark/executor/Executor.scala  |  19 +-
 .../scheduler/local/LocalScheduler.scala      | 177 +++++-------------
 .../scheduler/local/LocalTaskSetManager.scala |  13 +-
 .../spark/rdd/AsyncRDDActionsSuite.scala      |   2 +-
 4 files changed, 67 insertions(+), 144 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 3d82790427..4c544275c2 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -36,7 +36,8 @@ import org.apache.spark.util.Utils
 private[spark] class Executor(
     executorId: String,
     slaveHostname: String,
-    properties: Seq[(String, String)])
+    properties: Seq[(String, String)],
+    isLocal: Boolean = false)
   extends Logging
 {
   // Application dependencies (added through SparkContext) that we've fetched so far on this node.
@@ -101,10 +102,17 @@ private[spark] class Executor(
   val executorSource = new ExecutorSource(this, executorId)
 
   // Initialize Spark environment (using system properties read above)
-  private val env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0,
-    isDriver = false, isLocal = false)
-  SparkEnv.set(env)
-  env.metricsSystem.registerSource(executorSource)
+  private val env = {
+    if (!isLocal) {
+      val _env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0,
+        isDriver = false, isLocal = false)
+      SparkEnv.set(_env)
+      _env.metricsSystem.registerSource(executorSource)
+      _env
+    } else {
+      SparkEnv.get
+    }
+  }
 
   // Akka's message frame size. This is only used to warn the user when the task result is greater
   // than this value, in which case Akka will silently drop the task result message.
@@ -205,6 +213,7 @@ private[spark] class Executor(
         if (task.killed) {
           logInfo("Executor killed task " + taskId)
           execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
+          return
         }
 
         for (m <- task.metrics) {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
index e132182231..cc16c688ca 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
@@ -17,23 +17,19 @@
 
 package org.apache.spark.scheduler.local
 
-import java.io.File
-import java.lang.management.ManagementFactory
-import java.util.concurrent.atomic.AtomicInteger
 import java.nio.ByteBuffer
+import java.util.concurrent.atomic.AtomicInteger
 
-import scala.collection.JavaConversions._
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.HashSet
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
+
+import akka.actor._
 
 import org.apache.spark._
 import org.apache.spark.TaskState.TaskState
-import org.apache.spark.executor.ExecutorURLClassLoader
+import org.apache.spark.executor.{Executor, ExecutorBackend}
 import org.apache.spark.scheduler._
 import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
-import akka.actor._
-import org.apache.spark.util.Utils
+
 
 /**
  * A FIFO or Fair TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
@@ -51,7 +47,10 @@ private[local]
 case class KillTask(taskId: Long)
 
 private[spark]
-class LocalActor(localScheduler: LocalScheduler, var freeCores: Int) extends Actor with Logging {
+class LocalActor(localScheduler: LocalScheduler, private var freeCores: Int)
+  extends Actor with Logging {
+
+  val executor = new Executor("local", "local", Seq.empty, isLocal = true)
 
   def receive = {
     case LocalReviveOffers =>
@@ -59,36 +58,27 @@ class LocalActor(localScheduler: LocalScheduler, var freeCores: Int) extends Act
 
     case LocalStatusUpdate(taskId, state, serializeData) =>
       freeCores += 1
-      localScheduler.statusUpdate(taskId, state, serializeData)
       launchTask(localScheduler.resourceOffer(freeCores))
 
     case KillTask(taskId) =>
-      killTask(taskId)
+      executor.killTask(taskId)
   }
 
-  def launchTask(tasks : Seq[TaskDescription]) {
+  private def launchTask(tasks: Seq[TaskDescription]) {
     for (task <- tasks) {
       freeCores -= 1
-      localScheduler.threadPool.submit(new Runnable {
-        def run() {
-          localScheduler.runTask(task.taskId, task.serializedTask)
-        }
-      })
+      executor.launchTask(localScheduler, task.taskId, task.serializedTask)
     }
   }
-
-  def killTask(taskId: Long) {
-
-  }
 }
 
 private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext)
   extends TaskScheduler
+  with ExecutorBackend
   with Logging {
 
-  var attemptId = new AtomicInteger(0)
-  var threadPool = Utils.newDaemonFixedThreadPool(threads)
   val env = SparkEnv.get
+  val attemptId = new AtomicInteger
   var listener: TaskSchedulerListener = null
 
   // Application dependencies (added through SparkContext) that we've fetched so far on this node.
@@ -96,8 +86,6 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
   val currentFiles: HashMap[String, Long] = new HashMap[String, Long]()
   val currentJars: HashMap[String, Long] = new HashMap[String, Long]()
 
-  val classLoader = new ExecutorURLClassLoader(Array(), Thread.currentThread.getContextClassLoader)
-
   var schedulableBuilder: SchedulableBuilder = null
   var rootPool: Pool = null
   val schedulingMode: SchedulingMode = SchedulingMode.withName(
@@ -139,10 +127,20 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
   }
 
   override def cancelTasks(stageId: Int): Unit = synchronized {
-    schedulableBuilder.getTaskSetManagers(stageId).foreach { sched =>
-      val taskIds = taskSetTaskIds(sched.asInstanceOf[TaskSetManager].taskSet.id)
-      for (tid <- taskIds) {
-        localActor ! KillTask(tid)
+    logInfo("Cancelling stage " + stageId)
+    schedulableBuilder.getTaskSetManagers(stageId).foreach { tsm =>
+      // There are two possible cases here:
+      // 1. The task set manager has been created and some tasks have been scheduled.
+      //    In this case, send a kill signal to the executors to kill the task.
+      // 2. The task set manager has been created but no tasks has been scheduled. In this case,
+      //    simply abort the task set.
+      val taskIds = taskSetTaskIds(tsm.taskSet.id)
+      if (taskIds.size > 0) {
+        taskIds.foreach { tid =>
+          localActor ! KillTask(tid)
+        }
+      } else {
+        tsm.error("Stage %d was cancelled before any tasks was launched".format(stageId))
       }
     }
   }
@@ -186,107 +184,32 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
     }
   }
 
-  def runTask(taskId: Long, bytes: ByteBuffer) {
-    logInfo("Running " + taskId)
-    val info = new TaskInfo(taskId, 0, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL)
-    // Set the Spark execution environment for the worker thread
-    SparkEnv.set(env)
-    val ser = SparkEnv.get.closureSerializer.newInstance()
-    val objectSer = SparkEnv.get.serializer.newInstance()
-    var attemptedTask: Option[Task[_]] = None   
-    val start = System.currentTimeMillis()
-    var taskStart: Long = 0
-    def getTotalGCTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum
-    val startGCTime = getTotalGCTime
-
-    try {
-      Accumulators.clear()
-      Thread.currentThread().setContextClassLoader(classLoader)
-
-      // Serialize and deserialize the task so that accumulators are changed to thread-local ones;
-      // this adds a bit of unnecessary overhead but matches how the Mesos Executor works.
-      val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes)
-      updateDependencies(taskFiles, taskJars)   // Download any files added with addFile
-      val deserializedTask = ser.deserialize[Task[_]](
-        taskBytes, Thread.currentThread.getContextClassLoader)
-      attemptedTask = Some(deserializedTask)
-      val deserTime = System.currentTimeMillis() - start
-      taskStart = System.currentTimeMillis()
-
-      // Run it
-      val result: Any = deserializedTask.run(taskId)
-
-      // Serialize and deserialize the result to emulate what the Mesos
-      // executor does. This is useful to catch serialization errors early
-      // on in development (so when users move their local Spark programs
-      // to the cluster, they don't get surprised by serialization errors).
-      val serResult = objectSer.serialize(result)
-      deserializedTask.metrics.get.resultSize = serResult.limit()
-      val resultToReturn = objectSer.deserialize[Any](serResult)
-      val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
-        ser.serialize(Accumulators.values))
-      val serviceTime = System.currentTimeMillis() - taskStart
-      logInfo("Finished " + taskId)
-      deserializedTask.metrics.get.executorRunTime = serviceTime.toInt
-      deserializedTask.metrics.get.jvmGCTime = getTotalGCTime - startGCTime
-      deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt
-      val taskResult = new DirectTaskResult(
-        result, accumUpdates, deserializedTask.metrics.getOrElse(null))
-      val serializedResult = ser.serialize(taskResult)
-      localActor ! LocalStatusUpdate(taskId, TaskState.FINISHED, serializedResult)
-    } catch {
-      case t: Throwable => {
-        val serviceTime = System.currentTimeMillis() - taskStart
-        val metrics = attemptedTask.flatMap(t => t.metrics)
-        for (m <- metrics) {
-          m.executorRunTime = serviceTime.toInt
-          m.jvmGCTime = getTotalGCTime - startGCTime
-        }
-        val failure = new ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace, metrics)
-        localActor ! LocalStatusUpdate(taskId, TaskState.FAILED, ser.serialize(failure))
-      }
-    }
-  }
-
-  /**
-   * Download any missing dependencies if we receive a new set of files and JARs from the
-   * SparkContext. Also adds any new JARs we fetched to the class loader.
-   */
-  private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) {
-    synchronized {
-      // Fetch missing dependencies
-      for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
-        logInfo("Fetching " + name + " with timestamp " + timestamp)
-        Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
-        currentFiles(name) = timestamp
-      }
+  override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) {
+    if (TaskState.isFinished(state)) synchronized {
+      taskIdToTaskSetId.get(taskId) match {
+        case Some(taskSetId) =>
+          val taskSetManager = activeTaskSets(taskSetId)
+          taskSetTaskIds(taskSetId) -= taskId
+
+          state match {
+            case TaskState.FINISHED =>
+              taskSetManager.taskEnded(taskId, state, serializedData)
+            case TaskState.FAILED =>
+              taskSetManager.taskFailed(taskId, state, serializedData)
+            case TaskState.KILLED =>
+              taskSetManager.error("Task %d was killed".format(taskId))
+            case _ => {}
+          }
 
-      for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
-        logInfo("Fetching " + name + " with timestamp " + timestamp)
-        Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
-        currentJars(name) = timestamp
-        // Add it to our class loader
-        val localName = name.split("/").last
-        val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL
-        if (!classLoader.getURLs.contains(url)) {
-          logInfo("Adding " + url + " to class loader")
-          classLoader.addURL(url)
-        }
+          localActor ! LocalStatusUpdate(taskId, state, serializedData)
+        case None =>
+          logInfo("Ignoring update from TID " + taskId + " because its task set is gone")
       }
     }
   }
 
-  def statusUpdate(taskId :Long, state: TaskState, serializedData: ByteBuffer) {
-    synchronized {
-      val taskSetId = taskIdToTaskSetId(taskId)
-      val taskSetManager = activeTaskSets(taskSetId)
-      taskSetTaskIds(taskSetId) -= taskId
-      taskSetManager.statusUpdate(taskId, state, serializedData)
-    }
-  }
-
-   override def stop() {
-    threadPool.shutdownNow()
+  override def stop() {
+    //threadPool.shutdownNow()
   }
 
   override def defaultParallelism() = threads
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
index c2e2399ccb..f72e77d40f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
@@ -132,17 +132,6 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
     return None
   }
 
-  def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
-    SparkEnv.set(env)
-    state match {
-      case TaskState.FINISHED =>
-        taskEnded(tid, state, serializedData)
-      case TaskState.FAILED =>
-        taskFailed(tid, state, serializedData)
-      case _ => {}
-    }
-  }
-
   def taskStarted(task: Task[_], info: TaskInfo) {
     sched.listener.taskStarted(task, info)
   }
@@ -195,5 +184,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
   }
 
   override def error(message: String) {
+    sched.listener.taskSetFailed(taskSet, message)
+    sched.taskSetFinished(this)
   }
 }
diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
index 029f24a51b..ac84640751 100644
--- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
@@ -35,7 +35,7 @@ class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll {
   @transient private var sc: SparkContext = _
 
   override def beforeAll() {
-    sc = new SparkContext("local-cluster[2,1,512]", "test")
+    sc = new SparkContext("local[2]", "test")
   }
 
   override def afterAll() {
-- 
GitLab