Skip to content
Snippets Groups Projects
Commit ec2e2ed1 authored by Reynold Xin's avatar Reynold Xin
Browse files

Use the same Executor in LocalScheduler as in ClusterScheduler.

parent 357733d2
No related branches found
No related tags found
No related merge requests found
...@@ -36,7 +36,8 @@ import org.apache.spark.util.Utils ...@@ -36,7 +36,8 @@ import org.apache.spark.util.Utils
private[spark] class Executor( private[spark] class Executor(
executorId: String, executorId: String,
slaveHostname: String, slaveHostname: String,
properties: Seq[(String, String)]) properties: Seq[(String, String)],
isLocal: Boolean = false)
extends Logging extends Logging
{ {
// Application dependencies (added through SparkContext) that we've fetched so far on this node. // Application dependencies (added through SparkContext) that we've fetched so far on this node.
...@@ -101,10 +102,17 @@ private[spark] class Executor( ...@@ -101,10 +102,17 @@ private[spark] class Executor(
val executorSource = new ExecutorSource(this, executorId) val executorSource = new ExecutorSource(this, executorId)
// Initialize Spark environment (using system properties read above) // Initialize Spark environment (using system properties read above)
private val env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0, private val env = {
isDriver = false, isLocal = false) if (!isLocal) {
SparkEnv.set(env) val _env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0,
env.metricsSystem.registerSource(executorSource) isDriver = false, isLocal = false)
SparkEnv.set(_env)
_env.metricsSystem.registerSource(executorSource)
_env
} else {
SparkEnv.get
}
}
// Akka's message frame size. This is only used to warn the user when the task result is greater // Akka's message frame size. This is only used to warn the user when the task result is greater
// than this value, in which case Akka will silently drop the task result message. // than this value, in which case Akka will silently drop the task result message.
...@@ -205,6 +213,7 @@ private[spark] class Executor( ...@@ -205,6 +213,7 @@ private[spark] class Executor(
if (task.killed) { if (task.killed) {
logInfo("Executor killed task " + taskId) logInfo("Executor killed task " + taskId)
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
return
} }
for (m <- task.metrics) { for (m <- task.metrics) {
......
...@@ -17,23 +17,19 @@ ...@@ -17,23 +17,19 @@
package org.apache.spark.scheduler.local package org.apache.spark.scheduler.local
import java.io.File
import java.lang.management.ManagementFactory
import java.util.concurrent.atomic.AtomicInteger
import java.nio.ByteBuffer import java.nio.ByteBuffer
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap import akka.actor._
import scala.collection.mutable.HashSet
import org.apache.spark._ import org.apache.spark._
import org.apache.spark.TaskState.TaskState import org.apache.spark.TaskState.TaskState
import org.apache.spark.executor.ExecutorURLClassLoader import org.apache.spark.executor.{Executor, ExecutorBackend}
import org.apache.spark.scheduler._ import org.apache.spark.scheduler._
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
import akka.actor._
import org.apache.spark.util.Utils
/** /**
* A FIFO or Fair TaskScheduler implementation that runs tasks locally in a thread pool. Optionally * A FIFO or Fair TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
...@@ -51,7 +47,10 @@ private[local] ...@@ -51,7 +47,10 @@ private[local]
case class KillTask(taskId: Long) case class KillTask(taskId: Long)
private[spark] private[spark]
class LocalActor(localScheduler: LocalScheduler, var freeCores: Int) extends Actor with Logging { class LocalActor(localScheduler: LocalScheduler, private var freeCores: Int)
extends Actor with Logging {
val executor = new Executor("local", "local", Seq.empty, isLocal = true)
def receive = { def receive = {
case LocalReviveOffers => case LocalReviveOffers =>
...@@ -59,36 +58,27 @@ class LocalActor(localScheduler: LocalScheduler, var freeCores: Int) extends Act ...@@ -59,36 +58,27 @@ class LocalActor(localScheduler: LocalScheduler, var freeCores: Int) extends Act
case LocalStatusUpdate(taskId, state, serializeData) => case LocalStatusUpdate(taskId, state, serializeData) =>
freeCores += 1 freeCores += 1
localScheduler.statusUpdate(taskId, state, serializeData)
launchTask(localScheduler.resourceOffer(freeCores)) launchTask(localScheduler.resourceOffer(freeCores))
case KillTask(taskId) => case KillTask(taskId) =>
killTask(taskId) executor.killTask(taskId)
} }
def launchTask(tasks : Seq[TaskDescription]) { private def launchTask(tasks: Seq[TaskDescription]) {
for (task <- tasks) { for (task <- tasks) {
freeCores -= 1 freeCores -= 1
localScheduler.threadPool.submit(new Runnable { executor.launchTask(localScheduler, task.taskId, task.serializedTask)
def run() {
localScheduler.runTask(task.taskId, task.serializedTask)
}
})
} }
} }
def killTask(taskId: Long) {
}
} }
private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext) private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext)
extends TaskScheduler extends TaskScheduler
with ExecutorBackend
with Logging { with Logging {
var attemptId = new AtomicInteger(0)
var threadPool = Utils.newDaemonFixedThreadPool(threads)
val env = SparkEnv.get val env = SparkEnv.get
val attemptId = new AtomicInteger
var listener: TaskSchedulerListener = null var listener: TaskSchedulerListener = null
// Application dependencies (added through SparkContext) that we've fetched so far on this node. // Application dependencies (added through SparkContext) that we've fetched so far on this node.
...@@ -96,8 +86,6 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: ...@@ -96,8 +86,6 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
val currentFiles: HashMap[String, Long] = new HashMap[String, Long]() val currentFiles: HashMap[String, Long] = new HashMap[String, Long]()
val currentJars: HashMap[String, Long] = new HashMap[String, Long]() val currentJars: HashMap[String, Long] = new HashMap[String, Long]()
val classLoader = new ExecutorURLClassLoader(Array(), Thread.currentThread.getContextClassLoader)
var schedulableBuilder: SchedulableBuilder = null var schedulableBuilder: SchedulableBuilder = null
var rootPool: Pool = null var rootPool: Pool = null
val schedulingMode: SchedulingMode = SchedulingMode.withName( val schedulingMode: SchedulingMode = SchedulingMode.withName(
...@@ -139,10 +127,20 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: ...@@ -139,10 +127,20 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
} }
override def cancelTasks(stageId: Int): Unit = synchronized { override def cancelTasks(stageId: Int): Unit = synchronized {
schedulableBuilder.getTaskSetManagers(stageId).foreach { sched => logInfo("Cancelling stage " + stageId)
val taskIds = taskSetTaskIds(sched.asInstanceOf[TaskSetManager].taskSet.id) schedulableBuilder.getTaskSetManagers(stageId).foreach { tsm =>
for (tid <- taskIds) { // There are two possible cases here:
localActor ! KillTask(tid) // 1. The task set manager has been created and some tasks have been scheduled.
// In this case, send a kill signal to the executors to kill the task.
// 2. The task set manager has been created but no tasks has been scheduled. In this case,
// simply abort the task set.
val taskIds = taskSetTaskIds(tsm.taskSet.id)
if (taskIds.size > 0) {
taskIds.foreach { tid =>
localActor ! KillTask(tid)
}
} else {
tsm.error("Stage %d was cancelled before any tasks was launched".format(stageId))
} }
} }
} }
...@@ -186,107 +184,32 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: ...@@ -186,107 +184,32 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
} }
} }
def runTask(taskId: Long, bytes: ByteBuffer) { override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) {
logInfo("Running " + taskId) if (TaskState.isFinished(state)) synchronized {
val info = new TaskInfo(taskId, 0, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL) taskIdToTaskSetId.get(taskId) match {
// Set the Spark execution environment for the worker thread case Some(taskSetId) =>
SparkEnv.set(env) val taskSetManager = activeTaskSets(taskSetId)
val ser = SparkEnv.get.closureSerializer.newInstance() taskSetTaskIds(taskSetId) -= taskId
val objectSer = SparkEnv.get.serializer.newInstance()
var attemptedTask: Option[Task[_]] = None state match {
val start = System.currentTimeMillis() case TaskState.FINISHED =>
var taskStart: Long = 0 taskSetManager.taskEnded(taskId, state, serializedData)
def getTotalGCTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum case TaskState.FAILED =>
val startGCTime = getTotalGCTime taskSetManager.taskFailed(taskId, state, serializedData)
case TaskState.KILLED =>
try { taskSetManager.error("Task %d was killed".format(taskId))
Accumulators.clear() case _ => {}
Thread.currentThread().setContextClassLoader(classLoader) }
// Serialize and deserialize the task so that accumulators are changed to thread-local ones;
// this adds a bit of unnecessary overhead but matches how the Mesos Executor works.
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes)
updateDependencies(taskFiles, taskJars) // Download any files added with addFile
val deserializedTask = ser.deserialize[Task[_]](
taskBytes, Thread.currentThread.getContextClassLoader)
attemptedTask = Some(deserializedTask)
val deserTime = System.currentTimeMillis() - start
taskStart = System.currentTimeMillis()
// Run it
val result: Any = deserializedTask.run(taskId)
// Serialize and deserialize the result to emulate what the Mesos
// executor does. This is useful to catch serialization errors early
// on in development (so when users move their local Spark programs
// to the cluster, they don't get surprised by serialization errors).
val serResult = objectSer.serialize(result)
deserializedTask.metrics.get.resultSize = serResult.limit()
val resultToReturn = objectSer.deserialize[Any](serResult)
val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
ser.serialize(Accumulators.values))
val serviceTime = System.currentTimeMillis() - taskStart
logInfo("Finished " + taskId)
deserializedTask.metrics.get.executorRunTime = serviceTime.toInt
deserializedTask.metrics.get.jvmGCTime = getTotalGCTime - startGCTime
deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt
val taskResult = new DirectTaskResult(
result, accumUpdates, deserializedTask.metrics.getOrElse(null))
val serializedResult = ser.serialize(taskResult)
localActor ! LocalStatusUpdate(taskId, TaskState.FINISHED, serializedResult)
} catch {
case t: Throwable => {
val serviceTime = System.currentTimeMillis() - taskStart
val metrics = attemptedTask.flatMap(t => t.metrics)
for (m <- metrics) {
m.executorRunTime = serviceTime.toInt
m.jvmGCTime = getTotalGCTime - startGCTime
}
val failure = new ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace, metrics)
localActor ! LocalStatusUpdate(taskId, TaskState.FAILED, ser.serialize(failure))
}
}
}
/**
* Download any missing dependencies if we receive a new set of files and JARs from the
* SparkContext. Also adds any new JARs we fetched to the class loader.
*/
private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) {
synchronized {
// Fetch missing dependencies
for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
currentFiles(name) = timestamp
}
for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { localActor ! LocalStatusUpdate(taskId, state, serializedData)
logInfo("Fetching " + name + " with timestamp " + timestamp) case None =>
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) logInfo("Ignoring update from TID " + taskId + " because its task set is gone")
currentJars(name) = timestamp
// Add it to our class loader
val localName = name.split("/").last
val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL
if (!classLoader.getURLs.contains(url)) {
logInfo("Adding " + url + " to class loader")
classLoader.addURL(url)
}
} }
} }
} }
def statusUpdate(taskId :Long, state: TaskState, serializedData: ByteBuffer) { override def stop() {
synchronized { //threadPool.shutdownNow()
val taskSetId = taskIdToTaskSetId(taskId)
val taskSetManager = activeTaskSets(taskSetId)
taskSetTaskIds(taskSetId) -= taskId
taskSetManager.statusUpdate(taskId, state, serializedData)
}
}
override def stop() {
threadPool.shutdownNow()
} }
override def defaultParallelism() = threads override def defaultParallelism() = threads
......
...@@ -132,17 +132,6 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas ...@@ -132,17 +132,6 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
return None return None
} }
def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
SparkEnv.set(env)
state match {
case TaskState.FINISHED =>
taskEnded(tid, state, serializedData)
case TaskState.FAILED =>
taskFailed(tid, state, serializedData)
case _ => {}
}
}
def taskStarted(task: Task[_], info: TaskInfo) { def taskStarted(task: Task[_], info: TaskInfo) {
sched.listener.taskStarted(task, info) sched.listener.taskStarted(task, info)
} }
...@@ -195,5 +184,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas ...@@ -195,5 +184,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
} }
override def error(message: String) { override def error(message: String) {
sched.listener.taskSetFailed(taskSet, message)
sched.taskSetFinished(this)
} }
} }
...@@ -35,7 +35,7 @@ class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll { ...@@ -35,7 +35,7 @@ class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll {
@transient private var sc: SparkContext = _ @transient private var sc: SparkContext = _
override def beforeAll() { override def beforeAll() {
sc = new SparkContext("local-cluster[2,1,512]", "test") sc = new SparkContext("local[2]", "test")
} }
override def afterAll() { override def afterAll() {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment