diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index c109ff930ca781ae33d1883833cc58b49f33380c..6f54fa7a5a631ec3154ef9db8d59de33df7e8e94 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -43,11 +43,10 @@ import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
 import org.apache.spark.partial.{ApproximateEvaluator, PartialResult}
 import org.apache.spark.rdd._
 import org.apache.spark.scheduler._
-import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedulerBackend,
-SimrSchedulerBackend, SparkDeploySchedulerBackend}
-import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend,
-MesosSchedulerBackend}
-import org.apache.spark.scheduler.local.LocalScheduler
+import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend,
+  SparkDeploySchedulerBackend, SimrSchedulerBackend}
+import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
+import org.apache.spark.scheduler.local.LocalBackend
 import org.apache.spark.storage.{BlockManagerSource, RDDInfo, StorageStatus, StorageUtils}
 import org.apache.spark.ui.SparkUI
 import org.apache.spark.util._
@@ -560,9 +559,7 @@ class SparkContext(
     }
     addedFiles(key) = System.currentTimeMillis
 
-    // Fetch the file locally in case a job is executed locally.
-    // Jobs that run through LocalScheduler will already fetch the required dependencies,
-    // but jobs run in DAGScheduler.runLocally() will not so we must fetch the files here.
+    // Fetch the file locally in case a job is executed using DAGScheduler.runLocally().
     Utils.fetchFile(path, new File(SparkFiles.getRootDirectory), conf)
 
     logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key))
@@ -1070,18 +1067,30 @@ object SparkContext {
     // Regular expression for connection to Simr cluster
     val SIMR_REGEX = """simr://(.*)""".r
 
+    // When running locally, don't try to re-execute tasks on failure.
+    val MAX_LOCAL_TASK_FAILURES = 1
+
     master match {
       case "local" =>
-        new LocalScheduler(1, 0, sc)
+        val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true)
+        val backend = new LocalBackend(scheduler, 1)
+        scheduler.initialize(backend)
+        scheduler
 
       case LOCAL_N_REGEX(threads) =>
-        new LocalScheduler(threads.toInt, 0, sc)
+        val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true)
+        val backend = new LocalBackend(scheduler, threads.toInt)
+        scheduler.initialize(backend)
+        scheduler
 
       case LOCAL_N_FAILURES_REGEX(threads, maxFailures) =>
-        new LocalScheduler(threads.toInt, maxFailures.toInt, sc)
+        val scheduler = new TaskSchedulerImpl(sc, maxFailures.toInt, isLocal = true)
+        val backend = new LocalBackend(scheduler, threads.toInt)
+        scheduler.initialize(backend)
+        scheduler
 
       case SPARK_REGEX(sparkUrl) =>
-        val scheduler = new ClusterScheduler(sc)
+        val scheduler = new TaskSchedulerImpl(sc)
         val masterUrls = sparkUrl.split(",").map("spark://" + _)
         val backend = new SparkDeploySchedulerBackend(scheduler, sc, masterUrls, appName)
         scheduler.initialize(backend)
@@ -1096,7 +1105,7 @@ object SparkContext {
               memoryPerSlaveInt, sc.executorMemory))
         }
 
-        val scheduler = new ClusterScheduler(sc)
+        val scheduler = new TaskSchedulerImpl(sc)
         val localCluster = new LocalSparkCluster(
           numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt)
         val masterUrls = localCluster.start()
@@ -1111,7 +1120,7 @@ object SparkContext {
         val scheduler = try {
           val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClusterScheduler")
           val cons = clazz.getConstructor(classOf[SparkContext])
-          cons.newInstance(sc).asInstanceOf[ClusterScheduler]
+          cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl]
         } catch {
           // TODO: Enumerate the exact reasons why it can fail
           // But irrespective of it, it means we cannot proceed !
@@ -1127,7 +1136,7 @@ object SparkContext {
         val scheduler = try {
           val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClientClusterScheduler")
           val cons = clazz.getConstructor(classOf[SparkContext])
-          cons.newInstance(sc).asInstanceOf[ClusterScheduler]
+          cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl]
 
         } catch {
           case th: Throwable => {
@@ -1137,7 +1146,7 @@ object SparkContext {
 
         val backend = try {
           val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend")
-          val cons = clazz.getConstructor(classOf[ClusterScheduler], classOf[SparkContext])
+          val cons = clazz.getConstructor(classOf[TaskSchedulerImpl], classOf[SparkContext])
           cons.newInstance(scheduler, sc).asInstanceOf[CoarseGrainedSchedulerBackend]
         } catch {
           case th: Throwable => {
@@ -1150,7 +1159,7 @@ object SparkContext {
 
       case mesosUrl @ MESOS_REGEX(_) =>
         MesosNativeLibrary.load()
-        val scheduler = new ClusterScheduler(sc)
+        val scheduler = new TaskSchedulerImpl(sc)
         val coarseGrained = sc.conf.getOrElse("spark.mesos.coarse", "false").toBoolean
         val url = mesosUrl.stripPrefix("mesos://") // strip scheme from raw Mesos URLs
         val backend = if (coarseGrained) {
@@ -1162,7 +1171,7 @@ object SparkContext {
         scheduler
 
       case SIMR_REGEX(simrUrl) =>
-        val scheduler = new ClusterScheduler(sc)
+        val scheduler = new TaskSchedulerImpl(sc)
         val backend = new SimrSchedulerBackend(scheduler, sc, simrUrl)
         scheduler.initialize(backend)
         scheduler
diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
index c1e5e04b31e6035193828c06a9b664f9baf87c65..faf6dcd6186237b5d58a4612d78dc81bda2b54da 100644
--- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala
+++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
@@ -53,5 +53,3 @@ private[spark] case class ExceptionFailure(
 private[spark] case object TaskResultLost extends TaskEndReason
 
 private[spark] case object TaskKilled extends TaskEndReason
-
-private[spark] case class OtherFailure(message: String) extends TaskEndReason
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
index ec47ba1b5676866df20b6788d85d4638ef74b900..a801d857707b4ec8ffc1b69e45017be98ee91a3f 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
@@ -140,12 +140,12 @@ class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[I
         <body>
           {linkToMaster}
           <div>
-            <div style="float:left;width:40%">{backButton}</div>
+            <div style="float:left; margin-right:10px">{backButton}</div>
             <div style="float:left;">{range}</div>
-            <div style="float:right;">{nextButton}</div>
+            <div style="float:right; margin-left:10px">{nextButton}</div>
           </div>
           <br />
-          <div style="height:500px;overflow:auto;padding:5px;">
+          <div style="height:500px; overflow:auto; padding:5px;">
             <pre>{logText}</pre>
           </div>
         </body>
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 77aa24e6b6c62bb8406035a8fa88c89217b1463b..e06e49d9d282f2410b6aa62ae0b6924fa1d3b6bf 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -152,7 +152,8 @@ class DAGScheduler(
   val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done
   val running = new HashSet[Stage] // Stages we are running right now
   val failed = new HashSet[Stage]  // Stages that must be resubmitted due to fetch failures
-  val pendingTasks = new TimeStampedHashMap[Stage, HashSet[Task[_]]] // Missing tasks from each stage
+  // Missing tasks from each stage
+  val pendingTasks = new TimeStampedHashMap[Stage, HashSet[Task[_]]]
   var lastFetchFailureTime: Long = 0  // Used to wait a bit to avoid repeated resubmits
 
   val activeJobs = new HashSet[ActiveJob]
@@ -240,7 +241,8 @@ class DAGScheduler(
     shuffleToMapStage.get(shuffleDep.shuffleId) match {
       case Some(stage) => stage
       case None =>
-        val stage = newOrUsedStage(shuffleDep.rdd, shuffleDep.rdd.partitions.size, shuffleDep, jobId)
+        val stage =
+          newOrUsedStage(shuffleDep.rdd, shuffleDep.rdd.partitions.size, shuffleDep, jobId)
         shuffleToMapStage(shuffleDep.shuffleId) = stage
         stage
     }
@@ -249,7 +251,8 @@ class DAGScheduler(
   /**
    * Create a Stage -- either directly for use as a result stage, or as part of the (re)-creation
    * of a shuffle map stage in newOrUsedStage.  The stage will be associated with the provided
-   * jobId. Production of shuffle map stages should always use newOrUsedStage, not newStage directly.
+   * jobId. Production of shuffle map stages should always use newOrUsedStage, not newStage
+   * directly.
    */
   private def newStage(
       rdd: RDD[_],
@@ -359,7 +362,8 @@ class DAGScheduler(
         stageIdToJobIds.getOrElseUpdate(s.id, new HashSet[Int]()) += jobId
         jobIdToStageIds.getOrElseUpdate(jobId, new HashSet[Int]()) += s.id
         val parents = getParentStages(s.rdd, jobId)
-        val parentsWithoutThisJobId = parents.filter(p => !stageIdToJobIds.get(p.id).exists(_.contains(jobId)))
+        val parentsWithoutThisJobId = parents.filter(p =>
+          !stageIdToJobIds.get(p.id).exists(_.contains(jobId)))
         updateJobIdStageIdMapsList(parentsWithoutThisJobId ++ stages.tail)
       }
     }
@@ -367,8 +371,9 @@ class DAGScheduler(
   }
 
   /**
-   * Removes job and any stages that are not needed by any other job.  Returns the set of ids for stages that
-   * were removed.  The associated tasks for those stages need to be cancelled if we got here via job cancellation.
+   * Removes job and any stages that are not needed by any other job.  Returns the set of ids for
+   * stages that were removed.  The associated tasks for those stages need to be cancelled if we
+   * got here via job cancellation.
    */
   private def removeJobAndIndependentStages(jobId: Int): Set[Int] = {
     val registeredStages = jobIdToStageIds(jobId)
@@ -379,7 +384,8 @@ class DAGScheduler(
       stageIdToJobIds.filterKeys(stageId => registeredStages.contains(stageId)).foreach {
         case (stageId, jobSet) =>
           if (!jobSet.contains(jobId)) {
-            logError("Job %d not registered for stage %d even though that stage was registered for the job"
+            logError(
+              "Job %d not registered for stage %d even though that stage was registered for the job"
               .format(jobId, stageId))
           } else {
             def removeStage(stageId: Int) {
@@ -390,7 +396,8 @@ class DAGScheduler(
                   running -= s
                 }
                 stageToInfos -= s
-                shuffleToMapStage.keys.filter(shuffleToMapStage(_) == s).foreach(shuffleToMapStage.remove)
+                shuffleToMapStage.keys.filter(shuffleToMapStage(_) == s).foreach(shuffleId =>
+                  shuffleToMapStage.remove(shuffleId))
                 if (pendingTasks.contains(s) && !pendingTasks(s).isEmpty) {
                   logDebug("Removing pending status for stage %d".format(stageId))
                 }
@@ -408,7 +415,8 @@ class DAGScheduler(
               stageIdToStage -= stageId
               stageIdToJobIds -= stageId
 
-              logDebug("After removal of stage %d, remaining stages = %d".format(stageId, stageIdToStage.size))
+              logDebug("After removal of stage %d, remaining stages = %d"
+                .format(stageId, stageIdToStage.size))
             }
 
             jobSet -= jobId
@@ -460,7 +468,8 @@ class DAGScheduler(
     assert(partitions.size > 0)
     val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
     val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)
-    eventProcessActor ! JobSubmitted(jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties)
+    eventProcessActor ! JobSubmitted(
+      jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties)
     waiter
   }
 
@@ -495,7 +504,8 @@ class DAGScheduler(
     val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
     val partitions = (0 until rdd.partitions.size).toArray
     val jobId = nextJobId.getAndIncrement()
-    eventProcessActor ! JobSubmitted(jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties)
+    eventProcessActor ! JobSubmitted(
+      jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties)
     listener.awaitResult()    // Will throw an exception if the job fails
   }
 
@@ -530,8 +540,8 @@ class DAGScheduler(
       case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) =>
         var finalStage: Stage = null
         try {
-          // New stage creation at times and if its not protected, the scheduler thread is killed.
-          // e.g. it can fail when jobs are run on HadoopRDD whose underlying hdfs files have been deleted
+          // New stage creation may throw an exception if, for example, jobs are run on a HadoopRDD
+          // whose underlying HDFS files have been deleted.
           finalStage = newStage(rdd, partitions.size, None, jobId, Some(callSite))
         } catch {
           case e: Exception =>
@@ -564,7 +574,8 @@ class DAGScheduler(
       case JobGroupCancelled(groupId) =>
         // Cancel all jobs belonging to this job group.
         // First finds all active jobs with this group id, and then kill stages for them.
-        val activeInGroup = activeJobs.filter(groupId == _.properties.get(SparkContext.SPARK_JOB_GROUP_ID))
+        val activeInGroup = activeJobs.filter(activeJob =>
+          groupId == activeJob.properties.get(SparkContext.SPARK_JOB_GROUP_ID))
         val jobIds = activeInGroup.map(_.jobId)
         jobIds.foreach { handleJobCancellation }
 
@@ -586,7 +597,8 @@ class DAGScheduler(
           stage <- stageIdToStage.get(task.stageId);
           stageInfo <- stageToInfos.get(stage)
         ) {
-          if (taskInfo.serializedSize > TASK_SIZE_TO_WARN * 1024 && !stageInfo.emittedTaskSizeWarning) {
+          if (taskInfo.serializedSize > TASK_SIZE_TO_WARN * 1024 &&
+              !stageInfo.emittedTaskSizeWarning) {
             stageInfo.emittedTaskSizeWarning = true
             logWarning(("Stage %d (%s) contains a task of very large " +
               "size (%d KB). The maximum recommended task size is %d KB.").format(
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorLossReason.scala b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala
similarity index 96%
rename from core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorLossReason.scala
rename to core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala
index 5077b2b48b5749aaf0b6264ffc93767f5560a9db..2bc43a91864491ff8f06b8ef2a04ed51131734b8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorLossReason.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala
@@ -15,7 +15,7 @@
  * limitations under the License.
  */
 
-package org.apache.spark.scheduler.cluster
+package org.apache.spark.scheduler
 
 import org.apache.spark.executor.ExecutorExitCode
 
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
index 60927831a159a7d4e2b092b97775ac3005ed1c1e..be5c95e59e05bf6a9ab679caf8f2f77e03b90962 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
@@ -328,10 +328,6 @@ class JobLogger(val user: String, val logDirName: String)
                       task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" +
                       mapId + " REDUCE_ID=" + reduceId
         stageLogInfo(task.stageId, taskStatus)
-      case OtherFailure(message) =>
-        taskStatus += " STATUS=FAILURE TID=" + taskInfo.taskId +
-                      " STAGE_ID=" + task.stageId + " INFO=" + message
-        stageLogInfo(task.stageId, taskStatus)
       case _ =>
     }
   }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
index 596f9adde94938ad6a9c092aeb0d34cb937b665e..17912422150782a72252385ee696044355397f87 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
@@ -117,8 +117,4 @@ private[spark] class Pool(
       parent.decreaseRunningTasks(taskNum)
     }
   }
-
-  override def hasPendingTasks(): Boolean = {
-    schedulableQueue.exists(_.hasPendingTasks())
-  }
 }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala b/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala
index 1c7ea2dccc7d9b60b735fa37c4cc604db8fe6158..d573e125a33d1af5eec279b8d84c5c74c63d1f73 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala
@@ -42,5 +42,4 @@ private[spark] trait Schedulable {
   def executorLost(executorId: String, host: String): Unit
   def checkSpeculatableTasks(): Boolean
   def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager]
-  def hasPendingTasks(): Boolean
 }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
similarity index 89%
rename from core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala
rename to core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
index 65d3fc81875f5ed01d35b4d63310f21028a4960e..02bdbba825781968fe672ca969e9919be7c3a319 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
@@ -15,12 +15,12 @@
  * limitations under the License.
  */
 
-package org.apache.spark.scheduler.cluster
+package org.apache.spark.scheduler
 
 import org.apache.spark.SparkContext
 
 /**
- * A backend interface for cluster scheduling systems that allows plugging in different ones under
+ * A backend interface for scheduling systems that allows plugging in different ones under
  * ClusterScheduler. We assume a Mesos-like model where the application gets resource offers as
  * machines become available and can launch tasks on them.
  */
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
index 3841b5616dca24471a5d3e85baab617308a24f0a..ee63b3c4a15a26aad6c061cd65c29779ae7ab00c 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
@@ -63,7 +63,7 @@ trait SparkListener {
    * Called when a task begins remotely fetching its result (will not be called for tasks that do
    * not need to fetch the result remotely).
    */
- def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { }
+  def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { }
 
   /**
    * Called when a task ends
@@ -131,8 +131,8 @@ object StatsReportListener extends Logging {
 
   def showDistribution(heading: String, d: Distribution, formatNumber: Double => String) {
     val stats = d.statCounter
-    logInfo(heading + stats)
     val quantiles = d.getQuantiles(probabilities).map{formatNumber}
+    logInfo(heading + stats)
     logInfo(percentilesHeader)
     logInfo("\t" + quantiles.mkString("\t"))
   }
@@ -173,8 +173,6 @@ object StatsReportListener extends Logging {
     showMillisDistribution(heading, extractLongDistribution(stage, getMetric))
   }
 
-
-
   val seconds = 1000L
   val minutes = seconds * 60
   val hours = minutes * 60
@@ -198,7 +196,6 @@ object StatsReportListener extends Logging {
 }
 
 
-
 case class RuntimePercentage(executorPct: Double, fetchPct: Option[Double], other: Double)
 object RuntimePercentage {
   def apply(totalTime: Long, metrics: TaskMetrics): RuntimePercentage = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
index d5824e79547974e643b348b12465fa6fe78a2fe0..85687ea330660533c2fb95c1f5016c3db0ffb152 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
@@ -91,4 +91,3 @@ private[spark] class SparkListenerBus() extends Logging {
     return true
   }
 }
-
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
similarity index 93%
rename from core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala
rename to core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
index 319c91b9334925cf4fa63e374ef6ed9a4a87513e..29b0247f8a851930591cd4e1682060b0e003ea71 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -15,21 +15,20 @@
  * limitations under the License.
  */
 
-package org.apache.spark.scheduler.cluster
+package org.apache.spark.scheduler
 
 import java.nio.ByteBuffer
 import java.util.concurrent.{LinkedBlockingDeque, ThreadFactory, ThreadPoolExecutor, TimeUnit}
 
 import org.apache.spark._
 import org.apache.spark.TaskState.TaskState
-import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, TaskResult}
 import org.apache.spark.serializer.SerializerInstance
 import org.apache.spark.util.Utils
 
 /**
  * Runs a thread pool that deserializes and remotely fetches (if necessary) task results.
  */
-private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterScheduler)
+private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedulerImpl)
   extends Logging {
 
   private val THREADS = sparkEnv.conf.getOrElse("spark.resultGetter.threads", "4").toInt
@@ -43,7 +42,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSche
   }
 
   def enqueueSuccessfulTask(
-    taskSetManager: ClusterTaskSetManager, tid: Long, serializedData: ByteBuffer) {
+    taskSetManager: TaskSetManager, tid: Long, serializedData: ByteBuffer) {
     getTaskResultExecutor.execute(new Runnable {
       override def run() {
         try {
@@ -79,7 +78,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSche
     })
   }
 
-  def enqueueFailedTask(taskSetManager: ClusterTaskSetManager, tid: Long, taskState: TaskState,
+  def enqueueFailedTask(taskSetManager: TaskSetManager, tid: Long, taskState: TaskState,
     serializedData: ByteBuffer) {
     var reason: Option[TaskEndReason] = None
     getTaskResultExecutor.execute(new Runnable {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
index 10e047810827c974e7e9230bca60693c1e67cf40..17b6d97e90e0a85f1e16090b03d7546d3510dfeb 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
@@ -20,11 +20,12 @@ package org.apache.spark.scheduler
 import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
 
 /**
- * Low-level task scheduler interface, implemented by both ClusterScheduler and LocalScheduler.
- * Each TaskScheduler schedulers task for a single SparkContext.
- * These schedulers get sets of tasks submitted to them from the DAGScheduler for each stage,
- * and are responsible for sending the tasks to the cluster, running them, retrying if there
- * are failures, and mitigating stragglers. They return events to the DAGScheduler.
+ * Low-level task scheduler interface, currently implemented exclusively by the ClusterScheduler.
+ * This interface allows plugging in different task schedulers. Each TaskScheduler schedulers tasks
+ * for a single SparkContext. These schedulers get sets of tasks submitted to them from the
+ * DAGScheduler for each stage, and are responsible for sending the tasks to the cluster, running
+ * them, retrying if there are failures, and mitigating stragglers. They return events to the
+ * DAGScheduler.
  */
 private[spark] trait TaskScheduler {
 
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
similarity index 93%
rename from core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
rename to core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index 2707740d44361d66e7f8c490a6cb2b487020be52..56a038dc699375a77fda599bdca111aab801b664 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -15,7 +15,7 @@
  * limitations under the License.
  */
 
-package org.apache.spark.scheduler.cluster
+package org.apache.spark.scheduler
 
 import java.nio.ByteBuffer
 import java.util.concurrent.atomic.AtomicLong
@@ -28,37 +28,40 @@ import scala.concurrent.duration._
 
 import org.apache.spark._
 import org.apache.spark.TaskState.TaskState
-import org.apache.spark.scheduler._
 import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
 
 /**
- * The main TaskScheduler implementation, for running tasks on a cluster. Clients should first call
- * initialize() and start(), then submit task sets through the runTasks method.
- *
- * This class can work with multiple types of clusters by acting through a SchedulerBackend.
+ * Schedules tasks for multiple types of clusters by acting through a SchedulerBackend.
+ * It can also work with a local setup by using a LocalBackend and setting isLocal to true.
  * It handles common logic, like determining a scheduling order across jobs, waking up to launch
  * speculative tasks, etc.
  *
+ * Clients should first call initialize() and start(), then submit task sets through the
+ * runTasks method.
+ *
  * THREADING: SchedulerBackends and task-submitting clients can call this class from multiple
  * threads, so it needs locks in public API methods to maintain its state. In addition, some
  * SchedulerBackends sycnchronize on themselves when they want to send events here, and then
  * acquire a lock on us, so we need to make sure that we don't try to lock the backend while
  * we are holding a lock on ourselves.
  */
-private[spark] class ClusterScheduler(val sc: SparkContext)
-  extends TaskScheduler
-  with Logging
+private[spark] class TaskSchedulerImpl(
+    val sc: SparkContext,
+    val maxTaskFailures: Int = System.getProperty("spark.task.maxFailures", "4").toInt,
+    isLocal: Boolean = false)
+  extends TaskScheduler with Logging
 {
   val conf = sc.conf
+
   // How often to check for speculative tasks
   val SPECULATION_INTERVAL = conf.getOrElse("spark.speculation.interval", "100").toLong
 
   // Threshold above which we warn user initial TaskSet may be starved
   val STARVATION_TIMEOUT = conf.getOrElse("spark.starvation.timeout", "15000").toLong
 
-  // ClusterTaskSetManagers are not thread safe, so any access to one should be synchronized
+  // TaskSetManagers are not thread safe, so any access to one should be synchronized
   // on this class.
-  val activeTaskSets = new HashMap[String, ClusterTaskSetManager]
+  val activeTaskSets = new HashMap[String, TaskSetManager]
 
   val taskIdToTaskSetId = new HashMap[Long, String]
   val taskIdToExecutorId = new HashMap[Long, String]
@@ -120,7 +123,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
   override def start() {
     backend.start()
 
-    if (conf.getOrElse("spark.speculation", "false").toBoolean) {
+    if (!isLocal && conf.getOrElse("spark.speculation", "false").toBoolean) {
       logInfo("Starting speculative execution thread")
       import sc.env.actorSystem.dispatcher
       sc.env.actorSystem.scheduler.schedule(SPECULATION_INTERVAL milliseconds,
@@ -134,12 +137,12 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
     val tasks = taskSet.tasks
     logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
     this.synchronized {
-      val manager = new ClusterTaskSetManager(this, taskSet)
+      val manager = new TaskSetManager(this, taskSet, maxTaskFailures)
       activeTaskSets(taskSet.id) = manager
       schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
       taskSetTaskIds(taskSet.id) = new HashSet[Long]()
 
-      if (!hasReceivedTask) {
+      if (!isLocal && !hasReceivedTask) {
         starvationTimer.scheduleAtFixedRate(new TimerTask() {
           override def run() {
             if (!hasLaunchedTask) {
@@ -293,19 +296,19 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
     }
   }
 
-  def handleTaskGettingResult(taskSetManager: ClusterTaskSetManager, tid: Long) {
+  def handleTaskGettingResult(taskSetManager: TaskSetManager, tid: Long) {
     taskSetManager.handleTaskGettingResult(tid)
   }
 
   def handleSuccessfulTask(
-    taskSetManager: ClusterTaskSetManager,
+    taskSetManager: TaskSetManager,
     tid: Long,
     taskResult: DirectTaskResult[_]) = synchronized {
     taskSetManager.handleSuccessfulTask(tid, taskResult)
   }
 
   def handleFailedTask(
-    taskSetManager: ClusterTaskSetManager,
+    taskSetManager: TaskSetManager,
     tid: Long,
     taskState: TaskState,
     reason: Option[TaskEndReason]) = synchronized {
@@ -353,7 +356,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
 
   override def defaultParallelism() = backend.defaultParallelism()
 
-
   // Check for speculatable tasks in all our active jobs.
   def checkSpeculatableTasks() {
     var shouldRevive = false
@@ -365,13 +367,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
     }
   }
 
-  // Check for pending tasks in all our active jobs.
-  def hasPendingTasks: Boolean = {
-    synchronized {
-      rootPool.hasPendingTasks()
-    }
-  }
-
   def executorLost(executorId: String, reason: ExecutorLossReason) {
     var failedExecutor: Option[String] = None
 
@@ -430,7 +425,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
 }
 
 
-object ClusterScheduler {
+private[spark] object TaskSchedulerImpl {
   /**
    * Used to balance containers across hosts.
    *
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 90f6bcefac0bfab0e29196c963f3d9ef3ac29fa2..9b95e418d81c384e0f428ca62b879499df447b6f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -17,32 +17,702 @@
 
 package org.apache.spark.scheduler
 
-import java.nio.ByteBuffer
+import java.io.NotSerializableException
+import java.util.Arrays
 
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
+import scala.math.max
+import scala.math.min
+
+import org.apache.spark.{ExceptionFailure, FetchFailed, Logging, Resubmitted, SparkEnv,
+  Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState}
 import org.apache.spark.TaskState.TaskState
+import org.apache.spark.util.{Clock, SystemClock}
+
 
 /**
- * Tracks and schedules the tasks within a single TaskSet. This class keeps track of the status of
- * each task and is responsible for retries on failure and locality. The main interfaces to it
- * are resourceOffer, which asks the TaskSet whether it wants to run a task on one node, and
- * statusUpdate, which tells it that one of its tasks changed state (e.g. finished).
+ * Schedules the tasks within a single TaskSet in the ClusterScheduler. This class keeps track of
+ * each task, retries tasks if they fail (up to a limited number of times), and
+ * handles locality-aware scheduling for this TaskSet via delay scheduling. The main interfaces
+ * to it are resourceOffer, which asks the TaskSet whether it wants to run a task on one node,
+ * and statusUpdate, which tells it that one of its tasks changed state (e.g. finished).
+ *
+ * THREADING: This class is designed to only be called from code with a lock on the
+ * TaskScheduler (e.g. its event handlers). It should not be called from other threads.
  *
- * THREADING: This class is designed to only be called from code with a lock on the TaskScheduler
- * (e.g. its event handlers). It should not be called from other threads.
+ * @param sched           the ClusterScheduler associated with the TaskSetManager
+ * @param taskSet         the TaskSet to manage scheduling for
+ * @param maxTaskFailures if any particular task fails more than this number of times, the entire
+ *                        task set will be aborted
  */
-private[spark] trait TaskSetManager extends Schedulable {
-  def schedulableQueue = null
-  
-  def schedulingMode = SchedulingMode.NONE
-  
-  def taskSet: TaskSet
+private[spark] class TaskSetManager(
+    sched: TaskSchedulerImpl,
+    val taskSet: TaskSet,
+    val maxTaskFailures: Int,
+    clock: Clock = SystemClock)
+  extends Schedulable with Logging
+{
+  // CPUs to request per task
+  val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toInt
+
+  // Quantile of tasks at which to start speculation
+  val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble
+  val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble
+
+  // Serializer for closures and tasks.
+  val env = SparkEnv.get
+  val ser = env.closureSerializer.newInstance()
+
+  val tasks = taskSet.tasks
+  val numTasks = tasks.length
+  val copiesRunning = new Array[Int](numTasks)
+  val successful = new Array[Boolean](numTasks)
+  val numFailures = new Array[Int](numTasks)
+  val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil)
+  var tasksSuccessful = 0
+
+  var weight = 1
+  var minShare = 0
+  var priority = taskSet.priority
+  var stageId = taskSet.stageId
+  var name = "TaskSet_"+taskSet.stageId.toString
+  var parent: Pool = null
+
+  var runningTasks = 0
+  private val runningTasksSet = new HashSet[Long]
+
+  // Set of pending tasks for each executor. These collections are actually
+  // treated as stacks, in which new tasks are added to the end of the
+  // ArrayBuffer and removed from the end. This makes it faster to detect
+  // tasks that repeatedly fail because whenever a task failed, it is put
+  // back at the head of the stack. They are also only cleaned up lazily;
+  // when a task is launched, it remains in all the pending lists except
+  // the one that it was launched from, but gets removed from them later.
+  private val pendingTasksForExecutor = new HashMap[String, ArrayBuffer[Int]]
+
+  // Set of pending tasks for each host. Similar to pendingTasksForExecutor,
+  // but at host level.
+  private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]]
+
+  // Set of pending tasks for each rack -- similar to the above.
+  private val pendingTasksForRack = new HashMap[String, ArrayBuffer[Int]]
+
+  // Set containing pending tasks with no locality preferences.
+  val pendingTasksWithNoPrefs = new ArrayBuffer[Int]
+
+  // Set containing all pending tasks (also used as a stack, as above).
+  val allPendingTasks = new ArrayBuffer[Int]
+
+  // Tasks that can be speculated. Since these will be a small fraction of total
+  // tasks, we'll just hold them in a HashSet.
+  val speculatableTasks = new HashSet[Int]
+
+  // Task index, start and finish time for each task attempt (indexed by task ID)
+  val taskInfos = new HashMap[Long, TaskInfo]
+
+  // Did the TaskSet fail?
+  var failed = false
+  var causeOfFailure = ""
+
+  // How frequently to reprint duplicate exceptions in full, in milliseconds
+  val EXCEPTION_PRINT_INTERVAL =
+    System.getProperty("spark.logging.exceptionPrintInterval", "10000").toLong
+
+  // Map of recent exceptions (identified by string representation and top stack frame) to
+  // duplicate count (how many times the same exception has appeared) and time the full exception
+  // was printed. This should ideally be an LRU map that can drop old exceptions automatically.
+  val recentExceptions = HashMap[String, (Int, Long)]()
+
+  // Figure out the current map output tracker epoch and set it on all tasks
+  val epoch = sched.mapOutputTracker.getEpoch
+  logDebug("Epoch for " + taskSet + ": " + epoch)
+  for (t <- tasks) {
+    t.epoch = epoch
+  }
+
+  // Add all our tasks to the pending lists. We do this in reverse order
+  // of task index so that tasks with low indices get launched first.
+  for (i <- (0 until numTasks).reverse) {
+    addPendingTask(i)
+  }
+
+  // Figure out which locality levels we have in our TaskSet, so we can do delay scheduling
+  val myLocalityLevels = computeValidLocalityLevels()
+  val localityWaits = myLocalityLevels.map(getLocalityWait) // Time to wait at each level
+
+  // Delay scheduling variables: we keep track of our current locality level and the time we
+  // last launched a task at that level, and move up a level when localityWaits[curLevel] expires.
+  // We then move down if we manage to launch a "more local" task.
+  var currentLocalityIndex = 0    // Index of our current locality level in validLocalityLevels
+  var lastLaunchTime = clock.getTime()  // Time we last launched a task at this level
+
+  override def schedulableQueue = null
+
+  override def schedulingMode = SchedulingMode.NONE
+
+  /**
+   * Add a task to all the pending-task lists that it should be on. If readding is set, we are
+   * re-adding the task so only include it in each list if it's not already there.
+   */
+  private def addPendingTask(index: Int, readding: Boolean = false) {
+    // Utility method that adds `index` to a list only if readding=false or it's not already there
+    def addTo(list: ArrayBuffer[Int]) {
+      if (!readding || !list.contains(index)) {
+        list += index
+      }
+    }
+
+    var hadAliveLocations = false
+    for (loc <- tasks(index).preferredLocations) {
+      for (execId <- loc.executorId) {
+        if (sched.isExecutorAlive(execId)) {
+          addTo(pendingTasksForExecutor.getOrElseUpdate(execId, new ArrayBuffer))
+          hadAliveLocations = true
+        }
+      }
+      if (sched.hasExecutorsAliveOnHost(loc.host)) {
+        addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer))
+        for (rack <- sched.getRackForHost(loc.host)) {
+          addTo(pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer))
+        }
+        hadAliveLocations = true
+      }
+    }
+
+    if (!hadAliveLocations) {
+      // Even though the task might've had preferred locations, all of those hosts or executors
+      // are dead; put it in the no-prefs list so we can schedule it elsewhere right away.
+      addTo(pendingTasksWithNoPrefs)
+    }
+
+    if (!readding) {
+      allPendingTasks += index  // No point scanning this whole list to find the old task there
+    }
+  }
+
+  /**
+   * Return the pending tasks list for a given executor ID, or an empty list if
+   * there is no map entry for that host
+   */
+  private def getPendingTasksForExecutor(executorId: String): ArrayBuffer[Int] = {
+    pendingTasksForExecutor.getOrElse(executorId, ArrayBuffer())
+  }
+
+  /**
+   * Return the pending tasks list for a given host, or an empty list if
+   * there is no map entry for that host
+   */
+  private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = {
+    pendingTasksForHost.getOrElse(host, ArrayBuffer())
+  }
+
+  /**
+   * Return the pending rack-local task list for a given rack, or an empty list if
+   * there is no map entry for that rack
+   */
+  private def getPendingTasksForRack(rack: String): ArrayBuffer[Int] = {
+    pendingTasksForRack.getOrElse(rack, ArrayBuffer())
+  }
+
+  /**
+   * Dequeue a pending task from the given list and return its index.
+   * Return None if the list is empty.
+   * This method also cleans up any tasks in the list that have already
+   * been launched, since we want that to happen lazily.
+   */
+  private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
+    while (!list.isEmpty) {
+      val index = list.last
+      list.trimEnd(1)
+      if (copiesRunning(index) == 0 && !successful(index)) {
+        return Some(index)
+      }
+    }
+    return None
+  }
+
+  /** Check whether a task is currently running an attempt on a given host */
+  private def hasAttemptOnHost(taskIndex: Int, host: String): Boolean = {
+    !taskAttempts(taskIndex).exists(_.host == host)
+  }
+
+  /**
+   * Return a speculative task for a given executor if any are available. The task should not have
+   * an attempt running on this host, in case the host is slow. In addition, the task should meet
+   * the given locality constraint.
+   */
+  private def findSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value)
+    : Option[(Int, TaskLocality.Value)] =
+  {
+    speculatableTasks.retain(index => !successful(index)) // Remove finished tasks from set
+
+    if (!speculatableTasks.isEmpty) {
+      // Check for process-local or preference-less tasks; note that tasks can be process-local
+      // on multiple nodes when we replicate cached blocks, as in Spark Streaming
+      for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
+        val prefs = tasks(index).preferredLocations
+        val executors = prefs.flatMap(_.executorId)
+        if (prefs.size == 0 || executors.contains(execId)) {
+          speculatableTasks -= index
+          return Some((index, TaskLocality.PROCESS_LOCAL))
+        }
+      }
+
+      // Check for node-local tasks
+      if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) {
+        for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
+          val locations = tasks(index).preferredLocations.map(_.host)
+          if (locations.contains(host)) {
+            speculatableTasks -= index
+            return Some((index, TaskLocality.NODE_LOCAL))
+          }
+        }
+      }
 
+      // Check for rack-local tasks
+      if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
+        for (rack <- sched.getRackForHost(host)) {
+          for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
+            val racks = tasks(index).preferredLocations.map(_.host).map(sched.getRackForHost)
+            if (racks.contains(rack)) {
+              speculatableTasks -= index
+              return Some((index, TaskLocality.RACK_LOCAL))
+            }
+          }
+        }
+      }
+
+      // Check for non-local tasks
+      if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
+        for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
+          speculatableTasks -= index
+          return Some((index, TaskLocality.ANY))
+        }
+      }
+    }
+
+    return None
+  }
+
+  /**
+   * Dequeue a pending task for a given node and return its index and locality level.
+   * Only search for tasks matching the given locality constraint.
+   */
+  private def findTask(execId: String, host: String, locality: TaskLocality.Value)
+    : Option[(Int, TaskLocality.Value)] =
+  {
+    for (index <- findTaskFromList(getPendingTasksForExecutor(execId))) {
+      return Some((index, TaskLocality.PROCESS_LOCAL))
+    }
+
+    if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) {
+      for (index <- findTaskFromList(getPendingTasksForHost(host))) {
+        return Some((index, TaskLocality.NODE_LOCAL))
+      }
+    }
+
+    if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
+      for {
+        rack <- sched.getRackForHost(host)
+        index <- findTaskFromList(getPendingTasksForRack(rack))
+      } {
+        return Some((index, TaskLocality.RACK_LOCAL))
+      }
+    }
+
+    // Look for no-pref tasks after rack-local tasks since they can run anywhere.
+    for (index <- findTaskFromList(pendingTasksWithNoPrefs)) {
+      return Some((index, TaskLocality.PROCESS_LOCAL))
+    }
+
+    if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
+      for (index <- findTaskFromList(allPendingTasks)) {
+        return Some((index, TaskLocality.ANY))
+      }
+    }
+
+    // Finally, if all else has failed, find a speculative task
+    return findSpeculativeTask(execId, host, locality)
+  }
+
+  /**
+   * Respond to an offer of a single executor from the scheduler by finding a task
+   */
   def resourceOffer(
       execId: String,
       host: String,
       availableCpus: Int,
       maxLocality: TaskLocality.TaskLocality)
-    : Option[TaskDescription]
+    : Option[TaskDescription] =
+  {
+    if (tasksSuccessful < numTasks && availableCpus >= CPUS_PER_TASK) {
+      val curTime = clock.getTime()
+
+      var allowedLocality = getAllowedLocalityLevel(curTime)
+      if (allowedLocality > maxLocality) {
+        allowedLocality = maxLocality   // We're not allowed to search for farther-away tasks
+      }
+
+      findTask(execId, host, allowedLocality) match {
+        case Some((index, taskLocality)) => {
+          // Found a task; do some bookkeeping and return a task description
+          val task = tasks(index)
+          val taskId = sched.newTaskId()
+          // Figure out whether this should count as a preferred launch
+          logInfo("Starting task %s:%d as TID %s on executor %s: %s (%s)".format(
+            taskSet.id, index, taskId, execId, host, taskLocality))
+          // Do various bookkeeping
+          copiesRunning(index) += 1
+          val info = new TaskInfo(taskId, index, curTime, execId, host, taskLocality)
+          taskInfos(taskId) = info
+          taskAttempts(index) = info :: taskAttempts(index)
+          // Update our locality level for delay scheduling
+          currentLocalityIndex = getLocalityIndex(taskLocality)
+          lastLaunchTime = curTime
+          // Serialize and return the task
+          val startTime = clock.getTime()
+          // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here
+          // we assume the task can be serialized without exceptions.
+          val serializedTask = Task.serializeWithDependencies(
+            task, sched.sc.addedFiles, sched.sc.addedJars, ser)
+          val timeTaken = clock.getTime() - startTime
+          addRunningTask(taskId)
+          logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
+            taskSet.id, index, serializedTask.limit, timeTaken))
+          val taskName = "task %s:%d".format(taskSet.id, index)
+          if (taskAttempts(index).size == 1)
+            taskStarted(task,info)
+          return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask))
+        }
+        case _ =>
+      }
+    }
+    return None
+  }
+
+  /**
+   * Get the level we can launch tasks according to delay scheduling, based on current wait time.
+   */
+  private def getAllowedLocalityLevel(curTime: Long): TaskLocality.TaskLocality = {
+    while (curTime - lastLaunchTime >= localityWaits(currentLocalityIndex) &&
+        currentLocalityIndex < myLocalityLevels.length - 1)
+    {
+      // Jump to the next locality level, and remove our waiting time for the current one since
+      // we don't want to count it again on the next one
+      lastLaunchTime += localityWaits(currentLocalityIndex)
+      currentLocalityIndex += 1
+    }
+    myLocalityLevels(currentLocalityIndex)
+  }
+
+  /**
+   * Find the index in myLocalityLevels for a given locality. This is also designed to work with
+   * localities that are not in myLocalityLevels (in case we somehow get those) by returning the
+   * next-biggest level we have. Uses the fact that the last value in myLocalityLevels is ANY.
+   */
+  def getLocalityIndex(locality: TaskLocality.TaskLocality): Int = {
+    var index = 0
+    while (locality > myLocalityLevels(index)) {
+      index += 1
+    }
+    index
+  }
+
+  private def taskStarted(task: Task[_], info: TaskInfo) {
+    sched.dagScheduler.taskStarted(task, info)
+  }
+
+  def handleTaskGettingResult(tid: Long) = {
+    val info = taskInfos(tid)
+    info.markGettingResult()
+    sched.dagScheduler.taskGettingResult(tasks(info.index), info)
+  }
+
+  /**
+   * Marks the task as successful and notifies the DAGScheduler that a task has ended.
+   */
+  def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]) = {
+    val info = taskInfos(tid)
+    val index = info.index
+    info.markSuccessful()
+    removeRunningTask(tid)
+    if (!successful(index)) {
+      logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format(
+        tid, info.duration, info.host, tasksSuccessful, numTasks))
+      sched.dagScheduler.taskEnded(
+        tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
+
+      // Mark successful and stop if all the tasks have succeeded.
+      tasksSuccessful += 1
+      successful(index) = true
+      if (tasksSuccessful == numTasks) {
+        sched.taskSetFinished(this)
+      }
+    } else {
+      logInfo("Ignorning task-finished event for TID " + tid + " because task " +
+        index + " has already completed successfully")
+    }
+  }
+
+  /**
+   * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the
+   * DAG Scheduler.
+   */
+  def handleFailedTask(tid: Long, state: TaskState, reason: Option[TaskEndReason]) {
+    val info = taskInfos(tid)
+    if (info.failed) {
+      return
+    }
+    removeRunningTask(tid)
+    val index = info.index
+    info.markFailed()
+    var failureReason = "unknown"
+    if (!successful(index)) {
+      logWarning("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
+      copiesRunning(index) -= 1
+      // Check if the problem is a map output fetch failure. In that case, this
+      // task will never succeed on any node, so tell the scheduler about it.
+      reason.foreach {
+        case fetchFailed: FetchFailed =>
+          logWarning("Loss was due to fetch failure from " + fetchFailed.bmAddress)
+          sched.dagScheduler.taskEnded(tasks(index), fetchFailed, null, null, info, null)
+          successful(index) = true
+          tasksSuccessful += 1
+          sched.taskSetFinished(this)
+          removeAllRunningTasks()
+          return
+
+        case TaskKilled =>
+          logWarning("Task %d was killed.".format(tid))
+          sched.dagScheduler.taskEnded(tasks(index), reason.get, null, null, info, null)
+          return
+
+        case ef: ExceptionFailure =>
+          sched.dagScheduler.taskEnded(
+            tasks(index), ef, null, null, info, ef.metrics.getOrElse(null))
+          if (ef.className == classOf[NotSerializableException].getName()) {
+            // If the task result wasn't rerializable, there's no point in trying to re-execute it.
+            logError("Task %s:%s had a not serializable result: %s; not retrying".format(
+              taskSet.id, index, ef.description))
+            abort("Task %s:%s had a not serializable result: %s".format(
+              taskSet.id, index, ef.description))
+            return
+          }
+          val key = ef.description
+          failureReason = "Exception failure: %s".format(ef.description)
+          val now = clock.getTime()
+          val (printFull, dupCount) = {
+            if (recentExceptions.contains(key)) {
+              val (dupCount, printTime) = recentExceptions(key)
+              if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
+                recentExceptions(key) = (0, now)
+                (true, 0)
+              } else {
+                recentExceptions(key) = (dupCount + 1, printTime)
+                (false, dupCount + 1)
+              }
+            } else {
+              recentExceptions(key) = (0, now)
+              (true, 0)
+            }
+          }
+          if (printFull) {
+            val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString))
+            logWarning("Loss was due to %s\n%s\n%s".format(
+              ef.className, ef.description, locs.mkString("\n")))
+          } else {
+            logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount))
+          }
+
+        case TaskResultLost =>
+          failureReason = "Lost result for TID %s on host %s".format(tid, info.host)
+          logWarning(failureReason)
+          sched.dagScheduler.taskEnded(tasks(index), TaskResultLost, null, null, info, null)
+
+        case _ => {}
+      }
+      // On non-fetch failures, re-enqueue the task as pending for a max number of retries
+      addPendingTask(index)
+      if (state != TaskState.KILLED) {
+        numFailures(index) += 1
+        if (numFailures(index) >= maxTaskFailures) {
+          logError("Task %s:%d failed %d times; aborting job".format(
+            taskSet.id, index, maxTaskFailures))
+          abort("Task %s:%d failed %d times (most recent failure: %s)".format(
+            taskSet.id, index, maxTaskFailures, failureReason))
+        }
+      }
+    } else {
+      logInfo("Ignoring task-lost event for TID " + tid +
+        " because task " + index + " is already finished")
+    }
+  }
+
+  def error(message: String) {
+    // Save the error message
+    abort("Error: " + message)
+  }
+
+  def abort(message: String) {
+    failed = true
+    causeOfFailure = message
+    // TODO: Kill running tasks if we were not terminated due to a Mesos error
+    sched.dagScheduler.taskSetFailed(taskSet, message)
+    removeAllRunningTasks()
+    sched.taskSetFinished(this)
+  }
+
+  /** If the given task ID is not in the set of running tasks, adds it.
+   *
+   * Used to keep track of the number of running tasks, for enforcing scheduling policies.
+   */
+  def addRunningTask(tid: Long) {
+    if (runningTasksSet.add(tid) && parent != null) {
+      parent.increaseRunningTasks(1)
+    }
+    runningTasks = runningTasksSet.size
+  }
+
+  /** If the given task ID is in the set of running tasks, removes it. */
+  def removeRunningTask(tid: Long) {
+    if (runningTasksSet.remove(tid) && parent != null) {
+      parent.decreaseRunningTasks(1)
+    }
+    runningTasks = runningTasksSet.size
+  }
+
+  private[scheduler] def removeAllRunningTasks() {
+    val numRunningTasks = runningTasksSet.size
+    runningTasksSet.clear()
+    if (parent != null) {
+      parent.decreaseRunningTasks(numRunningTasks)
+    }
+    runningTasks = 0
+  }
+
+  override def getSchedulableByName(name: String): Schedulable = {
+    return null
+  }
+
+  override def addSchedulable(schedulable: Schedulable) {}
+
+  override def removeSchedulable(schedulable: Schedulable) {}
+
+  override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
+    var sortedTaskSetQueue = ArrayBuffer[TaskSetManager](this)
+    sortedTaskSetQueue += this
+    return sortedTaskSetQueue
+  }
+
+  /** Called by TaskScheduler when an executor is lost so we can re-enqueue our tasks */
+  override def executorLost(execId: String, host: String) {
+    logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id)
+
+    // Re-enqueue pending tasks for this host based on the status of the cluster -- for example, a
+    // task that used to have locations on only this host might now go to the no-prefs list. Note
+    // that it's okay if we add a task to the same queue twice (if it had multiple preferred
+    // locations), because findTaskFromList will skip already-running tasks.
+    for (index <- getPendingTasksForExecutor(execId)) {
+      addPendingTask(index, readding=true)
+    }
+    for (index <- getPendingTasksForHost(host)) {
+      addPendingTask(index, readding=true)
+    }
+
+    // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage
+    if (tasks(0).isInstanceOf[ShuffleMapTask]) {
+      for ((tid, info) <- taskInfos if info.executorId == execId) {
+        val index = taskInfos(tid).index
+        if (successful(index)) {
+          successful(index) = false
+          copiesRunning(index) -= 1
+          tasksSuccessful -= 1
+          addPendingTask(index)
+          // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
+          // stage finishes when a total of tasks.size tasks finish.
+          sched.dagScheduler.taskEnded(tasks(index), Resubmitted, null, null, info, null)
+        }
+      }
+    }
+    // Also re-enqueue any tasks that were running on the node
+    for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
+      handleFailedTask(tid, TaskState.KILLED, None)
+    }
+  }
+
+  /**
+   * Check for tasks to be speculated and return true if there are any. This is called periodically
+   * by the TaskScheduler.
+   *
+   * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that
+   * we don't scan the whole task set. It might also help to make this sorted by launch time.
+   */
+  override def checkSpeculatableTasks(): Boolean = {
+    // Can't speculate if we only have one task, or if all tasks have finished.
+    if (numTasks == 1 || tasksSuccessful == numTasks) {
+      return false
+    }
+    var foundTasks = false
+    val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
+    logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
+    if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) {
+      val time = clock.getTime()
+      val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
+      Arrays.sort(durations)
+      val medianDuration = durations(min((0.5 * tasksSuccessful).round.toInt, durations.size - 1))
+      val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100)
+      // TODO: Threshold should also look at standard deviation of task durations and have a lower
+      // bound based on that.
+      logDebug("Task length threshold for speculation: " + threshold)
+      for ((tid, info) <- taskInfos) {
+        val index = info.index
+        if (!successful(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold &&
+          !speculatableTasks.contains(index)) {
+          logInfo(
+            "Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format(
+              taskSet.id, index, info.host, threshold))
+          speculatableTasks += index
+          foundTasks = true
+        }
+      }
+    }
+    return foundTasks
+  }
+
+  private def getLocalityWait(level: TaskLocality.TaskLocality): Long = {
+    val defaultWait = System.getProperty("spark.locality.wait", "3000")
+    level match {
+      case TaskLocality.PROCESS_LOCAL =>
+        System.getProperty("spark.locality.wait.process", defaultWait).toLong
+      case TaskLocality.NODE_LOCAL =>
+        System.getProperty("spark.locality.wait.node", defaultWait).toLong
+      case TaskLocality.RACK_LOCAL =>
+        System.getProperty("spark.locality.wait.rack", defaultWait).toLong
+      case TaskLocality.ANY =>
+        0L
+    }
+  }
 
-  def error(message: String)
+  /**
+   * Compute the locality levels used in this TaskSet. Assumes that all tasks have already been
+   * added to queues using addPendingTask.
+   */
+  private def computeValidLocalityLevels(): Array[TaskLocality.TaskLocality] = {
+    import TaskLocality.{PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY}
+    val levels = new ArrayBuffer[TaskLocality.TaskLocality]
+    if (!pendingTasksForExecutor.isEmpty && getLocalityWait(PROCESS_LOCAL) != 0) {
+      levels += PROCESS_LOCAL
+    }
+    if (!pendingTasksForHost.isEmpty && getLocalityWait(NODE_LOCAL) != 0) {
+      levels += NODE_LOCAL
+    }
+    if (!pendingTasksForRack.isEmpty && getLocalityWait(RACK_LOCAL) != 0) {
+      levels += RACK_LOCAL
+    }
+    levels += ANY
+    logDebug("Valid locality levels for " + taskSet + ": " + levels.mkString(", "))
+    levels.toArray
+  }
 }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/WorkerOffer.scala b/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala
similarity index 95%
rename from core/src/main/scala/org/apache/spark/scheduler/cluster/WorkerOffer.scala
rename to core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala
index 938f62883a104df5e60b11333e939ef20293792e..ba6bab3f91a65eb0bf972c2368c2df61bac0462b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/WorkerOffer.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala
@@ -15,7 +15,7 @@
  * limitations under the License.
  */
 
-package org.apache.spark.scheduler.cluster
+package org.apache.spark.scheduler
 
 /**
  * Represents free resources available on an executor.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
deleted file mode 100644
index a46b16b92fa455f9a3971a3ee0de3a9c03ca7946..0000000000000000000000000000000000000000
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
+++ /dev/null
@@ -1,714 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler.cluster
-
-import java.io.NotSerializableException
-import java.util.Arrays
-
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.HashSet
-import scala.math.max
-import scala.math.min
-
-import org.apache.spark.{ExceptionFailure, FetchFailed, Logging, Resubmitted, SparkEnv,
-  Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState}
-import org.apache.spark.TaskState.TaskState
-import org.apache.spark.scheduler._
-import org.apache.spark.util.{SystemClock, Clock}
-
-
-/**
- * Schedules the tasks within a single TaskSet in the ClusterScheduler. This class keeps track of
- * the status of each task, retries tasks if they fail (up to a limited number of times), and
- * handles locality-aware scheduling for this TaskSet via delay scheduling. The main interfaces
- * to it are resourceOffer, which asks the TaskSet whether it wants to run a task on one node,
- * and statusUpdate, which tells it that one of its tasks changed state (e.g. finished).
- *
- * THREADING: This class is designed to only be called from code with a lock on the
- * ClusterScheduler (e.g. its event handlers). It should not be called from other threads.
- */
-private[spark] class ClusterTaskSetManager(
-    sched: ClusterScheduler,
-    val taskSet: TaskSet,
-    clock: Clock = SystemClock)
-  extends TaskSetManager
-  with Logging
-{
-  val conf = sched.sc.conf
-  // CPUs to request per task
-  val CPUS_PER_TASK = conf.getOrElse("spark.task.cpus", "1").toInt
-
-  // Maximum times a task is allowed to fail before failing the job
-  val MAX_TASK_FAILURES = conf.getOrElse("spark.task.maxFailures", "4").toInt
-
-  // Quantile of tasks at which to start speculation
-  val SPECULATION_QUANTILE = conf.getOrElse("spark.speculation.quantile", "0.75").toDouble
-  val SPECULATION_MULTIPLIER = conf.getOrElse("spark.speculation.multiplier", "1.5").toDouble
-
-  // Serializer for closures and tasks.
-  val env = SparkEnv.get
-  val ser = env.closureSerializer.newInstance()
-
-  val tasks = taskSet.tasks
-  val numTasks = tasks.length
-  val copiesRunning = new Array[Int](numTasks)
-  val successful = new Array[Boolean](numTasks)
-  val numFailures = new Array[Int](numTasks)
-  val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil)
-  var tasksSuccessful = 0
-
-  var weight = 1
-  var minShare = 0
-  var priority = taskSet.priority
-  var stageId = taskSet.stageId
-  var name = "TaskSet_"+taskSet.stageId.toString
-  var parent: Pool = null
-
-  var runningTasks = 0
-  private val runningTasksSet = new HashSet[Long]
-
-  // Set of pending tasks for each executor. These collections are actually
-  // treated as stacks, in which new tasks are added to the end of the
-  // ArrayBuffer and removed from the end. This makes it faster to detect
-  // tasks that repeatedly fail because whenever a task failed, it is put
-  // back at the head of the stack. They are also only cleaned up lazily;
-  // when a task is launched, it remains in all the pending lists except
-  // the one that it was launched from, but gets removed from them later.
-  private val pendingTasksForExecutor = new HashMap[String, ArrayBuffer[Int]]
-
-  // Set of pending tasks for each host. Similar to pendingTasksForExecutor,
-  // but at host level.
-  private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]]
-
-  // Set of pending tasks for each rack -- similar to the above.
-  private val pendingTasksForRack = new HashMap[String, ArrayBuffer[Int]]
-
-  // Set containing pending tasks with no locality preferences.
-  val pendingTasksWithNoPrefs = new ArrayBuffer[Int]
-
-  // Set containing all pending tasks (also used as a stack, as above).
-  val allPendingTasks = new ArrayBuffer[Int]
-
-  // Tasks that can be speculated. Since these will be a small fraction of total
-  // tasks, we'll just hold them in a HashSet.
-  val speculatableTasks = new HashSet[Int]
-
-  // Task index, start and finish time for each task attempt (indexed by task ID)
-  val taskInfos = new HashMap[Long, TaskInfo]
-
-  // Did the TaskSet fail?
-  var failed = false
-  var causeOfFailure = ""
-
-  // How frequently to reprint duplicate exceptions in full, in milliseconds
-  val EXCEPTION_PRINT_INTERVAL =
-    conf.getOrElse("spark.logging.exceptionPrintInterval", "10000").toLong
-
-  // Map of recent exceptions (identified by string representation and top stack frame) to
-  // duplicate count (how many times the same exception has appeared) and time the full exception
-  // was printed. This should ideally be an LRU map that can drop old exceptions automatically.
-  val recentExceptions = HashMap[String, (Int, Long)]()
-
-  // Figure out the current map output tracker epoch and set it on all tasks
-  val epoch = sched.mapOutputTracker.getEpoch
-  logDebug("Epoch for " + taskSet + ": " + epoch)
-  for (t <- tasks) {
-    t.epoch = epoch
-  }
-
-  // Add all our tasks to the pending lists. We do this in reverse order
-  // of task index so that tasks with low indices get launched first.
-  for (i <- (0 until numTasks).reverse) {
-    addPendingTask(i)
-  }
-
-  // Figure out which locality levels we have in our TaskSet, so we can do delay scheduling
-  val myLocalityLevels = computeValidLocalityLevels()
-  val localityWaits = myLocalityLevels.map(getLocalityWait) // Time to wait at each level
-
-  // Delay scheduling variables: we keep track of our current locality level and the time we
-  // last launched a task at that level, and move up a level when localityWaits[curLevel] expires.
-  // We then move down if we manage to launch a "more local" task.
-  var currentLocalityIndex = 0    // Index of our current locality level in validLocalityLevels
-  var lastLaunchTime = clock.getTime()  // Time we last launched a task at this level
-
-  /**
-   * Add a task to all the pending-task lists that it should be on. If readding is set, we are
-   * re-adding the task so only include it in each list if it's not already there.
-   */
-  private def addPendingTask(index: Int, readding: Boolean = false) {
-    // Utility method that adds `index` to a list only if readding=false or it's not already there
-    def addTo(list: ArrayBuffer[Int]) {
-      if (!readding || !list.contains(index)) {
-        list += index
-      }
-    }
-
-    var hadAliveLocations = false
-    for (loc <- tasks(index).preferredLocations) {
-      for (execId <- loc.executorId) {
-        if (sched.isExecutorAlive(execId)) {
-          addTo(pendingTasksForExecutor.getOrElseUpdate(execId, new ArrayBuffer))
-          hadAliveLocations = true
-        }
-      }
-      if (sched.hasExecutorsAliveOnHost(loc.host)) {
-        addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer))
-        for (rack <- sched.getRackForHost(loc.host)) {
-          addTo(pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer))
-        }
-        hadAliveLocations = true
-      }
-    }
-
-    if (!hadAliveLocations) {
-      // Even though the task might've had preferred locations, all of those hosts or executors
-      // are dead; put it in the no-prefs list so we can schedule it elsewhere right away.
-      addTo(pendingTasksWithNoPrefs)
-    }
-
-    if (!readding) {
-      allPendingTasks += index  // No point scanning this whole list to find the old task there
-    }
-  }
-
-  /**
-   * Return the pending tasks list for a given executor ID, or an empty list if
-   * there is no map entry for that host
-   */
-  private def getPendingTasksForExecutor(executorId: String): ArrayBuffer[Int] = {
-    pendingTasksForExecutor.getOrElse(executorId, ArrayBuffer())
-  }
-
-  /**
-   * Return the pending tasks list for a given host, or an empty list if
-   * there is no map entry for that host
-   */
-  private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = {
-    pendingTasksForHost.getOrElse(host, ArrayBuffer())
-  }
-
-  /**
-   * Return the pending rack-local task list for a given rack, or an empty list if
-   * there is no map entry for that rack
-   */
-  private def getPendingTasksForRack(rack: String): ArrayBuffer[Int] = {
-    pendingTasksForRack.getOrElse(rack, ArrayBuffer())
-  }
-
-  /**
-   * Dequeue a pending task from the given list and return its index.
-   * Return None if the list is empty.
-   * This method also cleans up any tasks in the list that have already
-   * been launched, since we want that to happen lazily.
-   */
-  private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
-    while (!list.isEmpty) {
-      val index = list.last
-      list.trimEnd(1)
-      if (copiesRunning(index) == 0 && !successful(index)) {
-        return Some(index)
-      }
-    }
-    return None
-  }
-
-  /** Check whether a task is currently running an attempt on a given host */
-  private def hasAttemptOnHost(taskIndex: Int, host: String): Boolean = {
-    !taskAttempts(taskIndex).exists(_.host == host)
-  }
-
-  /**
-   * Return a speculative task for a given executor if any are available. The task should not have
-   * an attempt running on this host, in case the host is slow. In addition, the task should meet
-   * the given locality constraint.
-   */
-  private def findSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value)
-    : Option[(Int, TaskLocality.Value)] =
-  {
-    speculatableTasks.retain(index => !successful(index)) // Remove finished tasks from set
-
-    if (!speculatableTasks.isEmpty) {
-      // Check for process-local or preference-less tasks; note that tasks can be process-local
-      // on multiple nodes when we replicate cached blocks, as in Spark Streaming
-      for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
-        val prefs = tasks(index).preferredLocations
-        val executors = prefs.flatMap(_.executorId)
-        if (prefs.size == 0 || executors.contains(execId)) {
-          speculatableTasks -= index
-          return Some((index, TaskLocality.PROCESS_LOCAL))
-        }
-      }
-
-      // Check for node-local tasks
-      if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) {
-        for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
-          val locations = tasks(index).preferredLocations.map(_.host)
-          if (locations.contains(host)) {
-            speculatableTasks -= index
-            return Some((index, TaskLocality.NODE_LOCAL))
-          }
-        }
-      }
-
-      // Check for rack-local tasks
-      if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
-        for (rack <- sched.getRackForHost(host)) {
-          for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
-            val racks = tasks(index).preferredLocations.map(_.host).map(sched.getRackForHost)
-            if (racks.contains(rack)) {
-              speculatableTasks -= index
-              return Some((index, TaskLocality.RACK_LOCAL))
-            }
-          }
-        }
-      }
-
-      // Check for non-local tasks
-      if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
-        for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
-          speculatableTasks -= index
-          return Some((index, TaskLocality.ANY))
-        }
-      }
-    }
-
-    return None
-  }
-
-  /**
-   * Dequeue a pending task for a given node and return its index and locality level.
-   * Only search for tasks matching the given locality constraint.
-   */
-  private def findTask(execId: String, host: String, locality: TaskLocality.Value)
-    : Option[(Int, TaskLocality.Value)] =
-  {
-    for (index <- findTaskFromList(getPendingTasksForExecutor(execId))) {
-      return Some((index, TaskLocality.PROCESS_LOCAL))
-    }
-
-    if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) {
-      for (index <- findTaskFromList(getPendingTasksForHost(host))) {
-        return Some((index, TaskLocality.NODE_LOCAL))
-      }
-    }
-
-    if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
-      for {
-        rack <- sched.getRackForHost(host)
-        index <- findTaskFromList(getPendingTasksForRack(rack))
-      } {
-        return Some((index, TaskLocality.RACK_LOCAL))
-      }
-    }
-
-    // Look for no-pref tasks after rack-local tasks since they can run anywhere.
-    for (index <- findTaskFromList(pendingTasksWithNoPrefs)) {
-      return Some((index, TaskLocality.PROCESS_LOCAL))
-    }
-
-    if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
-      for (index <- findTaskFromList(allPendingTasks)) {
-        return Some((index, TaskLocality.ANY))
-      }
-    }
-
-    // Finally, if all else has failed, find a speculative task
-    return findSpeculativeTask(execId, host, locality)
-  }
-
-  /**
-   * Respond to an offer of a single executor from the scheduler by finding a task
-   */
-  override def resourceOffer(
-      execId: String,
-      host: String,
-      availableCpus: Int,
-      maxLocality: TaskLocality.TaskLocality)
-    : Option[TaskDescription] =
-  {
-    if (tasksSuccessful < numTasks && availableCpus >= CPUS_PER_TASK) {
-      val curTime = clock.getTime()
-
-      var allowedLocality = getAllowedLocalityLevel(curTime)
-      if (allowedLocality > maxLocality) {
-        allowedLocality = maxLocality   // We're not allowed to search for farther-away tasks
-      }
-
-      findTask(execId, host, allowedLocality) match {
-        case Some((index, taskLocality)) => {
-          // Found a task; do some bookkeeping and return a task description
-          val task = tasks(index)
-          val taskId = sched.newTaskId()
-          // Figure out whether this should count as a preferred launch
-          logInfo("Starting task %s:%d as TID %s on executor %s: %s (%s)".format(
-            taskSet.id, index, taskId, execId, host, taskLocality))
-          // Do various bookkeeping
-          copiesRunning(index) += 1
-          val info = new TaskInfo(taskId, index, curTime, execId, host, taskLocality)
-          taskInfos(taskId) = info
-          taskAttempts(index) = info :: taskAttempts(index)
-          // Update our locality level for delay scheduling
-          currentLocalityIndex = getLocalityIndex(taskLocality)
-          lastLaunchTime = curTime
-          // Serialize and return the task
-          val startTime = clock.getTime()
-          // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here
-          // we assume the task can be serialized without exceptions.
-          val serializedTask = Task.serializeWithDependencies(
-            task, sched.sc.addedFiles, sched.sc.addedJars, ser)
-          val timeTaken = clock.getTime() - startTime
-          addRunningTask(taskId)
-          logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
-            taskSet.id, index, serializedTask.limit, timeTaken))
-          val taskName = "task %s:%d".format(taskSet.id, index)
-          info.serializedSize = serializedTask.limit
-          if (taskAttempts(index).size == 1)
-            taskStarted(task,info)
-          return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask))
-        }
-        case _ =>
-      }
-    }
-    return None
-  }
-
-  /**
-   * Get the level we can launch tasks according to delay scheduling, based on current wait time.
-   */
-  private def getAllowedLocalityLevel(curTime: Long): TaskLocality.TaskLocality = {
-    while (curTime - lastLaunchTime >= localityWaits(currentLocalityIndex) &&
-        currentLocalityIndex < myLocalityLevels.length - 1)
-    {
-      // Jump to the next locality level, and remove our waiting time for the current one since
-      // we don't want to count it again on the next one
-      lastLaunchTime += localityWaits(currentLocalityIndex)
-      currentLocalityIndex += 1
-    }
-    myLocalityLevels(currentLocalityIndex)
-  }
-
-  /**
-   * Find the index in myLocalityLevels for a given locality. This is also designed to work with
-   * localities that are not in myLocalityLevels (in case we somehow get those) by returning the
-   * next-biggest level we have. Uses the fact that the last value in myLocalityLevels is ANY.
-   */
-  def getLocalityIndex(locality: TaskLocality.TaskLocality): Int = {
-    var index = 0
-    while (locality > myLocalityLevels(index)) {
-      index += 1
-    }
-    index
-  }
-
-  private def taskStarted(task: Task[_], info: TaskInfo) {
-    sched.dagScheduler.taskStarted(task, info)
-  }
-
-  def handleTaskGettingResult(tid: Long) = {
-    val info = taskInfos(tid)
-    info.markGettingResult()
-    sched.dagScheduler.taskGettingResult(tasks(info.index), info)
-  }
-
-  /**
-   * Marks the task as successful and notifies the DAGScheduler that a task has ended.
-   */
-  def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]) = {
-    val info = taskInfos(tid)
-    val index = info.index
-    info.markSuccessful()
-    removeRunningTask(tid)
-    if (!successful(index)) {
-      logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format(
-        tid, info.duration, info.host, tasksSuccessful, numTasks))
-      sched.dagScheduler.taskEnded(
-        tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
-
-      // Mark successful and stop if all the tasks have succeeded.
-      tasksSuccessful += 1
-      successful(index) = true
-      if (tasksSuccessful == numTasks) {
-        sched.taskSetFinished(this)
-      }
-    } else {
-      logInfo("Ignorning task-finished event for TID " + tid + " because task " +
-        index + " has already completed successfully")
-    }
-  }
-
-  /**
-   * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the
-   * DAG Scheduler.
-   */
-  def handleFailedTask(tid: Long, state: TaskState, reason: Option[TaskEndReason]) {
-    val info = taskInfos(tid)
-    if (info.failed) {
-      return
-    }
-    removeRunningTask(tid)
-    val index = info.index
-    info.markFailed()
-    if (!successful(index)) {
-      logWarning("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
-      copiesRunning(index) -= 1
-      // Check if the problem is a map output fetch failure. In that case, this
-      // task will never succeed on any node, so tell the scheduler about it.
-      reason.foreach {
-        case fetchFailed: FetchFailed =>
-          logWarning("Loss was due to fetch failure from " + fetchFailed.bmAddress)
-          sched.dagScheduler.taskEnded(tasks(index), fetchFailed, null, null, info, null)
-          successful(index) = true
-          tasksSuccessful += 1
-          sched.taskSetFinished(this)
-          removeAllRunningTasks()
-          return
-
-        case TaskKilled =>
-          logWarning("Task %d was killed.".format(tid))
-          sched.dagScheduler.taskEnded(tasks(index), reason.get, null, null, info, null)
-          return
-
-        case ef: ExceptionFailure =>
-          sched.dagScheduler.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null))
-          if (ef.className == classOf[NotSerializableException].getName()) {
-            // If the task result wasn't serializable, there's no point in trying to re-execute it.
-            logError("Task %s:%s had a not serializable result: %s; not retrying".format(
-              taskSet.id, index, ef.description))
-            abort("Task %s:%s had a not serializable result: %s".format(
-              taskSet.id, index, ef.description))
-            return
-          }
-          val key = ef.description
-          val now = clock.getTime()
-          val (printFull, dupCount) = {
-            if (recentExceptions.contains(key)) {
-              val (dupCount, printTime) = recentExceptions(key)
-              if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
-                recentExceptions(key) = (0, now)
-                (true, 0)
-              } else {
-                recentExceptions(key) = (dupCount + 1, printTime)
-                (false, dupCount + 1)
-              }
-            } else {
-              recentExceptions(key) = (0, now)
-              (true, 0)
-            }
-          }
-          if (printFull) {
-            val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString))
-            logWarning("Loss was due to %s\n%s\n%s".format(
-              ef.className, ef.description, locs.mkString("\n")))
-          } else {
-            logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount))
-          }
-
-        case TaskResultLost =>
-          logWarning("Lost result for TID %s on host %s".format(tid, info.host))
-          sched.dagScheduler.taskEnded(tasks(index), TaskResultLost, null, null, info, null)
-
-        case _ => {}
-      }
-      // On non-fetch failures, re-enqueue the task as pending for a max number of retries
-      addPendingTask(index)
-      if (state != TaskState.KILLED) {
-        numFailures(index) += 1
-        if (numFailures(index) >= MAX_TASK_FAILURES) {
-          logError("Task %s:%d failed %d times; aborting job".format(
-            taskSet.id, index, MAX_TASK_FAILURES))
-          abort("Task %s:%d failed %d times".format(taskSet.id, index, MAX_TASK_FAILURES))
-        }
-      }
-    } else {
-      logInfo("Ignoring task-lost event for TID " + tid +
-        " because task " + index + " is already finished")
-    }
-  }
-
-  override def error(message: String) {
-    // Save the error message
-    abort("Error: " + message)
-  }
-
-  def abort(message: String) {
-    failed = true
-    causeOfFailure = message
-    // TODO: Kill running tasks if we were not terminated due to a Mesos error
-    sched.dagScheduler.taskSetFailed(taskSet, message)
-    removeAllRunningTasks()
-    sched.taskSetFinished(this)
-  }
-
-  /** If the given task ID is not in the set of running tasks, adds it.
-   *
-   * Used to keep track of the number of running tasks, for enforcing scheduling policies.
-   */
-  def addRunningTask(tid: Long) {
-    if (runningTasksSet.add(tid) && parent != null) {
-      parent.increaseRunningTasks(1)
-    }
-    runningTasks = runningTasksSet.size
-  }
-
-  /** If the given task ID is in the set of running tasks, removes it. */
-  def removeRunningTask(tid: Long) {
-    if (runningTasksSet.remove(tid) && parent != null) {
-      parent.decreaseRunningTasks(1)
-    }
-    runningTasks = runningTasksSet.size
-  }
-
-  private[cluster] def removeAllRunningTasks() {
-    val numRunningTasks = runningTasksSet.size
-    runningTasksSet.clear()
-    if (parent != null) {
-      parent.decreaseRunningTasks(numRunningTasks)
-    }
-    runningTasks = 0
-  }
-
-  override def getSchedulableByName(name: String): Schedulable = {
-    return null
-  }
-
-  override def addSchedulable(schedulable: Schedulable) {}
-
-  override def removeSchedulable(schedulable: Schedulable) {}
-
-  override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
-    var sortedTaskSetQueue = ArrayBuffer[TaskSetManager](this)
-    sortedTaskSetQueue += this
-    return sortedTaskSetQueue
-  }
-
-  /** Called by cluster scheduler when an executor is lost so we can re-enqueue our tasks */
-  override def executorLost(execId: String, host: String) {
-    logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id)
-
-    // Re-enqueue pending tasks for this host based on the status of the cluster -- for example, a
-    // task that used to have locations on only this host might now go to the no-prefs list. Note
-    // that it's okay if we add a task to the same queue twice (if it had multiple preferred
-    // locations), because findTaskFromList will skip already-running tasks.
-    for (index <- getPendingTasksForExecutor(execId)) {
-      addPendingTask(index, readding=true)
-    }
-    for (index <- getPendingTasksForHost(host)) {
-      addPendingTask(index, readding=true)
-    }
-
-    // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage
-    if (tasks(0).isInstanceOf[ShuffleMapTask]) {
-      for ((tid, info) <- taskInfos if info.executorId == execId) {
-        val index = taskInfos(tid).index
-        if (successful(index)) {
-          successful(index) = false
-          copiesRunning(index) -= 1
-          tasksSuccessful -= 1
-          addPendingTask(index)
-          // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
-          // stage finishes when a total of tasks.size tasks finish.
-          sched.dagScheduler.taskEnded(tasks(index), Resubmitted, null, null, info, null)
-        }
-      }
-    }
-    // Also re-enqueue any tasks that were running on the node
-    for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
-      handleFailedTask(tid, TaskState.KILLED, None)
-    }
-  }
-
-  /**
-   * Check for tasks to be speculated and return true if there are any. This is called periodically
-   * by the ClusterScheduler.
-   *
-   * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that
-   * we don't scan the whole task set. It might also help to make this sorted by launch time.
-   */
-  override def checkSpeculatableTasks(): Boolean = {
-    // Can't speculate if we only have one task, or if all tasks have finished.
-    if (numTasks == 1 || tasksSuccessful == numTasks) {
-      return false
-    }
-    var foundTasks = false
-    val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
-    logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
-    if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) {
-      val time = clock.getTime()
-      val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
-      Arrays.sort(durations)
-      val medianDuration = durations(min((0.5 * tasksSuccessful).round.toInt, durations.size - 1))
-      val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100)
-      // TODO: Threshold should also look at standard deviation of task durations and have a lower
-      // bound based on that.
-      logDebug("Task length threshold for speculation: " + threshold)
-      for ((tid, info) <- taskInfos) {
-        val index = info.index
-        if (!successful(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold &&
-          !speculatableTasks.contains(index)) {
-          logInfo(
-            "Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format(
-              taskSet.id, index, info.host, threshold))
-          speculatableTasks += index
-          foundTasks = true
-        }
-      }
-    }
-    return foundTasks
-  }
-
-  override def hasPendingTasks(): Boolean = {
-    numTasks > 0 && tasksSuccessful < numTasks
-  }
-
-  private def getLocalityWait(level: TaskLocality.TaskLocality): Long = {
-    val defaultWait = conf.getOrElse("spark.locality.wait", "3000")
-    level match {
-      case TaskLocality.PROCESS_LOCAL =>
-        conf.getOrElse("spark.locality.wait.process",  defaultWait).toLong
-      case TaskLocality.NODE_LOCAL =>
-        conf.getOrElse("spark.locality.wait.node",  defaultWait).toLong
-      case TaskLocality.RACK_LOCAL =>
-        conf.getOrElse("spark.locality.wait.rack",  defaultWait).toLong
-      case TaskLocality.ANY =>
-        0L
-    }
-  }
-
-  /**
-   * Compute the locality levels used in this TaskSet. Assumes that all tasks have already been
-   * added to queues using addPendingTask.
-   */
-  private def computeValidLocalityLevels(): Array[TaskLocality.TaskLocality] = {
-    import TaskLocality.{PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY}
-    val levels = new ArrayBuffer[TaskLocality.TaskLocality]
-    if (!pendingTasksForExecutor.isEmpty && getLocalityWait(PROCESS_LOCAL) != 0) {
-      levels += PROCESS_LOCAL
-    }
-    if (!pendingTasksForHost.isEmpty && getLocalityWait(NODE_LOCAL) != 0) {
-      levels += NODE_LOCAL
-    }
-    if (!pendingTasksForRack.isEmpty && getLocalityWait(RACK_LOCAL) != 0) {
-      levels += RACK_LOCAL
-    }
-    levels += ANY
-    logDebug("Valid locality levels for " + taskSet + ": " + levels.mkString(", "))
-    levels.toArray
-  }
-}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 156b01b1498837141ba92b75b39f6c756a4ec0c1..b4a3ecca3909512dac4693e78e0d58487fe74597 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -28,8 +28,10 @@ import akka.actor._
 import akka.pattern.ask
 import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
 
+import org.apache.spark.{SparkException, Logging, TaskState}
 import org.apache.spark.{Logging, SparkException, TaskState}
-import org.apache.spark.scheduler.TaskDescription
+import org.apache.spark.scheduler.{TaskSchedulerImpl, SchedulerBackend, SlaveLost, TaskDescription,
+  WorkerOffer}
 import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
 import org.apache.spark.util.{AkkaUtils, Utils}
 
@@ -42,7 +44,7 @@ import org.apache.spark.util.{AkkaUtils, Utils}
  * (spark.deploy.*).
  */
 private[spark]
-class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: ActorSystem)
+class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: ActorSystem)
   extends SchedulerBackend with Logging
 {
   // Use an atomic variable to track total number of cores in the cluster for simplicity and speed
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
index d74f000ebb910ee3db05136dbf0f98ed6e156452..f41fbbd1f34418120125a549e3cffd565c69415e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
@@ -19,10 +19,12 @@ package org.apache.spark.scheduler.cluster
 
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs.{Path, FileSystem}
+
 import org.apache.spark.{Logging, SparkContext}
+import org.apache.spark.scheduler.TaskSchedulerImpl
 
 private[spark] class SimrSchedulerBackend(
-    scheduler: ClusterScheduler,
+    scheduler: TaskSchedulerImpl,
     sc: SparkContext,
     driverFilePath: String)
   extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
index de69e3260d5a90fbbc0b7470bd3f8722b858986a..224077566d06aff6fabe4921560d06e848f7c6d7 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -17,14 +17,16 @@
 
 package org.apache.spark.scheduler.cluster
 
+import scala.collection.mutable.HashMap
+
 import org.apache.spark.{Logging, SparkContext}
 import org.apache.spark.deploy.client.{Client, ClientListener}
 import org.apache.spark.deploy.{Command, ApplicationDescription}
-import scala.collection.mutable.HashMap
+import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SlaveLost, TaskSchedulerImpl}
 import org.apache.spark.util.Utils
 
 private[spark] class SparkDeploySchedulerBackend(
-    scheduler: ClusterScheduler,
+    scheduler: TaskSchedulerImpl,
     sc: SparkContext,
     masters: Array[String],
     appName: String)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
index 1695374152c50226014daf5553cd5802792a7bd5..9e2cd3f6994e74436f8d29eeaf360c246cc2364b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
@@ -30,7 +30,8 @@ import org.apache.mesos._
 import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _}
 
 import org.apache.spark.{SparkException, Logging, SparkContext, TaskState}
-import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedulerBackend}
+import org.apache.spark.scheduler.TaskSchedulerImpl
+import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
 
 /**
  * A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds
@@ -43,7 +44,7 @@ import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedu
  * remove this.
  */
 private[spark] class CoarseMesosSchedulerBackend(
-    scheduler: ClusterScheduler,
+    scheduler: TaskSchedulerImpl,
     sc: SparkContext,
     master: String,
     appName: String)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
index 8dfd4d5fb39371ddff4c64f96344098e7cff3aa9..be963829830225f72a72ae7d4d0b96b2157de7d9 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
@@ -30,9 +30,8 @@ import org.apache.mesos._
 import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _}
 
 import org.apache.spark.{Logging, SparkException, SparkContext, TaskState}
-import org.apache.spark.scheduler.TaskDescription
-import org.apache.spark.scheduler.cluster.{ClusterScheduler, ExecutorExited, ExecutorLossReason}
-import org.apache.spark.scheduler.cluster.{SchedulerBackend, SlaveLost, WorkerOffer}
+import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SchedulerBackend, SlaveLost,
+  TaskDescription, TaskSchedulerImpl, WorkerOffer}
 import org.apache.spark.util.Utils
 
 /**
@@ -41,7 +40,7 @@ import org.apache.spark.util.Utils
  * from multiple apps can run on different cores) and in time (a core can switch ownership).
  */
 private[spark] class MesosSchedulerBackend(
-    scheduler: ClusterScheduler,
+    scheduler: TaskSchedulerImpl,
     sc: SparkContext,
     master: String,
     appName: String)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
new file mode 100644
index 0000000000000000000000000000000000000000..4edc6a0d3f2a06a17a0768e90b3a69bfd7d15348
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.local
+
+import java.nio.ByteBuffer
+
+import akka.actor.{Actor, ActorRef, Props}
+
+import org.apache.spark.{Logging, SparkContext, SparkEnv, TaskState}
+import org.apache.spark.TaskState.TaskState
+import org.apache.spark.executor.{Executor, ExecutorBackend}
+import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer}
+
+private case class ReviveOffers()
+
+private case class StatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer)
+
+private case class KillTask(taskId: Long)
+
+/**
+ * Calls to LocalBackend are all serialized through LocalActor. Using an actor makes the calls on
+ * LocalBackend asynchronous, which is necessary to prevent deadlock between LocalBackend
+ * and the ClusterScheduler.
+ */
+private[spark] class LocalActor(
+  scheduler: TaskSchedulerImpl,
+  executorBackend: LocalBackend,
+  private val totalCores: Int) extends Actor with Logging {
+
+  private var freeCores = totalCores
+
+  private val localExecutorId = "localhost"
+  private val localExecutorHostname = "localhost"
+
+  val executor = new Executor(localExecutorId, localExecutorHostname, Seq.empty, isLocal = true)
+
+  def receive = {
+    case ReviveOffers =>
+      reviveOffers()
+
+    case StatusUpdate(taskId, state, serializedData) =>
+      scheduler.statusUpdate(taskId, state, serializedData)
+      if (TaskState.isFinished(state)) {
+        freeCores += 1
+        reviveOffers()
+      }
+
+    case KillTask(taskId) =>
+      executor.killTask(taskId)
+  }
+
+  def reviveOffers() {
+    val offers = Seq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores))
+    for (task <- scheduler.resourceOffers(offers).flatten) {
+      freeCores -= 1
+      executor.launchTask(executorBackend, task.taskId, task.serializedTask)
+    }
+  }
+}
+
+/**
+ * LocalBackend is used when running a local version of Spark where the executor, backend, and
+ * master all run in the same JVM. It sits behind a ClusterScheduler and handles launching tasks
+ * on a single Executor (created by the LocalBackend) running locally.
+ */
+private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: Int)
+  extends SchedulerBackend with ExecutorBackend {
+
+  var localActor: ActorRef = null
+
+  override def start() {
+    localActor = SparkEnv.get.actorSystem.actorOf(
+      Props(new LocalActor(scheduler, this, totalCores)),
+      "LocalBackendActor")
+  }
+
+  override def stop() {
+  }
+
+  override def reviveOffers() {
+    localActor ! ReviveOffers
+  }
+
+  override def defaultParallelism() = totalCores
+
+  override def killTask(taskId: Long, executorId: String) {
+    localActor ! KillTask(taskId)
+  }
+
+  override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) {
+    localActor ! StatusUpdate(taskId, state, serializedData)
+  }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
deleted file mode 100644
index 7c173e3ad5078819f107007c27486f9d9ab3373b..0000000000000000000000000000000000000000
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
+++ /dev/null
@@ -1,224 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler.local
-
-import java.nio.ByteBuffer
-import java.util.concurrent.atomic.AtomicInteger
-
-import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
-
-import akka.actor._
-
-import org.apache.spark._
-import org.apache.spark.TaskState.TaskState
-import org.apache.spark.executor.{Executor, ExecutorBackend}
-import org.apache.spark.scheduler._
-import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
-
-
-/**
- * A FIFO or Fair TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
- * the scheduler also allows each task to fail up to maxFailures times, which is useful for
- * testing fault recovery.
- */
-
-private[local]
-case class LocalReviveOffers()
-
-private[local]
-case class LocalStatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer)
-
-private[local]
-case class KillTask(taskId: Long)
-
-private[spark]
-class LocalActor(localScheduler: LocalScheduler, private var freeCores: Int)
-  extends Actor with Logging {
-
-  val executor = new Executor(
-    "localhost", "localhost", localScheduler.sc.conf.getAll, isLocal = true)
-
-  def receive = {
-    case LocalReviveOffers =>
-      launchTask(localScheduler.resourceOffer(freeCores))
-
-    case LocalStatusUpdate(taskId, state, serializeData) =>
-      if (TaskState.isFinished(state)) {
-        freeCores += 1
-        launchTask(localScheduler.resourceOffer(freeCores))
-      }
-
-    case KillTask(taskId) =>
-      executor.killTask(taskId)
-  }
-
-  private def launchTask(tasks: Seq[TaskDescription]) {
-    for (task <- tasks) {
-      freeCores -= 1
-      executor.launchTask(localScheduler, task.taskId, task.serializedTask)
-    }
-  }
-}
-
-private[spark] class LocalScheduler(val threads: Int, val maxFailures: Int, val sc: SparkContext)
-  extends TaskScheduler
-  with ExecutorBackend
-  with Logging {
-
-  val env = SparkEnv.get
-  val conf = env.conf
-  val attemptId = new AtomicInteger
-  var dagScheduler: DAGScheduler = null
-
-  // Application dependencies (added through SparkContext) that we've fetched so far on this node.
-  // Each map holds the master's timestamp for the version of that file or JAR we got.
-  val currentFiles: HashMap[String, Long] = new HashMap[String, Long]()
-  val currentJars: HashMap[String, Long] = new HashMap[String, Long]()
-
-  var schedulableBuilder: SchedulableBuilder = null
-  var rootPool: Pool = null
-  val schedulingMode: SchedulingMode = SchedulingMode.withName(
-    conf.getOrElse("spark.scheduler.mode", "FIFO"))
-  val activeTaskSets = new HashMap[String, LocalTaskSetManager]
-  val taskIdToTaskSetId = new HashMap[Long, String]
-  val taskSetTaskIds = new HashMap[String, HashSet[Long]]
-
-  var localActor: ActorRef = null
-
-  override def start() {
-    // temporarily set rootPool name to empty
-    rootPool = new Pool("", schedulingMode, 0, 0)
-    schedulableBuilder = {
-      schedulingMode match {
-        case SchedulingMode.FIFO =>
-          new FIFOSchedulableBuilder(rootPool)
-        case SchedulingMode.FAIR =>
-          new FairSchedulableBuilder(rootPool, conf)
-      }
-    }
-    schedulableBuilder.buildPools()
-
-    localActor = env.actorSystem.actorOf(Props(new LocalActor(this, threads)), "Test")
-  }
-
-  override def setDAGScheduler(dagScheduler: DAGScheduler) {
-    this.dagScheduler = dagScheduler
-  }
-
-  override def submitTasks(taskSet: TaskSet) {
-    synchronized {
-      val manager = new LocalTaskSetManager(this, taskSet)
-      schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
-      activeTaskSets(taskSet.id) = manager
-      taskSetTaskIds(taskSet.id) = new HashSet[Long]()
-      localActor ! LocalReviveOffers
-    }
-  }
-
-  override def cancelTasks(stageId: Int): Unit = synchronized {
-    logInfo("Cancelling stage " + stageId)
-    logInfo("Cancelling stage " + activeTaskSets.map(_._2.stageId))
-    activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
-      // There are two possible cases here:
-      // 1. The task set manager has been created and some tasks have been scheduled.
-      //    In this case, send a kill signal to the executors to kill the task and then abort
-      //    the stage.
-      // 2. The task set manager has been created but no tasks has been scheduled. In this case,
-      //    simply abort the stage.
-      val taskIds = taskSetTaskIds(tsm.taskSet.id)
-      if (taskIds.size > 0) {
-        taskIds.foreach { tid =>
-          localActor ! KillTask(tid)
-        }
-      }
-      logInfo("Stage %d was cancelled".format(stageId))
-      taskSetFinished(tsm)
-    }
-  }
-
-  def resourceOffer(freeCores: Int): Seq[TaskDescription] = {
-    synchronized {
-      var freeCpuCores = freeCores
-      val tasks = new ArrayBuffer[TaskDescription](freeCores)
-      val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue()
-      for (manager <- sortedTaskSetQueue) {
-        logDebug("parentName:%s,name:%s,runningTasks:%s".format(
-          manager.parent.name, manager.name, manager.runningTasks))
-      }
-
-      var launchTask = false
-      for (manager <- sortedTaskSetQueue) {
-        do {
-          launchTask = false
-          manager.resourceOffer(null, null, freeCpuCores, null) match {
-            case Some(task) =>
-              tasks += task
-              taskIdToTaskSetId(task.taskId) = manager.taskSet.id
-              taskSetTaskIds(manager.taskSet.id) += task.taskId
-              freeCpuCores -= 1
-              launchTask = true
-            case None => {}
-          }
-        } while(launchTask)
-      }
-      return tasks
-    }
-  }
-
-  def taskSetFinished(manager: TaskSetManager) {
-    synchronized {
-      activeTaskSets -= manager.taskSet.id
-      manager.parent.removeSchedulable(manager)
-      logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name))
-      taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
-      taskSetTaskIds -= manager.taskSet.id
-    }
-  }
-
-  override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) {
-    if (TaskState.isFinished(state)) {
-      synchronized {
-        taskIdToTaskSetId.get(taskId) match {
-          case Some(taskSetId) =>
-            val taskSetManager = activeTaskSets.get(taskSetId)
-            taskSetManager.foreach { tsm =>
-              taskSetTaskIds(taskSetId) -= taskId
-
-              state match {
-                case TaskState.FINISHED =>
-                  tsm.taskEnded(taskId, state, serializedData)
-                case TaskState.FAILED =>
-                  tsm.taskFailed(taskId, state, serializedData)
-                case TaskState.KILLED =>
-                  tsm.error("Task %d was killed".format(taskId))
-                case _ => {}
-              }
-            }
-          case None =>
-            logInfo("Ignoring update from TID " + taskId + " because its task set is gone")
-        }
-      }
-      localActor ! LocalStatusUpdate(taskId, state, serializedData)
-    }
-  }
-
-  override def stop() {
-  }
-
-  override def defaultParallelism() = threads
-}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
deleted file mode 100644
index 53bf78267e02668f48cf0c2239c8c2313ad6d2ca..0000000000000000000000000000000000000000
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
+++ /dev/null
@@ -1,191 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler.local
-
-import java.nio.ByteBuffer
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-
-import org.apache.spark.{ExceptionFailure, Logging, SparkEnv, SparkException, Success, TaskState}
-import org.apache.spark.TaskState.TaskState
-import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Pool, Schedulable, Task,
-  TaskDescription, TaskInfo, TaskLocality, TaskResult, TaskSet, TaskSetManager}
-
-
-private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: TaskSet)
-  extends TaskSetManager with Logging {
-
-  var parent: Pool = null
-  var weight: Int = 1
-  var minShare: Int = 0
-  var runningTasks: Int = 0
-  var priority: Int = taskSet.priority
-  var stageId: Int = taskSet.stageId
-  var name: String = "TaskSet_" + taskSet.stageId.toString
-
-  var failCount = new Array[Int](taskSet.tasks.size)
-  val taskInfos = new HashMap[Long, TaskInfo]
-  val numTasks = taskSet.tasks.size
-  var numFinished = 0
-  val env = SparkEnv.get
-  val ser = env.closureSerializer.newInstance()
-  val copiesRunning = new Array[Int](numTasks)
-  val finished = new Array[Boolean](numTasks)
-  val numFailures = new Array[Int](numTasks)
-  val MAX_TASK_FAILURES = sched.maxFailures
-
-  def increaseRunningTasks(taskNum: Int): Unit = {
-    runningTasks += taskNum
-    if (parent != null) {
-     parent.increaseRunningTasks(taskNum)
-    }
-  }
-
-  def decreaseRunningTasks(taskNum: Int): Unit = {
-    runningTasks -= taskNum
-    if (parent != null) {
-      parent.decreaseRunningTasks(taskNum)
-    }
-  }
-
-  override def addSchedulable(schedulable: Schedulable): Unit = {
-    // nothing
-  }
-
-  override def removeSchedulable(schedulable: Schedulable): Unit = {
-    // nothing
-  }
-
-  override def getSchedulableByName(name: String): Schedulable = {
-    return null
-  }
-
-  override def executorLost(executorId: String, host: String): Unit = {
-    // nothing
-  }
-
-  override def checkSpeculatableTasks() = true
-
-  override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
-    var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager]
-    sortedTaskSetQueue += this
-    return sortedTaskSetQueue
-  }
-
-  override def hasPendingTasks() = true
-
-  def findTask(): Option[Int] = {
-    for (i <- 0 to numTasks-1) {
-      if (copiesRunning(i) == 0 && !finished(i)) {
-        return Some(i)
-      }
-    }
-    return None
-  }
-
-  override def resourceOffer(
-      execId: String,
-      host: String,
-      availableCpus: Int,
-      maxLocality: TaskLocality.TaskLocality)
-    : Option[TaskDescription] =
-  {
-    SparkEnv.set(sched.env)
-    logDebug("availableCpus:%d, numFinished:%d, numTasks:%d".format(
-      availableCpus.toInt, numFinished, numTasks))
-    if (availableCpus > 0 && numFinished < numTasks) {
-      findTask() match {
-        case Some(index) =>
-          val taskId = sched.attemptId.getAndIncrement()
-          val task = taskSet.tasks(index)
-          val info = new TaskInfo(taskId, index, System.currentTimeMillis(), "local", "local:1",
-            TaskLocality.NODE_LOCAL)
-          taskInfos(taskId) = info
-          // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here
-          // we assume the task can be serialized without exceptions.
-          val bytes = Task.serializeWithDependencies(
-            task, sched.sc.addedFiles, sched.sc.addedJars, ser)
-          logInfo("Size of task " + taskId + " is " + bytes.limit + " bytes")
-          val taskName = "task %s:%d".format(taskSet.id, index)
-          copiesRunning(index) += 1
-          increaseRunningTasks(1)
-          taskStarted(task, info)
-          return Some(new TaskDescription(taskId, null, taskName, index, bytes))
-        case None => {}
-      }
-    }
-    return None
-  }
-
-  def taskStarted(task: Task[_], info: TaskInfo) {
-    sched.dagScheduler.taskStarted(task, info)
-  }
-
-  def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) {
-    val info = taskInfos(tid)
-    val index = info.index
-    val task = taskSet.tasks(index)
-    info.markSuccessful()
-    val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader) match {
-      case directResult: DirectTaskResult[_] => directResult
-      case IndirectTaskResult(blockId) => {
-        throw new SparkException("Expect only DirectTaskResults when using LocalScheduler")
-      }
-    }
-    result.metrics.resultSize = serializedData.limit()
-    sched.dagScheduler.taskEnded(task, Success, result.value, result.accumUpdates, info,
-      result.metrics)
-    numFinished += 1
-    decreaseRunningTasks(1)
-    finished(index) = true
-    if (numFinished == numTasks) {
-      sched.taskSetFinished(this)
-    }
-  }
-
-  def taskFailed(tid: Long, state: TaskState, serializedData: ByteBuffer) {
-    val info = taskInfos(tid)
-    val index = info.index
-    val task = taskSet.tasks(index)
-    info.markFailed()
-    decreaseRunningTasks(1)
-    val reason: ExceptionFailure = ser.deserialize[ExceptionFailure](
-      serializedData, getClass.getClassLoader)
-    sched.dagScheduler.taskEnded(task, reason, null, null, info, reason.metrics.getOrElse(null))
-    if (!finished(index)) {
-      copiesRunning(index) -= 1
-      numFailures(index) += 1
-      val locs = reason.stackTrace.map(loc => "\tat %s".format(loc.toString))
-      logInfo("Loss was due to %s\n%s\n%s".format(
-        reason.className, reason.description, locs.mkString("\n")))
-      if (numFailures(index) > MAX_TASK_FAILURES) {
-        val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format(
-          taskSet.id, index, MAX_TASK_FAILURES, reason.description)
-        decreaseRunningTasks(runningTasks)
-        sched.dagScheduler.taskSetFailed(taskSet, errorMessage)
-        // need to delete failed Taskset from schedule queue
-        sched.taskSetFinished(this)
-      }
-    }
-  }
-
-  override def error(message: String) {
-    sched.dagScheduler.taskSetFailed(taskSet, message)
-    sched.taskSetFinished(this)
-  }
-}
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
index f592df283aa648449dabcdfc6c90c44f913d3f1c..151eedb7837fb108ee0cc3d0ee14e7f6f83ebe0c 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
@@ -74,10 +74,16 @@ class ShuffleBlockManager(blockManager: BlockManager) {
    * Contains all the state related to a particular shuffle. This includes a pool of unused
    * ShuffleFileGroups, as well as all ShuffleFileGroups that have been created for the shuffle.
    */
-  private class ShuffleState() {
+  private class ShuffleState(val numBuckets: Int) {
     val nextFileId = new AtomicInteger(0)
     val unusedFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]()
     val allFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]()
+
+    /**
+     * The mapIds of all map tasks completed on this Executor for this shuffle.
+     * NB: This is only populated if consolidateShuffleFiles is FALSE. We don't need it otherwise.
+     */
+    val completedMapTasks = new ConcurrentLinkedQueue[Int]()
   }
 
   type ShuffleId = Int
@@ -88,7 +94,7 @@ class ShuffleBlockManager(blockManager: BlockManager) {
 
   def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer) = {
     new ShuffleWriterGroup {
-      shuffleStates.putIfAbsent(shuffleId, new ShuffleState())
+      shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets))
       private val shuffleState = shuffleStates(shuffleId)
       private var fileGroup: ShuffleFileGroup = null
 
@@ -113,6 +119,8 @@ class ShuffleBlockManager(blockManager: BlockManager) {
             fileGroup.recordMapOutput(mapId, offsets)
           }
           recycleFileGroup(fileGroup)
+        } else {
+          shuffleState.completedMapTasks.add(mapId)
         }
       }
 
@@ -158,7 +166,18 @@ class ShuffleBlockManager(blockManager: BlockManager) {
   }
 
   private def cleanup(cleanupTime: Long) {
-    shuffleStates.clearOldValues(cleanupTime)
+    shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => {
+      if (consolidateShuffleFiles) {
+        for (fileGroup <- state.allFileGroups; file <- fileGroup.files) {
+          file.delete()
+        }
+      } else {
+        for (mapId <- state.completedMapTasks; reduceId <- 0 until state.numBuckets) {
+          val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId)
+          blockManager.diskBlockManager.getFile(blockId).delete()
+        }
+      }
+    })
   }
 }
 
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala
index e596690bc3df87475adac7f3abbf97e35d176634..a31a7e1d58374568b84e75572efb76028d3404df 100644
--- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala
@@ -56,7 +56,8 @@ private[spark] class ExecutorsUI(val sc: SparkContext) {
     val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize)).fold(0L)(_+_)
 
     val execHead = Seq("Executor ID", "Address", "RDD blocks", "Memory used", "Disk used",
-      "Active tasks", "Failed tasks", "Complete tasks", "Total tasks")
+      "Active tasks", "Failed tasks", "Complete tasks", "Total tasks", "Task Time", "Shuffle Read",
+      "Shuffle Write")
 
     def execRow(kv: Seq[String]) = {
       <tr>
@@ -73,6 +74,9 @@ private[spark] class ExecutorsUI(val sc: SparkContext) {
         <td>{kv(7)}</td>
         <td>{kv(8)}</td>
         <td>{kv(9)}</td>
+        <td>{Utils.msDurationToString(kv(10).toLong)}</td>
+        <td>{Utils.bytesToString(kv(11).toLong)}</td>
+        <td>{Utils.bytesToString(kv(12).toLong)}</td>
       </tr>
     }
 
@@ -111,6 +115,9 @@ private[spark] class ExecutorsUI(val sc: SparkContext) {
     val failedTasks = listener.executorToTasksFailed.getOrElse(execId, 0)
     val completedTasks = listener.executorToTasksComplete.getOrElse(execId, 0)
     val totalTasks = activeTasks + failedTasks + completedTasks
+    val totalDuration = listener.executorToDuration.getOrElse(execId, 0)
+    val totalShuffleRead = listener.executorToShuffleRead.getOrElse(execId, 0)
+    val totalShuffleWrite = listener.executorToShuffleWrite.getOrElse(execId, 0)
 
     Seq(
       execId,
@@ -122,7 +129,10 @@ private[spark] class ExecutorsUI(val sc: SparkContext) {
       activeTasks.toString,
       failedTasks.toString,
       completedTasks.toString,
-      totalTasks.toString
+      totalTasks.toString,
+      totalDuration.toString,
+      totalShuffleRead.toString,
+      totalShuffleWrite.toString
     )
   }
 
@@ -130,6 +140,9 @@ private[spark] class ExecutorsUI(val sc: SparkContext) {
     val executorToTasksActive = HashMap[String, HashSet[TaskInfo]]()
     val executorToTasksComplete = HashMap[String, Int]()
     val executorToTasksFailed = HashMap[String, Int]()
+    val executorToDuration = HashMap[String, Long]()
+    val executorToShuffleRead = HashMap[String, Long]()
+    val executorToShuffleWrite = HashMap[String, Long]()
 
     override def onTaskStart(taskStart: SparkListenerTaskStart) {
       val eid = taskStart.taskInfo.executorId
@@ -140,6 +153,9 @@ private[spark] class ExecutorsUI(val sc: SparkContext) {
     override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
       val eid = taskEnd.taskInfo.executorId
       val activeTasks = executorToTasksActive.getOrElseUpdate(eid, new HashSet[TaskInfo]())
+      val newDuration = executorToDuration.getOrElse(eid, 0L) + taskEnd.taskInfo.duration
+      executorToDuration.put(eid, newDuration)
+
       activeTasks -= taskEnd.taskInfo
       val (failureInfo, metrics): (Option[ExceptionFailure], Option[TaskMetrics]) =
         taskEnd.reason match {
@@ -150,6 +166,17 @@ private[spark] class ExecutorsUI(val sc: SparkContext) {
             executorToTasksComplete(eid) = executorToTasksComplete.getOrElse(eid, 0) + 1
             (None, Option(taskEnd.taskMetrics))
         }
+
+      // update shuffle read/write
+      if (null != taskEnd.taskMetrics) {
+        taskEnd.taskMetrics.shuffleReadMetrics.foreach(shuffleRead =>
+          executorToShuffleRead.put(eid, executorToShuffleRead.getOrElse(eid, 0L) +
+            shuffleRead.remoteBytesRead))
+
+        taskEnd.taskMetrics.shuffleWriteMetrics.foreach(shuffleWrite =>
+          executorToShuffleWrite.put(eid, executorToShuffleWrite.getOrElse(eid, 0L) +
+            shuffleWrite.shuffleBytesWritten))
+      }
     }
   }
 }
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala
new file mode 100644
index 0000000000000000000000000000000000000000..3c53e88380193daa24d6090c16ef306e77ca1de9
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.jobs
+
+/** class for reporting aggregated metrics for each executors in stageUI */
+private[spark] class ExecutorSummary {
+  var taskTime : Long = 0
+  var failedTasks : Int = 0
+  var succeededTasks : Int = 0
+  var shuffleRead : Long = 0
+  var shuffleWrite : Long = 0
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala
new file mode 100644
index 0000000000000000000000000000000000000000..0dd876480afa0e6d5ac1ade7ac7a82789374ab4e
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.jobs
+
+import scala.xml.Node
+
+import org.apache.spark.scheduler.SchedulingMode
+import org.apache.spark.util.Utils
+import scala.collection.mutable
+
+/** Page showing executor summary */
+private[spark] class ExecutorTable(val parent: JobProgressUI, val stageId: Int) {
+
+  val listener = parent.listener
+  val dateFmt = parent.dateFmt
+  val isFairScheduler = listener.sc.getSchedulingMode == SchedulingMode.FAIR
+
+  def toNodeSeq(): Seq[Node] = {
+    listener.synchronized {
+      executorTable()
+    }
+  }
+
+  /** Special table which merges two header cells. */
+  private def executorTable[T](): Seq[Node] = {
+    <table class="table table-bordered table-striped table-condensed sortable">
+      <thead>
+        <th>Executor ID</th>
+        <th>Address</th>
+        <th>Task Time</th>
+        <th>Total Tasks</th>
+        <th>Failed Tasks</th>
+        <th>Succeeded Tasks</th>
+        <th>Shuffle Read</th>
+        <th>Shuffle Write</th>
+      </thead>
+      <tbody>
+        {createExecutorTable()}
+      </tbody>
+    </table>
+  }
+
+  private def createExecutorTable() : Seq[Node] = {
+    // make a executor-id -> address map
+    val executorIdToAddress = mutable.HashMap[String, String]()
+    val storageStatusList = parent.sc.getExecutorStorageStatus
+    for (statusId <- 0 until storageStatusList.size) {
+      val blockManagerId = parent.sc.getExecutorStorageStatus(statusId).blockManagerId
+      val address = blockManagerId.hostPort
+      val executorId = blockManagerId.executorId
+      executorIdToAddress.put(executorId, address)
+    }
+
+    val executorIdToSummary = listener.stageIdToExecutorSummaries.get(stageId)
+    executorIdToSummary match {
+      case Some(x) => {
+        x.toSeq.sortBy(_._1).map{
+          case (k,v) => {
+            <tr>
+              <td>{k}</td>
+              <td>{executorIdToAddress.getOrElse(k, "CANNOT FIND ADDRESS")}</td>
+              <td>{parent.formatDuration(v.taskTime)}</td>
+              <td>{v.failedTasks + v.succeededTasks}</td>
+              <td>{v.failedTasks}</td>
+              <td>{v.succeededTasks}</td>
+              <td>{Utils.bytesToString(v.shuffleRead)}</td>
+              <td>{Utils.bytesToString(v.shuffleWrite)}</td>
+            </tr>
+          }
+        }
+      }
+      case _ => { Seq[Node]() }
+    }
+  }
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
index 6ff8e9fb143f94f1785ba27433a19c326dfde14a..eed3544b70bb7e23594991a01e61a81a96728128 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
@@ -57,6 +57,7 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList
   val stageIdToTasksFailed = HashMap[Int, Int]()
   val stageIdToTaskInfos =
     HashMap[Int, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]]()
+  val stageIdToExecutorSummaries = HashMap[Int, HashMap[String, ExecutorSummary]]()
 
   override def onJobStart(jobStart: SparkListenerJobStart) {}
 
@@ -124,8 +125,38 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList
 
   override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized {
     val sid = taskEnd.task.stageId
+
+    // create executor summary map if necessary
+    val executorSummaryMap = stageIdToExecutorSummaries.getOrElseUpdate(key = sid,
+      op = new HashMap[String, ExecutorSummary]())
+    executorSummaryMap.getOrElseUpdate(key = taskEnd.taskInfo.executorId,
+      op = new ExecutorSummary())
+
+    val executorSummary = executorSummaryMap.get(taskEnd.taskInfo.executorId)
+    executorSummary match {
+      case Some(y) => {
+        // first update failed-task, succeed-task
+        taskEnd.reason match {
+          case Success =>
+            y.succeededTasks += 1
+          case _ =>
+            y.failedTasks += 1
+        }
+
+        // update duration
+        y.taskTime += taskEnd.taskInfo.duration
+
+        Option(taskEnd.taskMetrics).foreach { taskMetrics =>
+          taskMetrics.shuffleReadMetrics.foreach { y.shuffleRead += _.remoteBytesRead }
+          taskMetrics.shuffleWriteMetrics.foreach { y.shuffleWrite += _.shuffleBytesWritten }
+        }
+      }
+      case _ => {}
+    }
+
     val tasksActive = stageIdToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
     tasksActive -= taskEnd.taskInfo
+
     val (failureInfo, metrics): (Option[ExceptionFailure], Option[TaskMetrics]) =
       taskEnd.reason match {
         case e: ExceptionFailure =>
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index 996e1b4d1aa51e805640b939e9834b9587bb9419..8dcfeacb60fc35e108d8439e0d6601a62e96a09e 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -66,7 +66,7 @@ private[spark] class StagePage(parent: JobProgressUI) {
         <div>
           <ul class="unstyled">
             <li>
-              <strong>Total duration across all tasks: </strong>
+              <strong>Total task time across all tasks: </strong>
               {parent.formatDuration(listener.stageIdToTime.getOrElse(stageId, 0L) + activeTime)}
             </li>
             {if (hasShuffleRead)
@@ -166,11 +166,12 @@ private[spark] class StagePage(parent: JobProgressUI) {
           def quantileRow(data: Seq[String]): Seq[Node] = <tr> {data.map(d => <td>{d}</td>)} </tr>
           Some(listingTable(quantileHeaders, quantileRow, listings, fixedWidth = true))
         }
-
+      val executorTable = new ExecutorTable(parent, stageId)
       val content =
         summary ++
         <h4>Summary Metrics for {numCompleted} Completed Tasks</h4> ++
         <div>{summaryTable.getOrElse("No tasks have reported metrics yet.")}</div> ++
+        <h4>Aggregated Metrics by Executors</h4> ++ executorTable.toNodeSeq() ++
         <h4>Tasks</h4> ++ taskTable
 
       headerSparkPage(content, parent.sc, "Details for Stage %d".format(stageId), Stages)
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
index 9ad6de3c6d8de79c758f1d0764b1a171bd56012c..463d85dfd54fdf79ddc15510b301c2a3ab8ff297 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
@@ -48,7 +48,7 @@ private[spark] class StageTable(val stages: Seq[StageInfo], val parent: JobProgr
         {if (isFairScheduler) {<th>Pool Name</th>} else {}}
         <th>Description</th>
         <th>Submitted</th>
-        <th>Duration</th>
+        <th>Task Time</th>
         <th>Tasks: Succeeded/Total</th>
         <th>Shuffle Read</th>
         <th>Shuffle Write</th>
diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
index 431d88838f02f7030ddd4f0a60026d814ba73af0..9ea7fc2dfd42b3247854884703a6ad1b73653ea4 100644
--- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
@@ -32,7 +32,7 @@ class MetadataCleaner(
 {
   val name = cleanerType.toString
 
-  private val delaySeconds = MetadataCleaner.getDelaySeconds(conf)
+  private val delaySeconds = MetadataCleaner.getDelaySeconds(conf, cleanerType)
   private val periodSeconds = math.max(10, delaySeconds / 10)
   private val timer = new Timer(name + " cleanup timer", true)
 
diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
index dbff571de9759feb2a0e4f4f7832eddcf03074e8..181ae2fd45baf8a4f69ea1b94d8d2aa31d23abba 100644
--- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
@@ -104,19 +104,28 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() with Logging {
   def toMap: immutable.Map[A, B] = iterator.toMap
 
   /**
-   * Removes old key-value pairs that have timestamp earlier than `threshTime`
+   * Removes old key-value pairs that have timestamp earlier than `threshTime`,
+   * calling the supplied function on each such entry before removing.
    */
-  def clearOldValues(threshTime: Long) {
+  def clearOldValues(threshTime: Long, f: (A, B) => Unit) {
     val iterator = internalMap.entrySet().iterator()
-    while(iterator.hasNext) {
+    while (iterator.hasNext) {
       val entry = iterator.next()
       if (entry.getValue._2 < threshTime) {
+        f(entry.getKey, entry.getValue._1)
         logDebug("Removing key " + entry.getKey)
         iterator.remove()
       }
     }
   }
 
+  /**
+   * Removes old key-value pairs that have timestamp earlier than `threshTime`
+   */
+  def clearOldValues(threshTime: Long) {
+    clearOldValues(threshTime, (_, _) => ())
+  }
+
   private def currentTime: Long = System.currentTimeMillis()
 
 }
diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala
index af448fcb37a1f2676ea2a68e8b5f48f6c492e609..befdc1589f009de6c16e53280f9725bd25d1ccf2 100644
--- a/core/src/test/scala/org/apache/spark/FailureSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala
@@ -42,7 +42,7 @@ class FailureSuite extends FunSuite with LocalSparkContext {
   // Run a 3-task map job in which task 1 deterministically fails once, and check
   // whether the job completes successfully and we ran 4 tasks in total.
   test("failure in a single-stage job") {
-    sc = new SparkContext("local[1,1]", "test")
+    sc = new SparkContext("local[1,2]", "test")
     val results = sc.makeRDD(1 to 3, 3).map { x =>
       FailureSuiteState.synchronized {
         FailureSuiteState.tasksRun += 1
@@ -62,7 +62,7 @@ class FailureSuite extends FunSuite with LocalSparkContext {
 
   // Run a map-reduce job in which a reduce task deterministically fails once.
   test("failure in a two-stage job") {
-    sc = new SparkContext("local[1,1]", "test")
+    sc = new SparkContext("local[1,2]", "test")
     val results = sc.makeRDD(1 to 3).map(x => (x, x)).groupByKey(3).map {
       case (k, v) =>
         FailureSuiteState.synchronized {
diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala
index 151af0d213c65e6144441f42be49163b2389faf4..f28d5c7b133b379ba8375dba8d9eded4b7eedf20 100644
--- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala
@@ -19,20 +19,21 @@ package org.apache.spark
 
 import org.scalatest.{FunSuite, PrivateMethodTester}
 
-import org.apache.spark.scheduler.TaskScheduler
-import org.apache.spark.scheduler.cluster.{ClusterScheduler, SimrSchedulerBackend, SparkDeploySchedulerBackend}
+import org.apache.spark.scheduler.{TaskSchedulerImpl, TaskScheduler}
+import org.apache.spark.scheduler.cluster.{SimrSchedulerBackend, SparkDeploySchedulerBackend}
 import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
-import org.apache.spark.scheduler.local.LocalScheduler
+import org.apache.spark.scheduler.local.LocalBackend
 
 class SparkContextSchedulerCreationSuite
   extends FunSuite with PrivateMethodTester with LocalSparkContext with Logging {
 
-  def createTaskScheduler(master: String): TaskScheduler = {
+  def createTaskScheduler(master: String): TaskSchedulerImpl = {
     // Create local SparkContext to setup a SparkEnv. We don't actually want to start() the
     // real schedulers, so we don't want to create a full SparkContext with the desired scheduler.
     sc = new SparkContext("local", "test")
     val createTaskSchedulerMethod = PrivateMethod[TaskScheduler]('createTaskScheduler)
-    SparkContext invokePrivate createTaskSchedulerMethod(sc, master, "test")
+    val sched = SparkContext invokePrivate createTaskSchedulerMethod(sc, master, "test")
+    sched.asInstanceOf[TaskSchedulerImpl]
   }
 
   test("bad-master") {
@@ -43,55 +44,49 @@ class SparkContextSchedulerCreationSuite
   }
 
   test("local") {
-    createTaskScheduler("local") match {
-      case s: LocalScheduler =>
-        assert(s.threads === 1)
-        assert(s.maxFailures === 0)
+    val sched = createTaskScheduler("local")
+    sched.backend match {
+      case s: LocalBackend => assert(s.totalCores === 1)
       case _ => fail()
     }
   }
 
   test("local-n") {
-    createTaskScheduler("local[5]") match {
-      case s: LocalScheduler =>
-        assert(s.threads === 5)
-        assert(s.maxFailures === 0)
+    val sched = createTaskScheduler("local[5]")
+    assert(sched.maxTaskFailures === 1)
+    sched.backend match {
+      case s: LocalBackend => assert(s.totalCores === 5)
       case _ => fail()
     }
   }
 
   test("local-n-failures") {
-    createTaskScheduler("local[4, 2]") match {
-      case s: LocalScheduler =>
-        assert(s.threads === 4)
-        assert(s.maxFailures === 2)
+    val sched = createTaskScheduler("local[4, 2]")
+    assert(sched.maxTaskFailures === 2)
+    sched.backend match {
+      case s: LocalBackend => assert(s.totalCores === 4)
       case _ => fail()
     }
   }
 
   test("simr") {
-    createTaskScheduler("simr://uri") match {
-      case s: ClusterScheduler =>
-        assert(s.backend.isInstanceOf[SimrSchedulerBackend])
+    createTaskScheduler("simr://uri").backend match {
+      case s: SimrSchedulerBackend => // OK
       case _ => fail()
     }
   }
 
   test("local-cluster") {
-    createTaskScheduler("local-cluster[3, 14, 512]") match {
-      case s: ClusterScheduler =>
-        assert(s.backend.isInstanceOf[SparkDeploySchedulerBackend])
+    createTaskScheduler("local-cluster[3, 14, 512]").backend match {
+      case s: SparkDeploySchedulerBackend => // OK
       case _ => fail()
     }
   }
 
   def testYarn(master: String, expectedClassName: String) {
     try {
-      createTaskScheduler(master) match {
-        case s: ClusterScheduler =>
-          assert(s.getClass === Class.forName(expectedClassName))
-        case _ => fail()
-      }
+      val sched = createTaskScheduler(master)
+      assert(sched.getClass === Class.forName(expectedClassName))
     } catch {
       case e: SparkException =>
         assert(e.getMessage.contains("YARN mode not available"))
@@ -110,11 +105,8 @@ class SparkContextSchedulerCreationSuite
 
   def testMesos(master: String, expectedClass: Class[_]) {
     try {
-      createTaskScheduler(master) match {
-        case s: ClusterScheduler =>
-          assert(s.backend.getClass === expectedClass)
-        case _ => fail()
-      }
+      val sched = createTaskScheduler(master)
+      assert(sched.backend.getClass === expectedClass)
     } catch {
       case e: UnsatisfiedLinkError =>
         assert(e.getMessage.contains("no mesos in"))
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala
similarity index 95%
rename from core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala
rename to core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala
index 34d2e4cb8c7c5ab0b85e961094db887646524a29..7bf2020fe378eae92a89355d20005efc9c822ec9 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala
@@ -15,14 +15,12 @@
  * limitations under the License.
  */
 
-package org.apache.spark.scheduler.cluster
+package org.apache.spark.scheduler
 
 import org.scalatest.FunSuite
 import org.scalatest.BeforeAndAfter
 
 import org.apache.spark._
-import org.apache.spark.scheduler._
-import org.apache.spark.scheduler.cluster._
 import scala.collection.mutable.ArrayBuffer
 
 import java.util.Properties
@@ -31,9 +29,9 @@ class FakeTaskSetManager(
     initPriority: Int,
     initStageId: Int,
     initNumTasks: Int,
-    clusterScheduler: ClusterScheduler,
+    clusterScheduler: TaskSchedulerImpl,
     taskSet: TaskSet)
-  extends ClusterTaskSetManager(clusterScheduler, taskSet) {
+  extends TaskSetManager(clusterScheduler, taskSet, 0) {
 
   parent = null
   weight = 1
@@ -106,7 +104,7 @@ class FakeTaskSetManager(
 
 class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging {
 
-  def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: ClusterScheduler, taskSet: TaskSet): FakeTaskSetManager = {
+  def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: TaskSchedulerImpl, taskSet: TaskSet): FakeTaskSetManager = {
     new FakeTaskSetManager(priority, stage, numTasks, cs , taskSet)
   }
 
@@ -133,7 +131,7 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging
 
   test("FIFO Scheduler Test") {
     sc = new SparkContext("local", "ClusterSchedulerSuite")
-    val clusterScheduler = new ClusterScheduler(sc)
+    val clusterScheduler = new TaskSchedulerImpl(sc)
     var tasks = ArrayBuffer[Task[_]]()
     val task = new FakeTask(0)
     tasks += task
@@ -160,7 +158,7 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging
 
   test("Fair Scheduler Test") {
     sc = new SparkContext("local", "ClusterSchedulerSuite")
-    val clusterScheduler = new ClusterScheduler(sc)
+    val clusterScheduler = new TaskSchedulerImpl(sc)
     var tasks = ArrayBuffer[Task[_]]()
     val task = new FakeTask(0)
     tasks += task
@@ -217,7 +215,7 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging
 
   test("Nested Pool Test") {
     sc = new SparkContext("local", "ClusterSchedulerSuite")
-    val clusterScheduler = new ClusterScheduler(sc)
+    val clusterScheduler = new TaskSchedulerImpl(sc)
     var tasks = ArrayBuffer[Task[_]]()
     val task = new FakeTask(0)
     tasks += task
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
similarity index 91%
rename from core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala
rename to core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
index 0f01515179f0b8683716ffcc0e2ba1b31b3dff87..0b90c4e74c8a481d0befe61a25f6c5d0301c9328 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
@@ -15,10 +15,9 @@
  * limitations under the License.
  */
 
-package org.apache.spark.scheduler.cluster
+package org.apache.spark.scheduler
 
 import org.apache.spark.TaskContext
-import org.apache.spark.scheduler.{TaskLocation, Task}
 
 class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0) {
   override def runTask(context: TaskContext): Int = 0
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
index 2e41438a527aac204662ec5f6c7b687a408b4da4..d4320e5e14458afef262316bdf18d4e75257ac18 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
@@ -19,23 +19,26 @@ package org.apache.spark.scheduler
 
 import scala.collection.mutable.{Buffer, HashSet}
 
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
 import org.scalatest.matchers.ShouldMatchers
 
 import org.apache.spark.{LocalSparkContext, SparkContext}
 import org.apache.spark.SparkContext._
 
 class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
-    with BeforeAndAfterAll {
+    with BeforeAndAfter with BeforeAndAfterAll {
   /** Length of time to wait while draining listener events. */
   val WAIT_TIMEOUT_MILLIS = 10000
 
+  before {
+    sc = new SparkContext("local", "SparkListenerSuite")
+  }
+
   override def afterAll {
     System.clearProperty("spark.akka.frameSize")
   }
 
   test("basic creation of StageInfo") {
-    sc = new SparkContext("local", "DAGSchedulerSuite")
     val listener = new SaveStageInfo
     sc.addSparkListener(listener)
     val rdd1 = sc.parallelize(1 to 100, 4)
@@ -56,7 +59,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
   }
 
   test("StageInfo with fewer tasks than partitions") {
-    sc = new SparkContext("local", "DAGSchedulerSuite")
     val listener = new SaveStageInfo
     sc.addSparkListener(listener)
     val rdd1 = sc.parallelize(1 to 100, 4)
@@ -72,7 +74,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
   }
 
   test("local metrics") {
-    sc = new SparkContext("local", "DAGSchedulerSuite")
     val listener = new SaveStageInfo
     sc.addSparkListener(listener)
     sc.addSparkListener(new StatsReportListener)
@@ -135,10 +136,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
   }
 
   test("onTaskGettingResult() called when result fetched remotely") {
-    // Need to use local cluster mode here, because results are not ever returned through the
-    // block manager when using the LocalScheduler.
-    sc = new SparkContext("local-cluster[1,1,512]", "test")
-
     val listener = new SaveTaskEvents
     sc.addSparkListener(listener)
  
@@ -157,10 +154,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
   }
 
   test("onTaskGettingResult() not called when result sent directly") {
-    // Need to use local cluster mode here, because results are not ever returned through the
-    // block manager when using the LocalScheduler.
-    sc = new SparkContext("local-cluster[1,1,512]", "test")
-
     val listener = new SaveTaskEvents
     sc.addSparkListener(listener)
  
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
similarity index 85%
rename from core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala
rename to core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
index 618fae7c16083f3230b099e6b2758dff4ff2fe8c..4b52d9651ebe82cbc8ff640e09445d89e0bd4299 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
@@ -15,14 +15,13 @@
  * limitations under the License.
  */
 
-package org.apache.spark.scheduler.cluster
+package org.apache.spark.scheduler
 
 import java.nio.ByteBuffer
 
 import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
 
-import org.apache.spark.{SparkConf, LocalSparkContext, SparkContext, SparkEnv}
-import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, TaskResult}
+import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv}
 import org.apache.spark.storage.TaskResultBlockId
 
 /**
@@ -31,12 +30,12 @@ import org.apache.spark.storage.TaskResultBlockId
  * Used to test the case where a BlockManager evicts the task result (or dies) before the
  * TaskResult is retrieved.
  */
-class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterScheduler)
+class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedulerImpl)
   extends TaskResultGetter(sparkEnv, scheduler) {
   var removedResult = false
 
   override def enqueueSuccessfulTask(
-    taskSetManager: ClusterTaskSetManager, tid: Long, serializedData: ByteBuffer) {
+    taskSetManager: TaskSetManager, tid: Long, serializedData: ByteBuffer) {
     if (!removedResult) {
       // Only remove the result once, since we'd like to test the case where the task eventually
       // succeeds.
@@ -44,13 +43,13 @@ class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSched
         case IndirectTaskResult(blockId) =>
           sparkEnv.blockManager.master.removeBlock(blockId)
         case directResult: DirectTaskResult[_] =>
-          taskSetManager.abort("Internal error: expect only indirect results") 
+          taskSetManager.abort("Internal error: expect only indirect results")
       }
       serializedData.rewind()
       removedResult = true
     }
     super.enqueueSuccessfulTask(taskSetManager, tid, serializedData)
-  } 
+  }
 }
 
 /**
@@ -65,22 +64,18 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndA
     System.setProperty("spark.akka.frameSize", "1")
   }
 
-  before {
-    // Use local-cluster mode because results are returned differently when running with the
-    // LocalScheduler.
-    sc = new SparkContext("local-cluster[1,1,512]", "test")
-  }
-
   override def afterAll {
     System.clearProperty("spark.akka.frameSize")
   }
 
   test("handling results smaller than Akka frame size") {
+    sc = new SparkContext("local", "test")
     val result = sc.parallelize(Seq(1), 1).map(x => 2 * x).reduce((x, y) => x)
     assert(result === 2)
   }
 
-  test("handling results larger than Akka frame size") { 
+  test("handling results larger than Akka frame size") {
+    sc = new SparkContext("local", "test")
     val akkaFrameSize =
       sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt
     val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x)
@@ -92,10 +87,13 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndA
   }
 
   test("task retried if result missing from block manager") {
+    // Set the maximum number of task failures to > 0, so that the task set isn't aborted
+    // after the result is missing.
+    sc = new SparkContext("local[1,2]", "test")
     // If this test hangs, it's probably because no resource offers were made after the task
     // failed.
-    val scheduler: ClusterScheduler = sc.taskScheduler match {
-      case clusterScheduler: ClusterScheduler =>
+    val scheduler: TaskSchedulerImpl = sc.taskScheduler match {
+      case clusterScheduler: TaskSchedulerImpl =>
         clusterScheduler
       case _ =>
         assert(false, "Expect local cluster to use ClusterScheduler")
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
similarity index 93%
rename from core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala
rename to core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index 3711382f2ee2c06ad794d3ce4123797063dc528b..5d33e662535df9ae865f5b3b5a4ea205ee550805 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -15,7 +15,7 @@
  * limitations under the License.
  */
 
-package org.apache.spark.scheduler.cluster
+package org.apache.spark.scheduler
 
 import scala.collection.mutable.ArrayBuffer
 import scala.collection.mutable
@@ -23,7 +23,6 @@ import scala.collection.mutable
 import org.scalatest.FunSuite
 
 import org.apache.spark._
-import org.apache.spark.scheduler._
 import org.apache.spark.executor.TaskMetrics
 import java.nio.ByteBuffer
 import org.apache.spark.util.{Utils, FakeClock}
@@ -56,10 +55,10 @@ class FakeDAGScheduler(taskScheduler: FakeClusterScheduler) extends DAGScheduler
  * A mock ClusterScheduler implementation that just remembers information about tasks started and
  * feedback received from the TaskSetManagers. Note that it's important to initialize this with
  * a list of "live" executors and their hostnames for isExecutorAlive and hasExecutorsAliveOnHost
- * to work, and these are required for locality in ClusterTaskSetManager.
+ * to work, and these are required for locality in TaskSetManager.
  */
 class FakeClusterScheduler(sc: SparkContext, liveExecutors: (String, String)* /* execId, host */)
-  extends ClusterScheduler(sc)
+  extends TaskSchedulerImpl(sc)
 {
   val startedTasks = new ArrayBuffer[Long]
   val endedTasks = new mutable.HashMap[Long, TaskEndReason]
@@ -79,16 +78,19 @@ class FakeClusterScheduler(sc: SparkContext, liveExecutors: (String, String)* /*
   override def hasExecutorsAliveOnHost(host: String): Boolean = executors.values.exists(_ == host)
 }
 
-class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
+class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
   import TaskLocality.{ANY, PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL}
+
   private val conf = new SparkConf
+
   val LOCALITY_WAIT = conf.getOrElse("spark.locality.wait", "3000").toLong
+  val MAX_TASK_FAILURES = 4
 
   test("TaskSet with no preferences") {
     sc = new SparkContext("local", "test")
     val sched = new FakeClusterScheduler(sc, ("exec1", "host1"))
     val taskSet = createTaskSet(1)
-    val manager = new ClusterTaskSetManager(sched, taskSet)
+    val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES)
 
     // Offer a host with no CPUs
     assert(manager.resourceOffer("exec1", "host1", 0, ANY) === None)
@@ -114,7 +116,7 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo
     sc = new SparkContext("local", "test")
     val sched = new FakeClusterScheduler(sc, ("exec1", "host1"))
     val taskSet = createTaskSet(3)
-    val manager = new ClusterTaskSetManager(sched, taskSet)
+    val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES)
 
     // First three offers should all find tasks
     for (i <- 0 until 3) {
@@ -151,7 +153,7 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo
       Seq()   // Last task has no locality prefs
     )
     val clock = new FakeClock
-    val manager = new ClusterTaskSetManager(sched, taskSet, clock)
+    val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock)
 
     // First offer host1, exec1: first task should be chosen
     assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0)
@@ -197,7 +199,7 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo
       Seq(TaskLocation("host2"))
     )
     val clock = new FakeClock
-    val manager = new ClusterTaskSetManager(sched, taskSet, clock)
+    val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock)
 
     // First offer host1: first task should be chosen
     assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0)
@@ -234,7 +236,7 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo
       Seq(TaskLocation("host3"))
     )
     val clock = new FakeClock
-    val manager = new ClusterTaskSetManager(sched, taskSet, clock)
+    val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock)
 
     // First offer host1: first task should be chosen
     assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0)
@@ -262,7 +264,7 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo
     val sched = new FakeClusterScheduler(sc, ("exec1", "host1"))
     val taskSet = createTaskSet(1)
     val clock = new FakeClock
-    val manager = new ClusterTaskSetManager(sched, taskSet, clock)
+    val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock)
 
     assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0)
 
@@ -279,17 +281,17 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo
     val sched = new FakeClusterScheduler(sc, ("exec1", "host1"))
     val taskSet = createTaskSet(1)
     val clock = new FakeClock
-    val manager = new ClusterTaskSetManager(sched, taskSet, clock)
+    val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock)
 
     // Fail the task MAX_TASK_FAILURES times, and check that the task set is aborted
     // after the last failure.
-    (1 to manager.MAX_TASK_FAILURES).foreach { index =>
+    (1 to manager.maxTaskFailures).foreach { index =>
       val offerResult = manager.resourceOffer("exec1", "host1", 1, ANY)
       assert(offerResult != None,
         "Expect resource offer on iteration %s to return a task".format(index))
       assert(offerResult.get.index === 0)
       manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, Some(TaskResultLost))
-      if (index < manager.MAX_TASK_FAILURES) {
+      if (index < MAX_TASK_FAILURES) {
         assert(!sched.taskSetsFailed.contains(taskSet.id))
       } else {
         assert(sched.taskSetsFailed.contains(taskSet.id))
diff --git a/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala
deleted file mode 100644
index 1e676c1719337179e02940472632baeb7966b63a..0000000000000000000000000000000000000000
--- a/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala
+++ /dev/null
@@ -1,227 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler.local
-
-import java.util.concurrent.Semaphore
-import java.util.concurrent.CountDownLatch
-
-import scala.collection.mutable.HashMap
-
-import org.scalatest.{BeforeAndAfterEach, FunSuite}
-
-import org.apache.spark._
-
-
-class Lock() {
-  var finished = false
-  def jobWait() = {
-    synchronized {
-      while(!finished) {
-        this.wait()
-      }
-    }
-  }
-
-  def jobFinished() = {
-    synchronized {
-      finished = true
-      this.notifyAll()
-    }
-  }
-}
-
-object TaskThreadInfo {
-  val threadToLock = HashMap[Int, Lock]()
-  val threadToRunning = HashMap[Int, Boolean]()
-  val threadToStarted = HashMap[Int, CountDownLatch]()
-}
-
-/*
- * 1. each thread contains one job.
- * 2. each job contains one stage.
- * 3. each stage only contains one task.
- * 4. each task(launched) must be lanched orderly(using threadToStarted) to make sure
- *    it will get cpu core resource, and will wait to finished after user manually
- *    release "Lock" and then cluster will contain another free cpu cores.
- * 5. each task(pending) must use "sleep" to  make sure it has been added to taskSetManager queue,
- *    thus it will be scheduled later when cluster has free cpu cores.
- */
-class LocalSchedulerSuite extends FunSuite with LocalSparkContext with BeforeAndAfterEach {
-
-  override def afterEach() {
-    super.afterEach()
-    System.clearProperty("spark.scheduler.mode")
-  }
-
-  def createThread(threadIndex: Int, poolName: String, sc: SparkContext, sem: Semaphore) {
-
-    TaskThreadInfo.threadToRunning(threadIndex) = false
-    val nums = sc.parallelize(threadIndex to threadIndex, 1)
-    TaskThreadInfo.threadToLock(threadIndex) = new Lock()
-    TaskThreadInfo.threadToStarted(threadIndex) = new CountDownLatch(1)
-    new Thread {
-      if (poolName != null) {
-        sc.setLocalProperty("spark.scheduler.pool", poolName)
-      }
-      override def run() {
-        val ans = nums.map(number => {
-          TaskThreadInfo.threadToRunning(number) = true
-          TaskThreadInfo.threadToStarted(number).countDown()
-          TaskThreadInfo.threadToLock(number).jobWait()
-          TaskThreadInfo.threadToRunning(number) = false
-          number
-        }).collect()
-        assert(ans.toList === List(threadIndex))
-        sem.release()
-      }
-    }.start()
-  }
-
-  test("Local FIFO scheduler end-to-end test") {
-    System.setProperty("spark.scheduler.mode", "FIFO")
-    sc = new SparkContext("local[4]", "test")
-    val sem = new Semaphore(0)
-
-    createThread(1,null,sc,sem)
-    TaskThreadInfo.threadToStarted(1).await()
-    createThread(2,null,sc,sem)
-    TaskThreadInfo.threadToStarted(2).await()
-    createThread(3,null,sc,sem)
-    TaskThreadInfo.threadToStarted(3).await()
-    createThread(4,null,sc,sem)
-    TaskThreadInfo.threadToStarted(4).await()
-    // thread 5 and 6 (stage pending)must meet following two points
-    // 1. stages (taskSetManager) of jobs in thread 5 and 6 should be add to taskSetManager
-    //    queue before executing TaskThreadInfo.threadToLock(1).jobFinished()
-    // 2. priority of stage in thread 5 should be prior to priority of stage in thread 6
-    // So I just use "sleep" 1s here for each thread.
-    // TODO: any better solution?
-    createThread(5,null,sc,sem)
-    Thread.sleep(1000)
-    createThread(6,null,sc,sem)
-    Thread.sleep(1000)
-
-    assert(TaskThreadInfo.threadToRunning(1) === true)
-    assert(TaskThreadInfo.threadToRunning(2) === true)
-    assert(TaskThreadInfo.threadToRunning(3) === true)
-    assert(TaskThreadInfo.threadToRunning(4) === true)
-    assert(TaskThreadInfo.threadToRunning(5) === false)
-    assert(TaskThreadInfo.threadToRunning(6) === false)
-
-    TaskThreadInfo.threadToLock(1).jobFinished()
-    TaskThreadInfo.threadToStarted(5).await()
-
-    assert(TaskThreadInfo.threadToRunning(1) === false)
-    assert(TaskThreadInfo.threadToRunning(2) === true)
-    assert(TaskThreadInfo.threadToRunning(3) === true)
-    assert(TaskThreadInfo.threadToRunning(4) === true)
-    assert(TaskThreadInfo.threadToRunning(5) === true)
-    assert(TaskThreadInfo.threadToRunning(6) === false)
-
-    TaskThreadInfo.threadToLock(3).jobFinished()
-    TaskThreadInfo.threadToStarted(6).await()
-
-    assert(TaskThreadInfo.threadToRunning(1) === false)
-    assert(TaskThreadInfo.threadToRunning(2) === true)
-    assert(TaskThreadInfo.threadToRunning(3) === false)
-    assert(TaskThreadInfo.threadToRunning(4) === true)
-    assert(TaskThreadInfo.threadToRunning(5) === true)
-    assert(TaskThreadInfo.threadToRunning(6) === true)
-
-    TaskThreadInfo.threadToLock(2).jobFinished()
-    TaskThreadInfo.threadToLock(4).jobFinished()
-    TaskThreadInfo.threadToLock(5).jobFinished()
-    TaskThreadInfo.threadToLock(6).jobFinished()
-    sem.acquire(6)
-  }
-
-  test("Local fair scheduler end-to-end test") {
-    System.setProperty("spark.scheduler.mode", "FAIR")
-    val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile()
-    System.setProperty("spark.scheduler.allocation.file", xmlPath)
-
-    sc = new SparkContext("local[8]", "LocalSchedulerSuite")
-    val sem = new Semaphore(0)
-
-    createThread(10,"1",sc,sem)
-    TaskThreadInfo.threadToStarted(10).await()
-    createThread(20,"2",sc,sem)
-    TaskThreadInfo.threadToStarted(20).await()
-    createThread(30,"3",sc,sem)
-    TaskThreadInfo.threadToStarted(30).await()
-
-    assert(TaskThreadInfo.threadToRunning(10) === true)
-    assert(TaskThreadInfo.threadToRunning(20) === true)
-    assert(TaskThreadInfo.threadToRunning(30) === true)
-
-    createThread(11,"1",sc,sem)
-    TaskThreadInfo.threadToStarted(11).await()
-    createThread(21,"2",sc,sem)
-    TaskThreadInfo.threadToStarted(21).await()
-    createThread(31,"3",sc,sem)
-    TaskThreadInfo.threadToStarted(31).await()
-
-    assert(TaskThreadInfo.threadToRunning(11) === true)
-    assert(TaskThreadInfo.threadToRunning(21) === true)
-    assert(TaskThreadInfo.threadToRunning(31) === true)
-
-    createThread(12,"1",sc,sem)
-    TaskThreadInfo.threadToStarted(12).await()
-    createThread(22,"2",sc,sem)
-    TaskThreadInfo.threadToStarted(22).await()
-    createThread(32,"3",sc,sem)
-
-    assert(TaskThreadInfo.threadToRunning(12) === true)
-    assert(TaskThreadInfo.threadToRunning(22) === true)
-    assert(TaskThreadInfo.threadToRunning(32) === false)
-
-    TaskThreadInfo.threadToLock(10).jobFinished()
-    TaskThreadInfo.threadToStarted(32).await()
-
-    assert(TaskThreadInfo.threadToRunning(32) === true)
-
-    //1. Similar with above scenario, sleep 1s for stage of 23 and 33 to be added to taskSetManager
-    //   queue so that cluster will assign free cpu core to stage 23 after stage 11 finished.
-    //2. priority of 23 and 33 will be meaningless as using fair scheduler here.
-    createThread(23,"2",sc,sem)
-    createThread(33,"3",sc,sem)
-    Thread.sleep(1000)
-
-    TaskThreadInfo.threadToLock(11).jobFinished()
-    TaskThreadInfo.threadToStarted(23).await()
-
-    assert(TaskThreadInfo.threadToRunning(23) === true)
-    assert(TaskThreadInfo.threadToRunning(33) === false)
-
-    TaskThreadInfo.threadToLock(12).jobFinished()
-    TaskThreadInfo.threadToStarted(33).await()
-
-    assert(TaskThreadInfo.threadToRunning(33) === true)
-
-    TaskThreadInfo.threadToLock(20).jobFinished()
-    TaskThreadInfo.threadToLock(21).jobFinished()
-    TaskThreadInfo.threadToLock(22).jobFinished()
-    TaskThreadInfo.threadToLock(23).jobFinished()
-    TaskThreadInfo.threadToLock(30).jobFinished()
-    TaskThreadInfo.threadToLock(31).jobFinished()
-    TaskThreadInfo.threadToLock(32).jobFinished()
-    TaskThreadInfo.threadToLock(33).jobFinished()
-
-    sem.acquire(11)
-  }
-}
diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..67a57a0e7f9d0cb8c08805a358c5c691c430511f
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.jobs
+
+import org.scalatest.FunSuite
+import org.apache.spark.scheduler._
+import org.apache.spark.{LocalSparkContext, SparkContext, Success}
+import org.apache.spark.scheduler.SparkListenerTaskStart
+import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics}
+
+class JobProgressListenerSuite extends FunSuite with LocalSparkContext {
+  test("test executor id to summary") {
+    val sc = new SparkContext("local", "test")
+    val listener = new JobProgressListener(sc)
+    val taskMetrics = new TaskMetrics()
+    val shuffleReadMetrics = new ShuffleReadMetrics()
+
+    // nothing in it
+    assert(listener.stageIdToExecutorSummaries.size == 0)
+
+    // finish this task, should get updated shuffleRead
+    shuffleReadMetrics.remoteBytesRead = 1000
+    taskMetrics.shuffleReadMetrics = Some(shuffleReadMetrics)
+    var taskInfo = new TaskInfo(1234L, 0, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL)
+    taskInfo.finishTime = 1
+    listener.onTaskEnd(new SparkListenerTaskEnd(
+      new ShuffleMapTask(0, null, null, 0, null), Success, taskInfo, taskMetrics))
+    assert(listener.stageIdToExecutorSummaries.getOrElse(0, fail()).getOrElse("exe-1", fail())
+      .shuffleRead == 1000)
+
+    // finish a task with unknown executor-id, nothing should happen
+    taskInfo = new TaskInfo(1234L, 0, 1000L, "exe-unknown", "host1", TaskLocality.NODE_LOCAL)
+    taskInfo.finishTime = 1
+    listener.onTaskEnd(new SparkListenerTaskEnd(
+      new ShuffleMapTask(0, null, null, 0, null), Success, taskInfo, taskMetrics))
+    assert(listener.stageIdToExecutorSummaries.size == 1)
+
+    // finish this task, should get updated duration
+    shuffleReadMetrics.remoteBytesRead = 1000
+    taskMetrics.shuffleReadMetrics = Some(shuffleReadMetrics)
+    taskInfo = new TaskInfo(1235L, 0, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL)
+    taskInfo.finishTime = 1
+    listener.onTaskEnd(new SparkListenerTaskEnd(
+      new ShuffleMapTask(0, null, null, 0, null), Success, taskInfo, taskMetrics))
+    assert(listener.stageIdToExecutorSummaries.getOrElse(0, fail()).getOrElse("exe-1", fail())
+      .shuffleRead == 2000)
+
+    // finish this task, should get updated duration
+    shuffleReadMetrics.remoteBytesRead = 1000
+    taskMetrics.shuffleReadMetrics = Some(shuffleReadMetrics)
+    taskInfo = new TaskInfo(1236L, 0, 0L, "exe-2", "host1", TaskLocality.NODE_LOCAL)
+    taskInfo.finishTime = 1
+    listener.onTaskEnd(new SparkListenerTaskEnd(
+      new ShuffleMapTask(0, null, null, 0, null), Success, taskInfo, taskMetrics))
+    assert(listener.stageIdToExecutorSummaries.getOrElse(0, fail()).getOrElse("exe-2", fail())
+      .shuffleRead == 1000)
+  }
+}
diff --git a/examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java b/examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java
index 9a8e4209eddc7da8d87401dafe81b46f3456065a..22994fb2ec71c70261b604b569124e0e0330fc1c 100644
--- a/examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java
+++ b/examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java
@@ -53,7 +53,7 @@ public class JavaKafkaWordCount {
     }
 
     // Create the context with a 1 second batch size
-    JavaStreamingContext ssc = new JavaStreamingContext(args[0], "NetworkWordCount",
+    JavaStreamingContext ssc = new JavaStreamingContext(args[0], "KafkaWordCount",
             new Duration(2000), System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR"));
 
     int numThreads = Integer.parseInt(args[4]);
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
new file mode 100644
index 0000000000000000000000000000000000000000..8247c1ebc5d2b32910cc4efc045d32c134296bc4
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.api.python
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.mllib.regression._
+import org.apache.spark.mllib.classification._
+import org.apache.spark.mllib.clustering._
+import org.apache.spark.mllib.recommendation._
+import org.apache.spark.rdd.RDD
+import java.nio.ByteBuffer
+import java.nio.ByteOrder
+import java.nio.DoubleBuffer
+
+/**
+ * The Java stubs necessary for the Python mllib bindings.
+ */
+class PythonMLLibAPI extends Serializable {
+  private def deserializeDoubleVector(bytes: Array[Byte]): Array[Double] = {
+    val packetLength = bytes.length
+    if (packetLength < 16) {
+      throw new IllegalArgumentException("Byte array too short.")
+    }
+    val bb = ByteBuffer.wrap(bytes)
+    bb.order(ByteOrder.nativeOrder())
+    val magic = bb.getLong()
+    if (magic != 1) {
+      throw new IllegalArgumentException("Magic " + magic + " is wrong.")
+    }
+    val length = bb.getLong()
+    if (packetLength != 16 + 8 * length) {
+      throw new IllegalArgumentException("Length " + length + " is wrong.")
+    }
+    val db = bb.asDoubleBuffer()
+    val ans = new Array[Double](length.toInt)
+    db.get(ans)
+    return ans
+  }
+
+  private def serializeDoubleVector(doubles: Array[Double]): Array[Byte] = {
+    val len = doubles.length
+    val bytes = new Array[Byte](16 + 8 * len)
+    val bb = ByteBuffer.wrap(bytes)
+    bb.order(ByteOrder.nativeOrder())
+    bb.putLong(1)
+    bb.putLong(len)
+    val db = bb.asDoubleBuffer()
+    db.put(doubles)
+    return bytes
+  }
+
+  private def deserializeDoubleMatrix(bytes: Array[Byte]): Array[Array[Double]] = {
+    val packetLength = bytes.length
+    if (packetLength < 24) {
+      throw new IllegalArgumentException("Byte array too short.")
+    }
+    val bb = ByteBuffer.wrap(bytes)
+    bb.order(ByteOrder.nativeOrder())
+    val magic = bb.getLong()
+    if (magic != 2) {
+      throw new IllegalArgumentException("Magic " + magic + " is wrong.")
+    }
+    val rows = bb.getLong()
+    val cols = bb.getLong()
+    if (packetLength != 24 + 8 * rows * cols) {
+      throw new IllegalArgumentException("Size " + rows + "x" + cols + " is wrong.")
+    }
+    val db = bb.asDoubleBuffer()
+    val ans = new Array[Array[Double]](rows.toInt)
+    var i = 0
+    for (i <- 0 until rows.toInt) {
+      ans(i) = new Array[Double](cols.toInt)
+      db.get(ans(i))
+    }
+    return ans
+  }
+
+  private def serializeDoubleMatrix(doubles: Array[Array[Double]]): Array[Byte] = {
+    val rows = doubles.length
+    var cols = 0
+    if (rows > 0) {
+      cols = doubles(0).length
+    }
+    val bytes = new Array[Byte](24 + 8 * rows * cols)
+    val bb = ByteBuffer.wrap(bytes)
+    bb.order(ByteOrder.nativeOrder())
+    bb.putLong(2)
+    bb.putLong(rows)
+    bb.putLong(cols)
+    val db = bb.asDoubleBuffer()
+    var i = 0
+    for (i <- 0 until rows) {
+      db.put(doubles(i))
+    }
+    return bytes
+  }
+
+  private def trainRegressionModel(trainFunc: (RDD[LabeledPoint], Array[Double]) => GeneralizedLinearModel,
+      dataBytesJRDD: JavaRDD[Array[Byte]], initialWeightsBA: Array[Byte]):
+      java.util.LinkedList[java.lang.Object] = {
+    val data = dataBytesJRDD.rdd.map(xBytes => {
+        val x = deserializeDoubleVector(xBytes)
+        LabeledPoint(x(0), x.slice(1, x.length))
+    })
+    val initialWeights = deserializeDoubleVector(initialWeightsBA)
+    val model = trainFunc(data, initialWeights)
+    val ret = new java.util.LinkedList[java.lang.Object]()
+    ret.add(serializeDoubleVector(model.weights))
+    ret.add(model.intercept: java.lang.Double)
+    return ret
+  }
+
+  /**
+   * Java stub for Python mllib LinearRegressionWithSGD.train()
+   */
+  def trainLinearRegressionModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]],
+      numIterations: Int, stepSize: Double, miniBatchFraction: Double,
+      initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+    return trainRegressionModel((data, initialWeights) =>
+        LinearRegressionWithSGD.train(data, numIterations, stepSize,
+                                      miniBatchFraction, initialWeights),
+        dataBytesJRDD, initialWeightsBA)
+  }
+
+  /**
+   * Java stub for Python mllib LassoWithSGD.train()
+   */
+  def trainLassoModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int,
+      stepSize: Double, regParam: Double, miniBatchFraction: Double,
+      initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+    return trainRegressionModel((data, initialWeights) =>
+        LassoWithSGD.train(data, numIterations, stepSize, regParam,
+                           miniBatchFraction, initialWeights),
+        dataBytesJRDD, initialWeightsBA)
+  }
+
+  /**
+   * Java stub for Python mllib RidgeRegressionWithSGD.train()
+   */
+  def trainRidgeModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int,
+      stepSize: Double, regParam: Double, miniBatchFraction: Double,
+      initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+    return trainRegressionModel((data, initialWeights) =>
+        RidgeRegressionWithSGD.train(data, numIterations, stepSize, regParam,
+                                     miniBatchFraction, initialWeights),
+        dataBytesJRDD, initialWeightsBA)
+  }
+
+  /**
+   * Java stub for Python mllib SVMWithSGD.train()
+   */
+  def trainSVMModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int,
+      stepSize: Double, regParam: Double, miniBatchFraction: Double,
+      initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+    return trainRegressionModel((data, initialWeights) =>
+        SVMWithSGD.train(data, numIterations, stepSize, regParam,
+                                     miniBatchFraction, initialWeights),
+        dataBytesJRDD, initialWeightsBA)
+  }
+
+  /**
+   * Java stub for Python mllib LogisticRegressionWithSGD.train()
+   */
+  def trainLogisticRegressionModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]],
+      numIterations: Int, stepSize: Double, miniBatchFraction: Double,
+      initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+    return trainRegressionModel((data, initialWeights) =>
+        LogisticRegressionWithSGD.train(data, numIterations, stepSize,
+                                     miniBatchFraction, initialWeights),
+        dataBytesJRDD, initialWeightsBA)
+  }
+
+  /**
+   * Java stub for Python mllib KMeans.train()
+   */
+  def trainKMeansModel(dataBytesJRDD: JavaRDD[Array[Byte]], k: Int,
+      maxIterations: Int, runs: Int, initializationMode: String):
+      java.util.List[java.lang.Object] = {
+    val data = dataBytesJRDD.rdd.map(xBytes => deserializeDoubleVector(xBytes))
+    val model = KMeans.train(data, k, maxIterations, runs, initializationMode)
+    val ret = new java.util.LinkedList[java.lang.Object]()
+    ret.add(serializeDoubleMatrix(model.clusterCenters))
+    return ret
+  }
+
+  private def unpackRating(ratingBytes: Array[Byte]): Rating = {
+    val bb = ByteBuffer.wrap(ratingBytes)
+    bb.order(ByteOrder.nativeOrder())
+    val user = bb.getInt()
+    val product = bb.getInt()
+    val rating = bb.getDouble()
+    return new Rating(user, product, rating)
+  }
+
+  /**
+   * Java stub for Python mllib ALS.train().  This stub returns a handle
+   * to the Java object instead of the content of the Java object.  Extra care
+   * needs to be taken in the Python code to ensure it gets freed on exit; see
+   * the Py4J documentation.
+   */
+  def trainALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int,
+      iterations: Int, lambda: Double, blocks: Int): MatrixFactorizationModel = {
+    val ratings = ratingsBytesJRDD.rdd.map(unpackRating)
+    return ALS.train(ratings, rank, iterations, lambda, blocks)
+  }
+
+  /**
+   * Java stub for Python mllib ALS.trainImplicit().  This stub returns a
+   * handle to the Java object instead of the content of the Java object.
+   * Extra care needs to be taken in the Python code to ensure it gets freed on
+   * exit; see the Py4J documentation.
+   */
+  def trainImplicitALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int,
+      iterations: Int, lambda: Double, blocks: Int, alpha: Double): MatrixFactorizationModel = {
+    val ratings = ratingsBytesJRDD.rdd.map(unpackRating)
+    return ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha)
+  }
+}
diff --git a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index 963b5b88be4311302802aa7c82b4a8279189e41f..1bba6a5ae4ac86ccb2b3986e5704b3bdc3825353 100644
--- a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -437,8 +437,10 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
   }
 
   def monitorApplication(appId: ApplicationId): Boolean = {
+    val interval = new SparkConf().getOrElse("spark.yarn.report.interval", "1000").toLong
+
     while (true) {
-      Thread.sleep(1000)
+      Thread.sleep(interval)
       val report = super.getApplicationReport(appId)
 
       logInfo("Application report from ASM: \n" +
diff --git a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala
index 71d1cbd416f5a23c63d94acbadfa9ea2807f3a86..abc3447746f9e27d9db43cc7cd12e078f12b941c 100644
--- a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala
+++ b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala
@@ -27,8 +27,8 @@ import scala.collection.JavaConversions._
 import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
 
 import org.apache.spark.Logging
-import org.apache.spark.scheduler.SplitInfo
-import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedulerBackend}
+import org.apache.spark.scheduler.{SplitInfo,TaskSchedulerImpl}
+import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
 import org.apache.spark.util.Utils
 
 import org.apache.hadoop.conf.Configuration
@@ -233,9 +233,9 @@ private[yarn] class YarnAllocationHandler(
       // Note that the list we create below tries to ensure that not all containers end up within
       // a host if there is a sufficiently large number of hosts/containers.
       val allocatedContainersToProcess = new ArrayBuffer[Container](allocatedContainers.size)
-      allocatedContainersToProcess ++= ClusterScheduler.prioritizeContainers(dataLocalContainers)
-      allocatedContainersToProcess ++= ClusterScheduler.prioritizeContainers(rackLocalContainers)
-      allocatedContainersToProcess ++= ClusterScheduler.prioritizeContainers(offRackContainers)
+      allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(dataLocalContainers)
+      allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(rackLocalContainers)
+      allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(offRackContainers)
 
       // Run each of the allocated containers.
       for (container <- allocatedContainersToProcess) {
diff --git a/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala b/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala
index 63a0449e5a0730085554d2b8ae86067135fa8dba..522e0a9ad7eeb50f4c2b6b781a68ea998639b30a 100644
--- a/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala
+++ b/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala
@@ -20,13 +20,14 @@ package org.apache.spark.scheduler.cluster
 import org.apache.spark._
 import org.apache.hadoop.conf.Configuration
 import org.apache.spark.deploy.yarn.YarnAllocationHandler
+import org.apache.spark.scheduler.TaskSchedulerImpl
 import org.apache.spark.util.Utils
 
 /**
  *
  * This scheduler launch worker through Yarn - by call into Client to launch WorkerLauncher as AM.
  */
-private[spark] class YarnClientClusterScheduler(sc: SparkContext, conf: Configuration) extends ClusterScheduler(sc) {
+private[spark] class YarnClientClusterScheduler(sc: SparkContext, conf: Configuration) extends TaskSchedulerImpl(sc) {
 
   def this(sc: SparkContext) = this(sc, new Configuration())
 
diff --git a/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
index 6feaaff01425606acf9c1da751b84201a6c95f59..4b69f5078b0ab10818f1d6ecd5d9655cd00327c8 100644
--- a/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
+++ b/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
@@ -20,9 +20,10 @@ package org.apache.spark.scheduler.cluster
 import org.apache.hadoop.yarn.api.records.{ApplicationId, YarnApplicationState}
 import org.apache.spark.{SparkException, Logging, SparkContext}
 import org.apache.spark.deploy.yarn.{Client, ClientArguments}
+import org.apache.spark.scheduler.TaskSchedulerImpl
 
 private[spark] class YarnClientSchedulerBackend(
-    scheduler: ClusterScheduler,
+    scheduler: TaskSchedulerImpl,
     sc: SparkContext)
   extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem)
   with Logging {
diff --git a/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala b/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
index 29b3f22e13697b38bc501e2f914d8fc0a202d722..a4638cc863611c0e152be72f47aed0c222aba5f8 100644
--- a/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
+++ b/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
@@ -19,6 +19,7 @@ package org.apache.spark.scheduler.cluster
 
 import org.apache.spark._
 import org.apache.spark.deploy.yarn.{ApplicationMaster, YarnAllocationHandler}
+import org.apache.spark.scheduler.TaskSchedulerImpl
 import org.apache.spark.util.Utils
 import org.apache.hadoop.conf.Configuration
 
@@ -26,7 +27,7 @@ import org.apache.hadoop.conf.Configuration
  *
  * This is a simple extension to ClusterScheduler - to ensure that appropriate initialization of ApplicationMaster, etc is done
  */
-private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration) extends ClusterScheduler(sc) {
+private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration) extends TaskSchedulerImpl(sc) {
 
   logInfo("Created YarnClusterScheduler")
 
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index ffb54a24ac7ec2d6ba2d54935b4f397cbadcf10d..37d6f1b60da00bf99ae3458b70232d76f86d7be7 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -114,6 +114,9 @@ object SparkBuild extends Build {
     fork := true,
     javaOptions += "-Xmx3g",
 
+    // Show full stack trace and duration in test cases.
+    testOptions in Test += Tests.Argument("-oDF"),
+
     // Only allow one test at a time, even across projects, since they run in the same JVM
     concurrentRestrictions in Global += Tags.limit(Tags.Test, 1),
 
@@ -260,7 +263,7 @@ object SparkBuild extends Build {
    libraryDependencies <+= scalaVersion(v => "org.scala-lang"  % "scala-reflect"  % v )
   )
 
-  
+
   def examplesSettings = sharedSettings ++ Seq(
     name := "spark-examples",
     libraryDependencies ++= Seq(
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index 128f078d12c1f59566aa02ca00fee27228c9dbab..d8ca9fce0037b4bc33653e7dbf9513973edf546b 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -63,5 +63,6 @@ def launch_gateway():
     java_import(gateway.jvm, "org.apache.spark.SparkConf")
     java_import(gateway.jvm, "org.apache.spark.api.java.*")
     java_import(gateway.jvm, "org.apache.spark.api.python.*")
+    java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
     java_import(gateway.jvm, "scala.Tuple2")
     return gateway
diff --git a/python/pyspark/mllib/__init__.py b/python/pyspark/mllib/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1a5df109b46edc3e5221401dcadf2d7e248ca74
--- /dev/null
+++ b/python/pyspark/mllib/__init__.py
@@ -0,0 +1,20 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+Python bindings for MLlib.
+"""
diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py
new file mode 100644
index 0000000000000000000000000000000000000000..e74ba0fabc09ce8142c4eee10b8ff09aeea4f812
--- /dev/null
+++ b/python/pyspark/mllib/_common.py
@@ -0,0 +1,227 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape
+from pyspark import SparkContext
+
+# Double vector format:
+#
+# [8-byte 1] [8-byte length] [length*8 bytes of data]
+#
+# Double matrix format:
+#
+# [8-byte 2] [8-byte rows] [8-byte cols] [rows*cols*8 bytes of data]
+#
+# This is all in machine-endian.  That means that the Java interpreter and the
+# Python interpreter must agree on what endian the machine is.
+
+def _deserialize_byte_array(shape, ba, offset):
+    """Wrapper around ndarray aliasing hack.
+
+    >>> x = array([1.0, 2.0, 3.0, 4.0, 5.0])
+    >>> array_equal(x, _deserialize_byte_array(x.shape, x.data, 0))
+    True
+    >>> x = array([1.0, 2.0, 3.0, 4.0]).reshape(2,2)
+    >>> array_equal(x, _deserialize_byte_array(x.shape, x.data, 0))
+    True
+    """
+    ar = ndarray(shape=shape, buffer=ba, offset=offset, dtype="float64",
+            order='C')
+    return ar.copy()
+
+def _serialize_double_vector(v):
+    """Serialize a double vector into a mutually understood format."""
+    if type(v) != ndarray:
+        raise TypeError("_serialize_double_vector called on a %s; "
+                "wanted ndarray" % type(v))
+    if v.dtype != float64:
+        raise TypeError("_serialize_double_vector called on an ndarray of %s; "
+                "wanted ndarray of float64" % v.dtype)
+    if v.ndim != 1:
+        raise TypeError("_serialize_double_vector called on a %ddarray; "
+                "wanted a 1darray" % v.ndim)
+    length = v.shape[0]
+    ba = bytearray(16 + 8*length)
+    header = ndarray(shape=[2], buffer=ba, dtype="int64")
+    header[0] = 1
+    header[1] = length
+    copyto(ndarray(shape=[length], buffer=ba, offset=16,
+            dtype="float64"), v)
+    return ba
+
+def _deserialize_double_vector(ba):
+    """Deserialize a double vector from a mutually understood format.
+
+    >>> x = array([1.0, 2.0, 3.0, 4.0, -1.0, 0.0, -0.0])
+    >>> array_equal(x, _deserialize_double_vector(_serialize_double_vector(x)))
+    True
+    """
+    if type(ba) != bytearray:
+        raise TypeError("_deserialize_double_vector called on a %s; "
+                "wanted bytearray" % type(ba))
+    if len(ba) < 16:
+        raise TypeError("_deserialize_double_vector called on a %d-byte array, "
+                "which is too short" % len(ba))
+    if (len(ba) & 7) != 0:
+        raise TypeError("_deserialize_double_vector called on a %d-byte array, "
+                "which is not a multiple of 8" % len(ba))
+    header = ndarray(shape=[2], buffer=ba, dtype="int64")
+    if header[0] != 1:
+        raise TypeError("_deserialize_double_vector called on bytearray "
+                        "with wrong magic")
+    length = header[1]
+    if len(ba) != 8*length + 16:
+        raise TypeError("_deserialize_double_vector called on bytearray "
+                        "with wrong length")
+    return _deserialize_byte_array([length], ba, 16)
+
+def _serialize_double_matrix(m):
+    """Serialize a double matrix into a mutually understood format."""
+    if (type(m) == ndarray and m.dtype == float64 and m.ndim == 2):
+        rows = m.shape[0]
+        cols = m.shape[1]
+        ba = bytearray(24 + 8 * rows * cols)
+        header = ndarray(shape=[3], buffer=ba, dtype="int64")
+        header[0] = 2
+        header[1] = rows
+        header[2] = cols
+        copyto(ndarray(shape=[rows, cols], buffer=ba, offset=24,
+                       dtype="float64", order='C'), m)
+        return ba
+    else:
+        raise TypeError("_serialize_double_matrix called on a "
+                        "non-double-matrix")
+
+def _deserialize_double_matrix(ba):
+    """Deserialize a double matrix from a mutually understood format."""
+    if type(ba) != bytearray:
+        raise TypeError("_deserialize_double_matrix called on a %s; "
+                "wanted bytearray" % type(ba))
+    if len(ba) < 24:
+        raise TypeError("_deserialize_double_matrix called on a %d-byte array, "
+                "which is too short" % len(ba))
+    if (len(ba) & 7) != 0:
+        raise TypeError("_deserialize_double_matrix called on a %d-byte array, "
+                "which is not a multiple of 8" % len(ba))
+    header = ndarray(shape=[3], buffer=ba, dtype="int64")
+    if (header[0] != 2):
+        raise TypeError("_deserialize_double_matrix called on bytearray "
+                        "with wrong magic")
+    rows = header[1]
+    cols = header[2]
+    if (len(ba) != 8*rows*cols + 24):
+        raise TypeError("_deserialize_double_matrix called on bytearray "
+                        "with wrong length")
+    return _deserialize_byte_array([rows, cols], ba, 24)
+
+def _linear_predictor_typecheck(x, coeffs):
+    """Check that x is a one-dimensional vector of the right shape.
+    This is a temporary hackaround until I actually implement bulk predict."""
+    if type(x) == ndarray:
+        if x.ndim == 1:
+            if x.shape == coeffs.shape:
+                pass
+            else:
+                raise RuntimeError("Got array of %d elements; wanted %d"
+                        % (shape(x)[0], shape(coeffs)[0]))
+        else:
+            raise RuntimeError("Bulk predict not yet supported.")
+    elif (type(x) == RDD):
+        raise RuntimeError("Bulk predict not yet supported.")
+    else:
+        raise TypeError("Argument of type " + type(x).__name__ + " unsupported")
+
+def _get_unmangled_rdd(data, serializer):
+    dataBytes = data.map(serializer)
+    dataBytes._bypass_serializer = True
+    dataBytes.cache()
+    return dataBytes
+
+# Map a pickled Python RDD of numpy double vectors to a Java RDD of
+# _serialized_double_vectors
+def _get_unmangled_double_vector_rdd(data):
+    return _get_unmangled_rdd(data, _serialize_double_vector)
+
+class LinearModel(object):
+    """Something that has a vector of coefficients and an intercept."""
+    def __init__(self, coeff, intercept):
+        self._coeff = coeff
+        self._intercept = intercept
+
+class LinearRegressionModelBase(LinearModel):
+    """A linear regression model.
+
+    >>> lrmb = LinearRegressionModelBase(array([1.0, 2.0]), 0.1)
+    >>> abs(lrmb.predict(array([-1.03, 7.777])) - 14.624) < 1e-6
+    True
+    """
+    def predict(self, x):
+        """Predict the value of the dependent variable given a vector x"""
+        """containing values for the independent variables."""
+        _linear_predictor_typecheck(x, self._coeff)
+        return dot(self._coeff, x) + self._intercept
+
+# If we weren't given initial weights, take a zero vector of the appropriate
+# length.
+def _get_initial_weights(initial_weights, data):
+    if initial_weights is None:
+        initial_weights = data.first()
+        if type(initial_weights) != ndarray:
+            raise TypeError("At least one data element has type "
+                    + type(initial_weights).__name__ + " which is not ndarray")
+        if initial_weights.ndim != 1:
+            raise TypeError("At least one data element has "
+                    + initial_weights.ndim + " dimensions, which is not 1")
+        initial_weights = ones([initial_weights.shape[0] - 1])
+    return initial_weights
+
+# train_func should take two parameters, namely data and initial_weights, and
+# return the result of a call to the appropriate JVM stub.
+# _regression_train_wrapper is responsible for setup and error checking.
+def _regression_train_wrapper(sc, train_func, klass, data, initial_weights):
+    initial_weights = _get_initial_weights(initial_weights, data)
+    dataBytes = _get_unmangled_double_vector_rdd(data)
+    ans = train_func(dataBytes, _serialize_double_vector(initial_weights))
+    if len(ans) != 2:
+        raise RuntimeError("JVM call result had unexpected length")
+    elif type(ans[0]) != bytearray:
+        raise RuntimeError("JVM call result had first element of type "
+                + type(ans[0]).__name__ + " which is not bytearray")
+    elif type(ans[1]) != float:
+        raise RuntimeError("JVM call result had second element of type "
+                + type(ans[0]).__name__ + " which is not float")
+    return klass(_deserialize_double_vector(ans[0]), ans[1])
+
+def _serialize_rating(r):
+    ba = bytearray(16)
+    intpart = ndarray(shape=[2], buffer=ba, dtype=int32)
+    doublepart = ndarray(shape=[1], buffer=ba, dtype=float64, offset=8)
+    intpart[0], intpart[1], doublepart[0] = r
+    return ba
+
+def _test():
+    import doctest
+    globs = globals().copy()
+    globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
+    (failure_count, test_count) = doctest.testmod(globs=globs,
+            optionflags=doctest.ELLIPSIS)
+    globs['sc'].stop()
+    if failure_count:
+        exit(-1)
+
+if __name__ == "__main__":
+    _test()
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
new file mode 100644
index 0000000000000000000000000000000000000000..70de332d3468ea06e851ec4dbd797926dd32495a
--- /dev/null
+++ b/python/pyspark/mllib/classification.py
@@ -0,0 +1,86 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from numpy import array, dot, shape
+from pyspark import SparkContext
+from pyspark.mllib._common import \
+    _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \
+    _serialize_double_matrix, _deserialize_double_matrix, \
+    _serialize_double_vector, _deserialize_double_vector, \
+    _get_initial_weights, _serialize_rating, _regression_train_wrapper, \
+    LinearModel, _linear_predictor_typecheck
+from math import exp, log
+
+class LogisticRegressionModel(LinearModel):
+    """A linear binary classification model derived from logistic regression.
+
+    >>> data = array([0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 1.0, 3.0]).reshape(4,2)
+    >>> lrm = LogisticRegressionWithSGD.train(sc, sc.parallelize(data))
+    >>> lrm.predict(array([1.0])) != None
+    True
+    """
+    def predict(self, x):
+        _linear_predictor_typecheck(x, self._coeff)
+        margin = dot(x, self._coeff) + self._intercept
+        prob = 1/(1 + exp(-margin))
+        return 1 if prob > 0.5 else 0
+
+class LogisticRegressionWithSGD(object):
+    @classmethod
+    def train(cls, sc, data, iterations=100, step=1.0,
+              mini_batch_fraction=1.0, initial_weights=None):
+        """Train a logistic regression model on the given data."""
+        return _regression_train_wrapper(sc, lambda d, i:
+                sc._jvm.PythonMLLibAPI().trainLogisticRegressionModelWithSGD(d._jrdd,
+                        iterations, step, mini_batch_fraction, i),
+                LogisticRegressionModel, data, initial_weights)
+
+class SVMModel(LinearModel):
+    """A support vector machine.
+
+    >>> data = array([0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 1.0, 3.0]).reshape(4,2)
+    >>> svm = SVMWithSGD.train(sc, sc.parallelize(data))
+    >>> svm.predict(array([1.0])) != None
+    True
+    """
+    def predict(self, x):
+        _linear_predictor_typecheck(x, self._coeff)
+        margin = dot(x, self._coeff) + self._intercept
+        return 1 if margin >= 0 else 0
+
+class SVMWithSGD(object):
+    @classmethod
+    def train(cls, sc, data, iterations=100, step=1.0, reg_param=1.0,
+              mini_batch_fraction=1.0, initial_weights=None):
+        """Train a support vector machine on the given data."""
+        return _regression_train_wrapper(sc, lambda d, i:
+                sc._jvm.PythonMLLibAPI().trainSVMModelWithSGD(d._jrdd,
+                        iterations, step, reg_param, mini_batch_fraction, i),
+                SVMModel, data, initial_weights)
+
+def _test():
+    import doctest
+    globs = globals().copy()
+    globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
+    (failure_count, test_count) = doctest.testmod(globs=globs,
+            optionflags=doctest.ELLIPSIS)
+    globs['sc'].stop()
+    if failure_count:
+        exit(-1)
+
+if __name__ == "__main__":
+    _test()
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
new file mode 100644
index 0000000000000000000000000000000000000000..8cf20e591af7b70ad87880217421cffa50a13bf8
--- /dev/null
+++ b/python/pyspark/mllib/clustering.py
@@ -0,0 +1,79 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from numpy import array, dot
+from math import sqrt
+from pyspark import SparkContext
+from pyspark.mllib._common import \
+    _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \
+    _serialize_double_matrix, _deserialize_double_matrix, \
+    _serialize_double_vector, _deserialize_double_vector, \
+    _get_initial_weights, _serialize_rating, _regression_train_wrapper
+
+class KMeansModel(object):
+    """A clustering model derived from the k-means method.
+
+    >>> data = array([0.0,0.0, 1.0,1.0, 9.0,8.0, 8.0,9.0]).reshape(4,2)
+    >>> clusters = KMeans.train(sc, sc.parallelize(data), 2, maxIterations=10, runs=30, initialization_mode="random")
+    >>> clusters.predict(array([0.0, 0.0])) == clusters.predict(array([1.0, 1.0]))
+    True
+    >>> clusters.predict(array([8.0, 9.0])) == clusters.predict(array([9.0, 8.0]))
+    True
+    >>> clusters = KMeans.train(sc, sc.parallelize(data), 2)
+    """
+    def __init__(self, centers_):
+        self.centers = centers_
+
+    def predict(self, x):
+        """Find the cluster to which x belongs in this model."""
+        best = 0
+        best_distance = 1e75
+        for i in range(0, self.centers.shape[0]):
+            diff = x - self.centers[i]
+            distance = sqrt(dot(diff, diff))
+            if distance < best_distance:
+                best = i
+                best_distance = distance
+        return best
+
+class KMeans(object):
+    @classmethod
+    def train(cls, sc, data, k, maxIterations=100, runs=1,
+            initialization_mode="k-means||"):
+        """Train a k-means clustering model."""
+        dataBytes = _get_unmangled_double_vector_rdd(data)
+        ans = sc._jvm.PythonMLLibAPI().trainKMeansModel(dataBytes._jrdd,
+                k, maxIterations, runs, initialization_mode)
+        if len(ans) != 1:
+            raise RuntimeError("JVM call result had unexpected length")
+        elif type(ans[0]) != bytearray:
+            raise RuntimeError("JVM call result had first element of type "
+                    + type(ans[0]) + " which is not bytearray")
+        return KMeansModel(_deserialize_double_matrix(ans[0]))
+
+def _test():
+    import doctest
+    globs = globals().copy()
+    globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
+    (failure_count, test_count) = doctest.testmod(globs=globs,
+            optionflags=doctest.ELLIPSIS)
+    globs['sc'].stop()
+    if failure_count:
+        exit(-1)
+
+if __name__ == "__main__":
+    _test()
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py
new file mode 100644
index 0000000000000000000000000000000000000000..14d06cba2137fc6125bc419c467ed4b679ab3df4
--- /dev/null
+++ b/python/pyspark/mllib/recommendation.py
@@ -0,0 +1,74 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark import SparkContext
+from pyspark.mllib._common import \
+    _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \
+    _serialize_double_matrix, _deserialize_double_matrix, \
+    _serialize_double_vector, _deserialize_double_vector, \
+    _get_initial_weights, _serialize_rating, _regression_train_wrapper
+
+class MatrixFactorizationModel(object):
+    """A matrix factorisation model trained by regularized alternating
+    least-squares.
+
+    >>> r1 = (1, 1, 1.0)
+    >>> r2 = (1, 2, 2.0)
+    >>> r3 = (2, 1, 2.0)
+    >>> ratings = sc.parallelize([r1, r2, r3])
+    >>> model = ALS.trainImplicit(sc, ratings, 1)
+    >>> model.predict(2,2) is not None
+    True
+    """
+
+    def __init__(self, sc, java_model):
+        self._context = sc
+        self._java_model = java_model
+
+    def __del__(self):
+        self._context._gateway.detach(self._java_model)
+
+    def predict(self, user, product):
+        return self._java_model.predict(user, product)
+
+class ALS(object):
+    @classmethod
+    def train(cls, sc, ratings, rank, iterations=5, lambda_=0.01, blocks=-1):
+        ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating)
+        mod = sc._jvm.PythonMLLibAPI().trainALSModel(ratingBytes._jrdd,
+                rank, iterations, lambda_, blocks)
+        return MatrixFactorizationModel(sc, mod)
+
+    @classmethod
+    def trainImplicit(cls, sc, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01):
+        ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating)
+        mod = sc._jvm.PythonMLLibAPI().trainImplicitALSModel(ratingBytes._jrdd,
+                rank, iterations, lambda_, blocks, alpha)
+        return MatrixFactorizationModel(sc, mod)
+
+def _test():
+    import doctest
+    globs = globals().copy()
+    globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
+    (failure_count, test_count) = doctest.testmod(globs=globs,
+            optionflags=doctest.ELLIPSIS)
+    globs['sc'].stop()
+    if failure_count:
+        exit(-1)
+
+if __name__ == "__main__":
+    _test()
diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3a68b29e01cbd45f233a346d85c9855bf9ec74e
--- /dev/null
+++ b/python/pyspark/mllib/regression.py
@@ -0,0 +1,110 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from numpy import array, dot
+from pyspark import SparkContext
+from pyspark.mllib._common import \
+    _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \
+    _serialize_double_matrix, _deserialize_double_matrix, \
+    _serialize_double_vector, _deserialize_double_vector, \
+    _get_initial_weights, _serialize_rating, _regression_train_wrapper, \
+    _linear_predictor_typecheck
+
+class LinearModel(object):
+    """Something that has a vector of coefficients and an intercept."""
+    def __init__(self, coeff, intercept):
+        self._coeff = coeff
+        self._intercept = intercept
+
+class LinearRegressionModelBase(LinearModel):
+    """A linear regression model.
+
+    >>> lrmb = LinearRegressionModelBase(array([1.0, 2.0]), 0.1)
+    >>> abs(lrmb.predict(array([-1.03, 7.777])) - 14.624) < 1e-6
+    True
+    """
+    def predict(self, x):
+        """Predict the value of the dependent variable given a vector x"""
+        """containing values for the independent variables."""
+        _linear_predictor_typecheck(x, self._coeff)
+        return dot(self._coeff, x) + self._intercept
+
+class LinearRegressionModel(LinearRegressionModelBase):
+    """A linear regression model derived from a least-squares fit.
+
+    >>> data = array([0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 3.0]).reshape(4,2)
+    >>> lrm = LinearRegressionWithSGD.train(sc, sc.parallelize(data), initial_weights=array([1.0]))
+    """
+
+class LinearRegressionWithSGD(object):
+    @classmethod
+    def train(cls, sc, data, iterations=100, step=1.0,
+              mini_batch_fraction=1.0, initial_weights=None):
+        """Train a linear regression model on the given data."""
+        return _regression_train_wrapper(sc, lambda d, i:
+                sc._jvm.PythonMLLibAPI().trainLinearRegressionModelWithSGD(
+                        d._jrdd, iterations, step, mini_batch_fraction, i),
+                LinearRegressionModel, data, initial_weights)
+
+class LassoModel(LinearRegressionModelBase):
+    """A linear regression model derived from a least-squares fit with an
+    l_1 penalty term.
+
+    >>> data = array([0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 3.0]).reshape(4,2)
+    >>> lrm = LassoWithSGD.train(sc, sc.parallelize(data), initial_weights=array([1.0]))
+    """
+    
+class LassoWithSGD(object):
+    @classmethod
+    def train(cls, sc, data, iterations=100, step=1.0, reg_param=1.0,
+              mini_batch_fraction=1.0, initial_weights=None):
+        """Train a Lasso regression model on the given data."""
+        return _regression_train_wrapper(sc, lambda d, i:
+                sc._jvm.PythonMLLibAPI().trainLassoModelWithSGD(d._jrdd,
+                        iterations, step, reg_param, mini_batch_fraction, i),
+                LassoModel, data, initial_weights)
+
+class RidgeRegressionModel(LinearRegressionModelBase):
+    """A linear regression model derived from a least-squares fit with an
+    l_2 penalty term.
+
+    >>> data = array([0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 3.0]).reshape(4,2)
+    >>> lrm = RidgeRegressionWithSGD.train(sc, sc.parallelize(data), initial_weights=array([1.0]))
+    """
+
+class RidgeRegressionWithSGD(object):
+    @classmethod
+    def train(cls, sc, data, iterations=100, step=1.0, reg_param=1.0,
+              mini_batch_fraction=1.0, initial_weights=None):
+        """Train a ridge regression model on the given data."""
+        return _regression_train_wrapper(sc, lambda d, i:
+                sc._jvm.PythonMLLibAPI().trainRidgeModelWithSGD(d._jrdd,
+                        iterations, step, reg_param, mini_batch_fraction, i),
+                RidgeRegressionModel, data, initial_weights)
+
+def _test():
+    import doctest
+    globs = globals().copy()
+    globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
+    (failure_count, test_count) = doctest.testmod(globs=globs,
+            optionflags=doctest.ELLIPSIS)
+    globs['sc'].stop()
+    if failure_count:
+        exit(-1)
+
+if __name__ == "__main__":
+    _test()
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 811fa6f018b23f3c9883bd2a770f03c044786850..2a500ab919beaf3b8fcc4c50e47de610d3af7b0b 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -308,4 +308,4 @@ def write_int(value, stream):
 
 def write_with_length(obj, stream):
     write_int(len(obj), stream)
-    stream.write(obj)
\ No newline at end of file
+    stream.write(obj)
diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py
index a47595909029754590cf2a4fac27e4d3b0d6b17a..ef07eb437baac423945a41ccdbb8f6b3652e7730 100644
--- a/python/pyspark/shell.py
+++ b/python/pyspark/shell.py
@@ -42,7 +42,7 @@ print "Using Python version %s (%s, %s)" % (
     platform.python_version(),
     platform.python_build()[0],
     platform.python_build()[1])
-print "Spark context avaiable as sc."
+print "Spark context available as sc."
 
 if add_files != None:
     print "Adding files: [%s]" % ", ".join(add_files)
diff --git a/spark-class b/spark-class
index 4eb95a9ba22248cebec651165ea9ebc79a5654f1..802e4aa1045e483fc96f061b44640d53f781df9b 100755
--- a/spark-class
+++ b/spark-class
@@ -129,11 +129,11 @@ fi
 
 # Compute classpath using external script
 CLASSPATH=`$FWDIR/bin/compute-classpath.sh`
-CLASSPATH="$SPARK_TOOLS_JAR:$CLASSPATH"
+CLASSPATH="$CLASSPATH:$SPARK_TOOLS_JAR"
 
 if $cygwin; then
-    CLASSPATH=`cygpath -wp $CLASSPATH`
-    export SPARK_TOOLS_JAR=`cygpath -w $SPARK_TOOLS_JAR`
+  CLASSPATH=`cygpath -wp $CLASSPATH`
+  export SPARK_TOOLS_JAR=`cygpath -w $SPARK_TOOLS_JAR`
 fi
 export CLASSPATH
 
diff --git a/spark-class2.cmd b/spark-class2.cmd
index 3869d0761bfaa8e7ba0e3688b1ec23f8e8a56d87..dc9dadf356e2626f251abac3fec0925147c03a92 100644
--- a/spark-class2.cmd
+++ b/spark-class2.cmd
@@ -17,7 +17,7 @@ rem See the License for the specific language governing permissions and
 rem limitations under the License.
 rem
 
-set SCALA_VERSION=2.9.3
+set SCALA_VERSION=2.10
 
 rem Figure out where the Spark framework is installed
 set FWDIR=%~dp0
@@ -75,7 +75,7 @@ rem Compute classpath using external script
 set DONT_PRINT_CLASSPATH=1
 call "%FWDIR%bin\compute-classpath.cmd"
 set DONT_PRINT_CLASSPATH=0
-set CLASSPATH=%SPARK_TOOLS_JAR%;%CLASSPATH%
+set CLASSPATH=%CLASSPATH%;%SPARK_TOOLS_JAR%
 
 rem Figure out where java is.
 set RUNNER=java
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
index f106bba678e3504f129f7a162645f05ad3412b51..35e23c1355abf1764ab9ad3a13049241bc9ce203 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
@@ -39,9 +39,9 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time)
   val graph = ssc.graph
   val checkpointDir = ssc.checkpointDir
   val checkpointDuration = ssc.checkpointDuration
-  val pendingTimes = ssc.scheduler.jobManager.getPendingTimes()
+  val pendingTimes = ssc.scheduler.getPendingTimes()
   val delaySeconds = MetadataCleaner.getDelaySeconds(ssc.conf)
-  val sparkConf = ssc.sc.conf
+  val sparkConf = ssc.conf
 
   def validate() {
     assert(master != null, "Checkpoint.master is null")
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala
index 8005202500f479b24bb8cb310684231e05483318..ce2a9d414285598b978c8d39bbc30859b80e594f 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala
@@ -17,24 +17,19 @@
 
 package org.apache.spark.streaming
 
-import org.apache.spark.streaming.dstream._
 import StreamingContext._
-import org.apache.spark.util.MetadataCleaner
-
-//import Time._
-
+import org.apache.spark.streaming.dstream._
+import org.apache.spark.streaming.scheduler.Job
 import org.apache.spark.Logging
 import org.apache.spark.rdd.RDD
 import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.MetadataCleaner
 
-import scala.collection.mutable.ArrayBuffer
 import scala.collection.mutable.HashMap
 import scala.reflect.ClassTag
 
 import java.io.{ObjectInputStream, IOException, ObjectOutputStream}
 
-import org.apache.hadoop.fs.{FileSystem, Path}
-import org.apache.hadoop.conf.Configuration
 
 /**
  * A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
index b9a58fded67614d40d7a4363badb6ac30dc844fd..daed7ff7c3f1385489c8f691591f100087a39249 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
@@ -21,6 +21,7 @@ import dstream.InputDStream
 import java.io.{ObjectInputStream, IOException, ObjectOutputStream}
 import collection.mutable.ArrayBuffer
 import org.apache.spark.Logging
+import org.apache.spark.streaming.scheduler.Job
 
 final private[streaming] class DStreamGraph extends Serializable with Logging {
   initLogging()
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/JobManager.scala b/streaming/src/main/scala/org/apache/spark/streaming/JobManager.scala
deleted file mode 100644
index 5233129506f9e8b2cdd4b0a208d8d01972564d0d..0000000000000000000000000000000000000000
--- a/streaming/src/main/scala/org/apache/spark/streaming/JobManager.scala
+++ /dev/null
@@ -1,88 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.streaming
-
-import org.apache.spark.Logging
-import org.apache.spark.SparkEnv
-import java.util.concurrent.Executors
-import collection.mutable.HashMap
-import collection.mutable.ArrayBuffer
-
-
-private[streaming]
-class JobManager(ssc: StreamingContext, numThreads: Int = 1) extends Logging {
-  
-  class JobHandler(ssc: StreamingContext, job: Job) extends Runnable {
-    def run() {
-      SparkEnv.set(ssc.env)
-      try {
-        val timeTaken = job.run()
-        logInfo("Total delay: %.5f s for job %s of time %s (execution: %.5f s)".format(
-          (System.currentTimeMillis() - job.time.milliseconds) / 1000.0, job.id, job.time.milliseconds, timeTaken / 1000.0))
-      } catch {
-        case e: Exception =>
-          logError("Running " + job + " failed", e)
-      }
-      clearJob(job)
-    }
-  }
-
-  initLogging()
-
-  val jobExecutor = Executors.newFixedThreadPool(numThreads) 
-  val jobs = new HashMap[Time, ArrayBuffer[Job]]
-
-  def runJob(job: Job) {
-    jobs.synchronized {
-      jobs.getOrElseUpdate(job.time, new ArrayBuffer[Job]) += job
-    }
-    jobExecutor.execute(new JobHandler(ssc, job))
-    logInfo("Added " + job + " to queue")
-  }
-
-  def stop() {
-    jobExecutor.shutdown()
-  }
-
-  private def clearJob(job: Job) {
-    var timeCleared = false
-    val time = job.time
-    jobs.synchronized {
-      val jobsOfTime = jobs.get(time)
-      if (jobsOfTime.isDefined) {
-        jobsOfTime.get -= job
-        if (jobsOfTime.get.isEmpty) {
-          jobs -= time
-          timeCleared = true
-        }
-      } else {
-        throw new Exception("Job finished for time " + job.time +
-          " but time does not exist in jobs")
-      }
-    }
-    if (timeCleared) {
-      ssc.scheduler.clearOldMetadata(time)
-    }
-  }
-
-  def getPendingTimes(): Array[Time] = {
-    jobs.synchronized {
-      jobs.keySet.toArray
-    }
-  }
-}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
index 286ec285a9d6836b71c79f1d81195617ad0b1d0b..339f6e64a20b8e334845e128fd5737f30c719abb 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
@@ -47,9 +47,9 @@ import org.apache.hadoop.mapreduce.lib.input.TextInputFormat
 import org.apache.hadoop.fs.Path
 import twitter4j.Status
 import twitter4j.auth.Authorization
+import org.apache.spark.streaming.scheduler._
 import akka.util.ByteString
 
-
 /**
  * A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic
  * information (such as, cluster URL and job name) to internally create a SparkContext, it provides
@@ -160,9 +160,10 @@ class StreamingContext private (
     }
   }
 
-  protected[streaming] var checkpointDuration: Duration = if (isCheckpointPresent) cp_.checkpointDuration else null
-  protected[streaming] var receiverJobThread: Thread = null
-  protected[streaming] var scheduler: Scheduler = null
+  protected[streaming] val checkpointDuration: Duration = {
+    if (isCheckpointPresent) cp_.checkpointDuration else graph.batchDuration
+  }
+  protected[streaming] val scheduler = new JobScheduler(this)
 
   /**
    * Return the associated Spark context
@@ -524,6 +525,13 @@ class StreamingContext private (
     graph.addOutputStream(outputStream)
   }
 
+  /** Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] object for
+    * receiving system events related to streaming.
+    */
+  def addStreamingListener(streamingListener: StreamingListener) {
+    scheduler.listenerBus.addListener(streamingListener)
+  }
+
   protected def validate() {
     assert(graph != null, "Graph is null")
     graph.validate()
@@ -539,27 +547,22 @@ class StreamingContext private (
    * Start the execution of the streams.
    */
   def start() {
-    if (checkpointDir != null && checkpointDuration == null && graph != null) {
-      checkpointDuration = graph.batchDuration
-    }
-
     validate()
 
+    // Get the network input streams
     val networkInputStreams = graph.getInputStreams().filter(s => s match {
         case n: NetworkInputDStream[_] => true
         case _ => false
       }).map(_.asInstanceOf[NetworkInputDStream[_]]).toArray
 
+    // Start the network input tracker (must start before receivers)
     if (networkInputStreams.length > 0) {
-      // Start the network input tracker (must start before receivers)
       networkInputTracker = new NetworkInputTracker(this, networkInputStreams)
       networkInputTracker.start()
     }
-
     Thread.sleep(1000)
 
     // Start the scheduler
-    scheduler = new Scheduler(this)
     scheduler.start()
   }
 
@@ -570,7 +573,6 @@ class StreamingContext private (
     try {
       if (scheduler != null) scheduler.stop()
       if (networkInputTracker != null) networkInputTracker.stop()
-      if (receiverJobThread != null) receiverJobThread.interrupt()
       sc.stop()
       logInfo("StreamingContext stopped successfully")
     } catch {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
index 5842a7cd68fa2409d6812f76445e971dcd5ee0e9..29f673d8ae61c1e2a28e47cb24135ed07dc5b9df 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
@@ -40,6 +40,7 @@ import org.apache.spark.api.java.{JavaPairRDD, JavaSparkContext, JavaRDD}
 import org.apache.spark.streaming._
 import org.apache.spark.streaming.dstream._
 import org.apache.spark.SparkConf
+import org.apache.spark.streaming.scheduler.StreamingListener
 
 /**
  * A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic
@@ -696,6 +697,13 @@ class JavaStreamingContext(val ssc: StreamingContext) {
     ssc.remember(duration)
   }
 
+  /** Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] object for
+    * receiving system events related to streaming.
+    */
+  def addStreamingListener(streamingListener: StreamingListener) {
+    ssc.addStreamingListener(streamingListener)
+  }
+
   /**
    * Starts the execution of the streams.
    */
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala
index 98b14cb224263778e3edfeb3aa1ad377f69f7bfc..364abcde68c95125d887a6ed0b40ad52611b63eb 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala
@@ -18,7 +18,8 @@
 package org.apache.spark.streaming.dstream
 
 import org.apache.spark.rdd.RDD
-import org.apache.spark.streaming.{Duration, DStream, Job, Time}
+import org.apache.spark.streaming.{Duration, DStream, Time}
+import org.apache.spark.streaming.scheduler.Job
 import scala.reflect.ClassTag
 
 private[streaming]
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
index bd607f9d18718b61e830d2ace34a6cf10a4d8239..1839ca35783e37fc3327451c1876cc317ba851d3 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
@@ -33,6 +33,7 @@ import org.apache.spark.streaming._
 import org.apache.spark.{Logging, SparkEnv}
 import org.apache.spark.rdd.{RDD, BlockRDD}
 import org.apache.spark.storage.{BlockId, StorageLevel, StreamBlockId}
+import org.apache.spark.streaming.scheduler.{DeregisterReceiver, AddBlocks, RegisterReceiver}
 
 /**
  * Abstract class for defining any InputDStream that has to start a receiver on worker
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala
new file mode 100644
index 0000000000000000000000000000000000000000..4e8d07fe921fbcf2e06ceca5f21e04aced35a6c1
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.scheduler
+
+import org.apache.spark.streaming.Time
+
+/**
+ * Class having information on completed batches.
+ * @param batchTime   Time of the batch
+ * @param submissionTime  Clock time of when jobs of this batch was submitted to
+ *                        the streaming scheduler queue
+ * @param processingStartTime Clock time of when the first job of this batch started processing
+ * @param processingEndTime Clock time of when the last job of this batch finished processing
+ */
+case class BatchInfo(
+    batchTime: Time,
+    submissionTime: Long,
+    processingStartTime: Option[Long],
+    processingEndTime: Option[Long]
+  ) {
+
+  /**
+   * Time taken for the first job of this batch to start processing from the time this batch
+   * was submitted to the streaming scheduler. Essentially, it is
+   * `processingStartTime` - `submissionTime`.
+   */
+  def schedulingDelay = processingStartTime.map(_ - submissionTime)
+
+  /**
+   * Time taken for the all jobs of this batch to finish processing from the time they started
+   * processing. Essentially, it is `processingEndTime` - `processingStartTime`.
+   */
+  def processingDelay = processingEndTime.zip(processingStartTime).map(x => x._1 - x._2).headOption
+
+  /**
+   * Time taken for all the jobs of this batch to finish processing from the time they
+   * were submitted.  Essentially, it is `processingDelay` + `schedulingDelay`.
+   */
+  def totalDelay = schedulingDelay.zip(processingDelay).map(x => x._1 + x._2).headOption
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Job.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala
similarity index 77%
rename from streaming/src/main/scala/org/apache/spark/streaming/Job.scala
rename to streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala
index 2128b7c7a64c27c98a8d88db6d27f801b8cf606e..7341bfbc99399b94a1143e12752c1120bf3fbdb3 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Job.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala
@@ -15,13 +15,17 @@
  * limitations under the License.
  */
 
-package org.apache.spark.streaming
+package org.apache.spark.streaming.scheduler
 
-import java.util.concurrent.atomic.AtomicLong
+import org.apache.spark.streaming.Time
 
+/**
+ * Class representing a Spark computation. It may contain multiple Spark jobs.
+ */
 private[streaming]
 class Job(val time: Time, func: () => _) {
-  val id = Job.getNewId()
+  var id: String = _
+
   def run(): Long = {
     val startTime = System.currentTimeMillis 
     func() 
@@ -29,13 +33,9 @@ class Job(val time: Time, func: () => _) {
     (stopTime - startTime)
   }
 
-  override def toString = "streaming job " + id + " @ " + time 
-}
-
-private[streaming]
-object Job {
-  val id = new AtomicLong(0)
-
-  def getNewId() = id.getAndIncrement()
-}
+  def setId(number: Int) {
+    id = "streaming job " + time + "." + number
+  }
 
+  override def toString = id
+}
\ No newline at end of file
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Scheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
similarity index 76%
rename from streaming/src/main/scala/org/apache/spark/streaming/Scheduler.scala
rename to streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
index 82ed6bed6987d085ce8ca3ba31505ecb96fc9620..dbd08415a1d0b5d2dd137784ed4ad79a6436028c 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Scheduler.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
@@ -15,31 +15,35 @@
  * limitations under the License.
  */
 
-package org.apache.spark.streaming
+package org.apache.spark.streaming.scheduler
 
-import util.{ManualClock, RecurringTimer, Clock}
 import org.apache.spark.SparkEnv
 import org.apache.spark.Logging
+import org.apache.spark.streaming.{Checkpoint, Time, CheckpointWriter}
+import org.apache.spark.streaming.util.{ManualClock, RecurringTimer, Clock}
 
+/**
+ * This class generates jobs from DStreams as well as drives checkpointing and cleaning
+ * up DStream metadata.
+ */
 private[streaming]
-class Scheduler(ssc: StreamingContext) extends Logging {
+class JobGenerator(jobScheduler: JobScheduler) extends Logging {
 
   initLogging()
 
-  val concurrentJobs = ssc.sc.conf.getOrElse("spark.streaming.concurrentJobs", "1").toInt
-  val jobManager = new JobManager(ssc, concurrentJobs)
-  val checkpointWriter = if (ssc.checkpointDuration != null && ssc.checkpointDir != null) {
-    new CheckpointWriter(ssc.conf, ssc.checkpointDir)
-  } else {
-    null
-  }
-
+  val ssc = jobScheduler.ssc
   val clockClass = ssc.sc.conf.getOrElse(
     "spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock")
   val clock = Class.forName(clockClass).newInstance().asInstanceOf[Clock]
   val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds,
     longTime => generateJobs(new Time(longTime)))
   val graph = ssc.graph
+  lazy val checkpointWriter = if (ssc.checkpointDuration != null && ssc.checkpointDir != null) {
+    new CheckpointWriter(ssc.conf, ssc.checkpointDir)
+  } else {
+    null
+  }
+
   var latestTime: Time = null
 
   def start() = synchronized {
@@ -48,26 +52,24 @@ class Scheduler(ssc: StreamingContext) extends Logging {
     } else {
       startFirstTime()
     }
-    logInfo("Scheduler started")
+    logInfo("JobGenerator started")
   }
 
   def stop() = synchronized {
     timer.stop()
-    jobManager.stop()
     if (checkpointWriter != null) checkpointWriter.stop()
     ssc.graph.stop()
-    logInfo("Scheduler stopped")
+    logInfo("JobGenerator stopped")
   }
 
   private def startFirstTime() {
     val startTime = new Time(timer.getStartTime())
     graph.start(startTime - graph.batchDuration)
     timer.start(startTime.milliseconds)
-    logInfo("Scheduler's timer started at " + startTime)
+    logInfo("JobGenerator's timer started at " + startTime)
   }
 
   private def restart() {
-
     // If manual clock is being used for testing, then
     // either set the manual clock to the last checkpointed time,
     // or if the property is defined set it to that time
@@ -93,35 +95,34 @@ class Scheduler(ssc: StreamingContext) extends Logging {
     val timesToReschedule = (pendingTimes ++ downTimes).distinct.sorted(Time.ordering)
     logInfo("Batches to reschedule: " + timesToReschedule.mkString(", "))
     timesToReschedule.foreach(time =>
-      graph.generateJobs(time).foreach(jobManager.runJob)
+      jobScheduler.runJobs(time, graph.generateJobs(time))
     )
 
     // Restart the timer
     timer.start(restartTime.milliseconds)
-    logInfo("Scheduler's timer restarted at " + restartTime)
+    logInfo("JobGenerator's timer restarted at " + restartTime)
   }
 
   /** Generate jobs and perform checkpoint for the given `time`.  */
-  def generateJobs(time: Time) {
+  private def generateJobs(time: Time) {
     SparkEnv.set(ssc.env)
     logInfo("\n-----------------------------------------------------\n")
-    graph.generateJobs(time).foreach(jobManager.runJob)
+    jobScheduler.runJobs(time, graph.generateJobs(time))
     latestTime = time
     doCheckpoint(time)
   }
 
   /**
-   * Clear old metadata assuming jobs of `time` have finished processing.
-   * And also perform checkpoint.
+   * On batch completion, clear old metadata and checkpoint computation.
    */
-  def clearOldMetadata(time: Time) {
+  private[streaming] def onBatchCompletion(time: Time) {
     ssc.graph.clearOldMetadata(time)
     doCheckpoint(time)
   }
 
   /** Perform checkpoint for the give `time`. */
-  def doCheckpoint(time: Time) = synchronized {
-    if (ssc.checkpointDuration != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointDuration)) {
+  private def doCheckpoint(time: Time) = synchronized {
+    if (checkpointWriter != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointDuration)) {
       logInfo("Checkpointing graph for time " + time)
       ssc.graph.updateCheckpointData(time)
       checkpointWriter.write(new Checkpoint(ssc, time))
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
new file mode 100644
index 0000000000000000000000000000000000000000..9511ccfbeddd6132b455882bddf65c8bd82bb5e6
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.scheduler
+
+import org.apache.spark.Logging
+import org.apache.spark.SparkEnv
+import java.util.concurrent.{TimeUnit, ConcurrentHashMap, Executors}
+import scala.collection.mutable.HashSet
+import org.apache.spark.streaming._
+
+/**
+ * This class schedules jobs to be run on Spark. It uses the JobGenerator to generate
+ * the jobs and runs them using a thread pool. Number of threads 
+ */
+private[streaming]
+class JobScheduler(val ssc: StreamingContext) extends Logging {
+
+  initLogging()
+
+  val jobSets = new ConcurrentHashMap[Time, JobSet]
+  val numConcurrentJobs = System.getProperty("spark.streaming.concurrentJobs", "1").toInt
+  val executor = Executors.newFixedThreadPool(numConcurrentJobs)
+  val generator = new JobGenerator(this)
+  val listenerBus = new StreamingListenerBus()
+
+  def clock = generator.clock
+
+  def start() {
+    generator.start()
+  }
+
+  def stop() {
+    generator.stop()
+    executor.shutdown()
+    if (!executor.awaitTermination(5, TimeUnit.SECONDS)) {
+      executor.shutdownNow()
+    }
+  }
+
+  def runJobs(time: Time, jobs: Seq[Job]) {
+    if (jobs.isEmpty) {
+      logInfo("No jobs added for time " + time)
+    } else {
+      val jobSet = new JobSet(time, jobs)
+      jobSets.put(time, jobSet)
+      jobSet.jobs.foreach(job => executor.execute(new JobHandler(job)))
+      logInfo("Added jobs for time " + time)
+    }
+  }
+
+  def getPendingTimes(): Array[Time] = {
+    jobSets.keySet.toArray(new Array[Time](0))
+  }
+
+  private def beforeJobStart(job: Job) {
+    val jobSet = jobSets.get(job.time)
+    if (!jobSet.hasStarted) {
+      listenerBus.post(StreamingListenerBatchStarted(jobSet.toBatchInfo()))
+    }
+    jobSet.beforeJobStart(job)
+    logInfo("Starting job " + job.id + " from job set of time " + jobSet.time)
+    SparkEnv.set(generator.ssc.env)
+  }
+
+  private def afterJobEnd(job: Job) {
+    val jobSet = jobSets.get(job.time)
+    jobSet.afterJobStop(job)
+    logInfo("Finished job " + job.id + " from job set of time " + jobSet.time)
+    if (jobSet.hasCompleted) {
+      jobSets.remove(jobSet.time)
+      generator.onBatchCompletion(jobSet.time)
+      logInfo("Total delay: %.3f s for time %s (execution: %.3f s)".format(
+        jobSet.totalDelay / 1000.0, jobSet.time.toString,
+        jobSet.processingDelay / 1000.0
+      ))
+      listenerBus.post(StreamingListenerBatchCompleted(jobSet.toBatchInfo()))
+    }
+  }
+
+  private[streaming]
+  class JobHandler(job: Job) extends Runnable {
+    def run() {
+      beforeJobStart(job)
+      try {
+        job.run()
+      } catch {
+        case e: Exception =>
+          logError("Running " + job + " failed", e)
+      }
+      afterJobEnd(job)
+    }
+  }
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala
new file mode 100644
index 0000000000000000000000000000000000000000..57268674ead9dd22a9c77a941f0195544400999c
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.scheduler
+
+import scala.collection.mutable.HashSet
+import org.apache.spark.streaming.Time
+
+/** Class representing a set of Jobs
+  * belong to the same batch.
+  */
+private[streaming]
+case class JobSet(time: Time, jobs: Seq[Job]) {
+
+  private val incompleteJobs = new HashSet[Job]()
+  var submissionTime = System.currentTimeMillis() // when this jobset was submitted
+  var processingStartTime = -1L // when the first job of this jobset started processing
+  var processingEndTime = -1L // when the last job of this jobset finished processing
+
+  jobs.zipWithIndex.foreach { case (job, i) => job.setId(i) }
+  incompleteJobs ++= jobs
+
+  def beforeJobStart(job: Job) {
+    if (processingStartTime < 0) processingStartTime = System.currentTimeMillis()
+  }
+
+  def afterJobStop(job: Job) {
+    incompleteJobs -= job
+    if (hasCompleted) processingEndTime = System.currentTimeMillis()
+  }
+
+  def hasStarted() = (processingStartTime > 0)
+
+  def hasCompleted() = incompleteJobs.isEmpty
+
+  // Time taken to process all the jobs from the time they started processing
+  // (i.e. not including the time they wait in the streaming scheduler queue)
+  def processingDelay = processingEndTime - processingStartTime
+
+  // Time taken to process all the jobs from the time they were submitted
+  // (i.e. including the time they wait in the streaming scheduler queue)
+  def totalDelay = {
+    processingEndTime - time.milliseconds
+  }
+
+  def toBatchInfo(): BatchInfo = {
+    new BatchInfo(
+      time,
+      submissionTime,
+      if (processingStartTime >= 0 ) Some(processingStartTime) else None,
+      if (processingEndTime >= 0 ) Some(processingEndTime) else None
+    )
+  }
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala
similarity index 98%
rename from streaming/src/main/scala/org/apache/spark/streaming/NetworkInputTracker.scala
rename to streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala
index 6e9a781978d2054f5836b53e321a3ec27521f314..abff55d77c829b5063e7c52ff51dfec4fac53c1a 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/NetworkInputTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala
@@ -15,7 +15,7 @@
  * limitations under the License.
  */
 
-package org.apache.spark.streaming
+package org.apache.spark.streaming.scheduler
 
 import org.apache.spark.streaming.dstream.{NetworkInputDStream, NetworkReceiver}
 import org.apache.spark.streaming.dstream.{StopReceiver, ReportBlock, ReportError}
@@ -31,6 +31,7 @@ import akka.actor._
 import akka.pattern.ask
 import akka.dispatch._
 import org.apache.spark.storage.BlockId
+import org.apache.spark.streaming.{Time, StreamingContext}
 
 private[streaming] sealed trait NetworkInputTrackerMessage
 private[streaming] case class RegisterReceiver(streamId: Int, receiverActor: ActorRef) extends NetworkInputTrackerMessage
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala
new file mode 100644
index 0000000000000000000000000000000000000000..36225e190cd7917502f23debf6a7b9c77b14743e
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.scheduler
+
+import scala.collection.mutable.Queue
+import org.apache.spark.util.Distribution
+
+/** Base trait for events related to StreamingListener */
+sealed trait StreamingListenerEvent
+
+case class StreamingListenerBatchCompleted(batchInfo: BatchInfo) extends StreamingListenerEvent
+
+case class StreamingListenerBatchStarted(batchInfo: BatchInfo) extends StreamingListenerEvent
+
+
+/**
+ * A listener interface for receiving information about an ongoing streaming
+ * computation.
+ */
+trait StreamingListener {
+  /**
+   * Called when processing of a batch has completed
+   */
+  def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) { }
+
+  /**
+   * Called when processing of a batch has started
+   */
+  def onBatchStarted(batchStarted: StreamingListenerBatchStarted) { }
+}
+
+
+/**
+ * A simple StreamingListener that logs summary statistics across Spark Streaming batches
+ * @param numBatchInfos Number of last batches to consider for generating statistics (default: 10)
+ */
+class StatsReportListener(numBatchInfos: Int = 10) extends StreamingListener {
+  // Queue containing latest completed batches
+  val batchInfos = new Queue[BatchInfo]()
+
+  override def onBatchCompleted(batchStarted: StreamingListenerBatchCompleted) {
+    batchInfos.enqueue(batchStarted.batchInfo)
+    if (batchInfos.size > numBatchInfos) batchInfos.dequeue()
+    printStats()
+  }
+
+  def printStats() {
+    showMillisDistribution("Total delay: ", _.totalDelay)
+    showMillisDistribution("Processing time: ", _.processingDelay)
+  }
+
+  def showMillisDistribution(heading: String, getMetric: BatchInfo => Option[Long]) {
+    org.apache.spark.scheduler.StatsReportListener.showMillisDistribution(
+      heading, extractDistribution(getMetric))
+  }
+
+  def extractDistribution(getMetric: BatchInfo => Option[Long]): Option[Distribution] = {
+    Distribution(batchInfos.flatMap(getMetric(_)).map(_.toDouble))
+  }
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala
new file mode 100644
index 0000000000000000000000000000000000000000..110a20f282f110879ad7836399f2f9e3784a1ac1
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.scheduler
+
+import org.apache.spark.Logging
+import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer}
+import java.util.concurrent.LinkedBlockingQueue
+
+/** Asynchronously passes StreamingListenerEvents to registered StreamingListeners. */
+private[spark] class StreamingListenerBus() extends Logging {
+  private val listeners = new ArrayBuffer[StreamingListener]() with SynchronizedBuffer[StreamingListener]
+
+  /* Cap the capacity of the SparkListenerEvent queue so we get an explicit error (rather than
+   * an OOM exception) if it's perpetually being added to more quickly than it's being drained. */
+  private val EVENT_QUEUE_CAPACITY = 10000
+  private val eventQueue = new LinkedBlockingQueue[StreamingListenerEvent](EVENT_QUEUE_CAPACITY)
+  private var queueFullErrorMessageLogged = false
+
+  new Thread("StreamingListenerBus") {
+    setDaemon(true)
+    override def run() {
+      while (true) {
+        val event = eventQueue.take
+        event match {
+          case batchStarted: StreamingListenerBatchStarted =>
+            listeners.foreach(_.onBatchStarted(batchStarted))
+          case batchCompleted: StreamingListenerBatchCompleted =>
+            listeners.foreach(_.onBatchCompleted(batchCompleted))
+          case _ =>
+        }
+      }
+    }
+  }.start()
+
+  def addListener(listener: StreamingListener) {
+    listeners += listener
+  }
+
+  def post(event: StreamingListenerEvent) {
+    val eventAdded = eventQueue.offer(event)
+    if (!eventAdded && !queueFullErrorMessageLogged) {
+      logError("Dropping SparkListenerEvent because no remaining room in event queue. " +
+        "This likely means one of the SparkListeners is too slow and cannot keep up with the " +
+        "rate at which tasks are being started by the scheduler.")
+      queueFullErrorMessageLogged = true
+    }
+  }
+
+  /**
+   * Waits until there are no more events in the queue, or until the specified time has elapsed.
+   * Used for testing only. Returns true if the queue has emptied and false is the specified time
+   * elapsed before the queue emptied.
+   */
+  def waitUntilEmpty(timeoutMillis: Int): Boolean = {
+    val finishTime = System.currentTimeMillis + timeoutMillis
+    while (!eventQueue.isEmpty()) {
+      if (System.currentTimeMillis > finishTime) {
+        return false
+      }
+      /* Sleep rather than using wait/notify, because this is used only for testing and wait/notify
+       * add overhead in the general case. */
+      Thread.sleep(10)
+    }
+    return true
+  }
+}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
index 60e986cb9d68b69199fa9234467027789b8f4e0c..ee6b433d1f1fa2d05a33991cad2d7cdf0b81a7c5 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
@@ -26,17 +26,6 @@ import util.ManualClock
 import org.apache.spark.{SparkContext, SparkConf}
 
 class BasicOperationsSuite extends TestSuiteBase {
-
-  override def framework = "BasicOperationsSuite"
-
-  conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock")
-
-  after {
-    // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
-    System.clearProperty("spark.driver.port")
-    System.clearProperty("spark.hostPort")
-  }
-
   test("map") {
     val input = Seq(1 to 4, 5 to 8, 9 to 12)
     testOperation(
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
index ca230fd056a6ab0394e545a0416c541221aea735..c60a3f53905949ba19da2d7a4a238ff3bc816aed 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
@@ -40,29 +40,25 @@ import org.apache.spark.streaming.util.ManualClock
  * the checkpointing of a DStream's RDDs as well as the checkpointing of
  * the whole DStream graph.
  */
-class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
+class CheckpointSuite extends TestSuiteBase {
 
-  before {
+  var ssc: StreamingContext = null
+
+  override def batchDuration = Milliseconds(500)
+
+  override def actuallyWait = true // to allow checkpoints to be written
+
+  override def beforeFunction() {
+    super.beforeFunction()
     FileUtils.deleteDirectory(new File(checkpointDir))
   }
 
-  after {
+  override def afterFunction() {
+    super.afterFunction()
     if (ssc != null) ssc.stop()
     FileUtils.deleteDirectory(new File(checkpointDir))
-
-    // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
-    System.clearProperty("spark.driver.port")
-    System.clearProperty("spark.hostPort")
   }
 
-  var ssc: StreamingContext = null
-
-  override def framework = "CheckpointSuite"
-
-  override def batchDuration = Milliseconds(500)
-
-  override def actuallyWait = true
-
   test("basic rdd checkpoints + dstream graph checkpoint recovery") {
 
     assert(batchDuration === Milliseconds(500), "batchDuration for this test must be 1 second")
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala
index 6337c5359c3dcac1d8206d8881b6b28864484522..da9b04de1ac44ee4299bb58003412718de0eb545 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala
@@ -32,17 +32,22 @@ import collection.mutable.ArrayBuffer
  * This testsuite tests master failures at random times while the stream is running using
  * the real clock.
  */
-class FailureSuite extends FunSuite with BeforeAndAfter with Logging {
+class FailureSuite extends TestSuiteBase with Logging {
 
   var directory = "FailureSuite"
   val numBatches = 30
-  val batchDuration = Milliseconds(1000)
 
-  before {
+  override def batchDuration = Milliseconds(1000)
+
+  override def useManualClock = false
+
+  override def beforeFunction() {
+    super.beforeFunction()
     FileUtils.deleteDirectory(new File(directory))
   }
 
-  after {
+  override def afterFunction() {
+    super.afterFunction()
     FileUtils.deleteDirectory(new File(directory))
   }
 
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
index 8c16daa21c8b445d98a6585cfa59f26347f99bd1..52381c10b077d3109317617e4662bb7fc330f994 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
@@ -50,16 +50,6 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
 
   val testPort = 9999
 
-  override def checkpointDir = "checkpoint"
-
-  conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock")
-
-  after {
-    // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
-    System.clearProperty("spark.driver.port")
-    System.clearProperty("spark.hostPort")
-  }
-
   test("socket input stream") {
     // Start the server
     val testServer = new TestServer()
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..fa6414209605405e2a70834409bb3851e10b6422
--- /dev/null
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming
+
+import org.apache.spark.streaming.scheduler._
+import scala.collection.mutable.ArrayBuffer
+import org.scalatest.matchers.ShouldMatchers
+
+class StreamingListenerSuite extends TestSuiteBase with ShouldMatchers {
+
+  val input = (1 to 4).map(Seq(_)).toSeq
+  val operation = (d: DStream[Int]) => d.map(x => x)
+
+  // To make sure that the processing start and end times in collected
+  // information are different for successive batches
+  override def batchDuration = Milliseconds(100)
+  override def actuallyWait = true
+
+  test("basic BatchInfo generation") {
+    val ssc = setupStreams(input, operation)
+    val collector = new BatchInfoCollector
+    ssc.addStreamingListener(collector)
+    runStreams(ssc, input.size, input.size)
+    val batchInfos = collector.batchInfos
+    batchInfos should have size 4
+
+    batchInfos.foreach(info => {
+      info.schedulingDelay should not be None
+      info.processingDelay should not be None
+      info.totalDelay should not be None
+      info.schedulingDelay.get should be >= 0L
+      info.processingDelay.get should be >= 0L
+      info.totalDelay.get should be >= 0L
+    })
+
+    isInIncreasingOrder(batchInfos.map(_.submissionTime)) should be (true)
+    isInIncreasingOrder(batchInfos.map(_.processingStartTime.get)) should be (true)
+    isInIncreasingOrder(batchInfos.map(_.processingEndTime.get)) should be (true)
+  }
+
+  /** Check if a sequence of numbers is in increasing order */
+  def isInIncreasingOrder(seq: Seq[Long]): Boolean = {
+    for(i <- 1 until seq.size) {
+      if (seq(i - 1) > seq(i)) return false
+    }
+    true
+  }
+
+  /** Listener that collects information on processed batches */
+  class BatchInfoCollector extends StreamingListener {
+    val batchInfos = new ArrayBuffer[BatchInfo]
+    override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) {
+      batchInfos += batchCompleted.batchInfo
+    }
+  }
+}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
index 3dd671849195cbc77ff1fd2c9396086dc2404342..33464bc3a1c76bb6545e6f5156f173231e41be16 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
@@ -110,7 +110,7 @@ class TestOutputStreamWithPartitions[T: ClassTag](parent: DStream[T],
 trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
 
   // Name of the framework for Spark context
-  def framework = "TestSuiteBase"
+  def framework = this.getClass.getSimpleName
 
   // Master for Spark context
   def master = "local[2]"
@@ -127,15 +127,45 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
   // Maximum time to wait before the test times out
   def maxWaitTimeMillis = 10000
 
+  // Whether to use manual clock or not
+  def useManualClock = true
+
   // Whether to actually wait in real time before changing manual clock
   def actuallyWait = false
 
-  // A SparkConf to use in tests. Can be modified before calling setupStreams to configure things.
+  //// A SparkConf to use in tests. Can be modified before calling setupStreams to configure things.
   val conf = new SparkConf()
     .setMaster(master)
     .setAppName(framework)
     .set("spark.cleaner.ttl", "3600")
 
+  // Default before function for any streaming test suite. Override this
+  // if you want to add your stuff to "before" (i.e., don't call before { } )
+  def beforeFunction() {
+    //if (useManualClock) {
+    //  System.setProperty(
+    //    "spark.streaming.clock",
+    //    "org.apache.spark.streaming.util.ManualClock"
+    //  )
+    //} else {
+    //  System.clearProperty("spark.streaming.clock")
+    //}
+    if (useManualClock) {
+      conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock")
+    }
+  }
+
+  // Default after function for any streaming test suite. Override this
+  // if you want to add your stuff to "after" (i.e., don't call after { } )
+  def afterFunction() {
+    // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
+    System.clearProperty("spark.driver.port")
+    System.clearProperty("spark.hostPort")
+  }
+
+  before(beforeFunction)
+  after(afterFunction)
+
   /**
    * Set up required DStreams to test the DStream operation using the two sequences
    * of input collections.
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala
index 3242c4cd11bd8763c66f74246d5f4d2400808f79..c92c34d49bd367da19f36cc0f16f64bb38970436 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala
@@ -21,19 +21,9 @@ import org.apache.spark.streaming.StreamingContext._
 
 class WindowOperationsSuite extends TestSuiteBase {
 
-  conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock")
+  override def maxWaitTimeMillis = 20000  // large window tests can sometimes take longer
 
-  override def framework = "WindowOperationsSuite"
-
-  override def maxWaitTimeMillis = 20000
-
-  override def batchDuration = Seconds(1)
-
-  after {
-    // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
-    System.clearProperty("spark.driver.port")
-    System.clearProperty("spark.hostPort")
-  }
+  override def batchDuration = Seconds(1)  // making sure its visible in this class
 
   val largerSlideInput = Seq(
     Seq(("a", 1)),
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index cc150888ebe380ee56c83d2e856413130cce8745..595a7ee8c3d83c329e07268a05675d4aac646812 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -422,8 +422,10 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
   }
 
   def monitorApplication(appId: ApplicationId): Boolean = {
+    val interval = new SparkConf().getOrElse("spark.yarn.report.interval", "1000").toLong
+
     while (true) {
-      Thread.sleep(1000)
+      Thread.sleep(interval)
       val report = super.getApplicationReport(appId)
 
       logInfo("Application report from ASM: \n" +
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala
index 4c9fee56955643099491ac108747e9f5060fc711..5966a0f7577572e277277f56c41928b4ce93d59b 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala
@@ -27,8 +27,8 @@ import scala.collection.JavaConversions._
 import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
 
 import org.apache.spark.Logging
-import org.apache.spark.scheduler.SplitInfo
-import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedulerBackend}
+import org.apache.spark.scheduler.{SplitInfo,TaskSchedulerImpl}
+import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
 import org.apache.spark.util.Utils
 
 import org.apache.hadoop.conf.Configuration
@@ -214,9 +214,9 @@ private[yarn] class YarnAllocationHandler(
       // host if there are sufficiently large number of hosts/containers.
 
       val allocatedContainers = new ArrayBuffer[Container](_allocatedContainers.size)
-      allocatedContainers ++= ClusterScheduler.prioritizeContainers(dataLocalContainers)
-      allocatedContainers ++= ClusterScheduler.prioritizeContainers(rackLocalContainers)
-      allocatedContainers ++= ClusterScheduler.prioritizeContainers(offRackContainers)
+      allocatedContainers ++= TaskSchedulerImpl.prioritizeContainers(dataLocalContainers)
+      allocatedContainers ++= TaskSchedulerImpl.prioritizeContainers(rackLocalContainers)
+      allocatedContainers ++= TaskSchedulerImpl.prioritizeContainers(offRackContainers)
 
       // Run each of the allocated containers
       for (container <- allocatedContainers) {
diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala
index 63a0449e5a0730085554d2b8ae86067135fa8dba..522e0a9ad7eeb50f4c2b6b781a68ea998639b30a 100644
--- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala
+++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala
@@ -20,13 +20,14 @@ package org.apache.spark.scheduler.cluster
 import org.apache.spark._
 import org.apache.hadoop.conf.Configuration
 import org.apache.spark.deploy.yarn.YarnAllocationHandler
+import org.apache.spark.scheduler.TaskSchedulerImpl
 import org.apache.spark.util.Utils
 
 /**
  *
  * This scheduler launch worker through Yarn - by call into Client to launch WorkerLauncher as AM.
  */
-private[spark] class YarnClientClusterScheduler(sc: SparkContext, conf: Configuration) extends ClusterScheduler(sc) {
+private[spark] class YarnClientClusterScheduler(sc: SparkContext, conf: Configuration) extends TaskSchedulerImpl(sc) {
 
   def this(sc: SparkContext) = this(sc, new Configuration())
 
diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
index 6feaaff01425606acf9c1da751b84201a6c95f59..4b69f5078b0ab10818f1d6ecd5d9655cd00327c8 100644
--- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
+++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
@@ -20,9 +20,10 @@ package org.apache.spark.scheduler.cluster
 import org.apache.hadoop.yarn.api.records.{ApplicationId, YarnApplicationState}
 import org.apache.spark.{SparkException, Logging, SparkContext}
 import org.apache.spark.deploy.yarn.{Client, ClientArguments}
+import org.apache.spark.scheduler.TaskSchedulerImpl
 
 private[spark] class YarnClientSchedulerBackend(
-    scheduler: ClusterScheduler,
+    scheduler: TaskSchedulerImpl,
     sc: SparkContext)
   extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem)
   with Logging {
diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
index 29b3f22e13697b38bc501e2f914d8fc0a202d722..2d9fbcb400e5bc07e1af665116910909d7edd118 100644
--- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
+++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
@@ -17,16 +17,20 @@
 
 package org.apache.spark.scheduler.cluster
 
+import org.apache.hadoop.conf.Configuration
+
 import org.apache.spark._
 import org.apache.spark.deploy.yarn.{ApplicationMaster, YarnAllocationHandler}
+import org.apache.spark.scheduler.TaskSchedulerImpl
 import org.apache.spark.util.Utils
-import org.apache.hadoop.conf.Configuration
 
 /**
  *
- * This is a simple extension to ClusterScheduler - to ensure that appropriate initialization of ApplicationMaster, etc is done
+ * This is a simple extension to ClusterScheduler - to ensure that appropriate initialization of
+ * ApplicationMaster, etc. is done
  */
-private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration) extends ClusterScheduler(sc) {
+private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration)
+  extends TaskSchedulerImpl(sc) {
 
   logInfo("Created YarnClusterScheduler")