diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
index 7bf50de660ec3302f31d95ec68f2534a3b9795bd..9b45fc29388eaba971624bd031b891abbeac2de5 100644
--- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
@@ -17,19 +17,17 @@
 
 package spark.scheduler
 
-import cluster.TaskInfo
-import java.util.concurrent.atomic.AtomicInteger
-import java.util.concurrent.LinkedBlockingQueue
-import java.util.concurrent.TimeUnit
+import java.io.NotSerializableException
 import java.util.Properties
+import java.util.concurrent.{LinkedBlockingQueue, TimeUnit}
+import java.util.concurrent.atomic.AtomicInteger
 
 import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
 
 import spark._
 import spark.executor.TaskMetrics
-import spark.partial.ApproximateActionListener
-import spark.partial.ApproximateEvaluator
-import spark.partial.PartialResult
+import spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
+import spark.scheduler.cluster.TaskInfo
 import spark.storage.{BlockManager, BlockManagerMaster}
 import spark.util.{MetadataCleaner, TimeStampedHashMap}
 
@@ -263,7 +261,8 @@ class DAGScheduler(
     assert(partitions.size > 0)
     val waiter = new JobWaiter(partitions.size, resultHandler)
     val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
-    val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties)
+    val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter,
+      properties)
     return (toSubmit, waiter)
   }
 
@@ -288,7 +287,7 @@ class DAGScheduler(
         "Total number of partitions: " + maxPartitions)
     }
 
-    val (toSubmit, waiter) = prepareJob(
+    val (toSubmit: JobSubmitted, waiter: JobWaiter[_]) = prepareJob(
         finalRdd, func, partitions, callSite, allowLocal, resultHandler, properties)
     eventQueue.put(toSubmit)
     waiter.awaitResult() match {
@@ -512,6 +511,19 @@ class DAGScheduler(
       }
     }
     if (tasks.size > 0) {
+      // Preemptively serialize a task to make sure it can be serialized. We are catching this
+      // exception here because it would be fairly hard to catch the non-serializable exception
+      // down the road, where we have several different implementations for local scheduler and
+      // cluster schedulers.
+      try {
+        SparkEnv.get.closureSerializer.newInstance().serialize(tasks.head)
+      } catch {
+        case e: NotSerializableException =>
+          abortStage(stage, e.toString)
+          running -= stage
+          return
+      }
+
       sparkListeners.foreach(_.onStageSubmitted(SparkListenerStageSubmitted(stage, tasks.size)))
       logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")")
       myPending ++= tasks
diff --git a/core/src/test/scala/spark/FailureSuite.scala b/core/src/test/scala/spark/FailureSuite.scala
index 6c847b8fef8562886accfd13407aa110007a5277..5b133cdd6e867947bb324b068fe352cefd3c1178 100644
--- a/core/src/test/scala/spark/FailureSuite.scala
+++ b/core/src/test/scala/spark/FailureSuite.scala
@@ -18,9 +18,6 @@
 package spark
 
 import org.scalatest.FunSuite
-import org.scalatest.prop.Checkers
-
-import scala.collection.mutable.ArrayBuffer
 
 import SparkContext._
 
@@ -40,7 +37,7 @@ object FailureSuiteState {
 }
 
 class FailureSuite extends FunSuite with LocalSparkContext {
-  
+
   // Run a 3-task map job in which task 1 deterministically fails once, and check
   // whether the job completes successfully and we ran 4 tasks in total.
   test("failure in a single-stage job") {
@@ -66,7 +63,7 @@ class FailureSuite extends FunSuite with LocalSparkContext {
   test("failure in a two-stage job") {
     sc = new SparkContext("local[1,1]", "test")
     val results = sc.makeRDD(1 to 3).map(x => (x, x)).groupByKey(3).map {
-      case (k, v) => 
+      case (k, v) =>
         FailureSuiteState.synchronized {
           FailureSuiteState.tasksRun += 1
           if (k == 1 && FailureSuiteState.tasksFailed == 0) {
@@ -87,12 +84,40 @@ class FailureSuite extends FunSuite with LocalSparkContext {
     sc = new SparkContext("local[1,1]", "test")
     val results = sc.makeRDD(1 to 3).map(x => new NonSerializable)
 
-    val thrown = intercept[spark.SparkException] {
+    val thrown = intercept[SparkException] {
       results.collect()
     }
-    assert(thrown.getClass === classOf[spark.SparkException])
+    assert(thrown.getClass === classOf[SparkException])
+    assert(thrown.getMessage.contains("NotSerializableException"))
+
+    FailureSuiteState.clear()
+  }
+
+  test("failure because task closure is not serializable") {
+    sc = new SparkContext("local[1,1]", "test")
+    val a = new NonSerializable
+
+    // Non-serializable closure in the final result stage
+    val thrown = intercept[SparkException] {
+      sc.parallelize(1 to 10, 2).map(x => a).count()
+    }
+    assert(thrown.getClass === classOf[SparkException])
     assert(thrown.getMessage.contains("NotSerializableException"))
 
+    // Non-serializable closure in an earlier stage
+    val thrown1 = intercept[SparkException] {
+      sc.parallelize(1 to 10, 2).map(x => (x, a)).partitionBy(new HashPartitioner(3)).count()
+    }
+    assert(thrown1.getClass === classOf[SparkException])
+    assert(thrown1.getMessage.contains("NotSerializableException"))
+
+    // Non-serializable closure in foreach function
+    val thrown2 = intercept[SparkException] {
+      sc.parallelize(1 to 10, 2).foreach(x => println(a))
+    }
+    assert(thrown2.getClass === classOf[SparkException])
+    assert(thrown2.getMessage.contains("NotSerializableException"))
+
     FailureSuiteState.clear()
   }