From 0b5abbf5f96a5f6bfd15a65e8788cf3fa96fe54c Mon Sep 17 00:00:00 2001
From: Josh Rosen <joshrosen@databricks.com>
Date: Sat, 27 Jun 2015 14:40:45 -0700
Subject: [PATCH] [SPARK-8606] Prevent exceptions in
 RDD.getPreferredLocations() from crashing DAGScheduler

If `RDD.getPreferredLocations()` throws an exception it may crash the DAGScheduler and SparkContext. This patch addresses this by adding a try-catch block.

Author: Josh Rosen <joshrosen@databricks.com>

Closes #7023 from JoshRosen/SPARK-8606 and squashes the following commits:

770b169 [Josh Rosen] Fix getPreferredLocations() DAGScheduler crash with try block.
44a9b55 [Josh Rosen] Add test of a buggy getPartitions() method
19aa9f7 [Josh Rosen] Add (failing) regression test for getPreferredLocations() DAGScheduler crash
---
 .../apache/spark/scheduler/DAGScheduler.scala | 37 +++++++++++--------
 .../spark/scheduler/DAGSchedulerSuite.scala   | 31 ++++++++++++++++
 2 files changed, 53 insertions(+), 15 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index b00a5fee09..a7cf0c23d9 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -907,22 +907,29 @@ class DAGScheduler(
         return
     }
 
-    val tasks: Seq[Task[_]] = stage match {
-      case stage: ShuffleMapStage =>
-        partitionsToCompute.map { id =>
-          val locs = getPreferredLocs(stage.rdd, id)
-          val part = stage.rdd.partitions(id)
-          new ShuffleMapTask(stage.id, taskBinary, part, locs)
-        }
+    val tasks: Seq[Task[_]] = try {
+      stage match {
+        case stage: ShuffleMapStage =>
+          partitionsToCompute.map { id =>
+            val locs = getPreferredLocs(stage.rdd, id)
+            val part = stage.rdd.partitions(id)
+            new ShuffleMapTask(stage.id, taskBinary, part, locs)
+          }
 
-      case stage: ResultStage =>
-        val job = stage.resultOfJob.get
-        partitionsToCompute.map { id =>
-          val p: Int = job.partitions(id)
-          val part = stage.rdd.partitions(p)
-          val locs = getPreferredLocs(stage.rdd, p)
-          new ResultTask(stage.id, taskBinary, part, locs, id)
-        }
+        case stage: ResultStage =>
+          val job = stage.resultOfJob.get
+          partitionsToCompute.map { id =>
+            val p: Int = job.partitions(id)
+            val part = stage.rdd.partitions(p)
+            val locs = getPreferredLocs(stage.rdd, p)
+            new ResultTask(stage.id, taskBinary, part, locs, id)
+          }
+      }
+    } catch {
+      case NonFatal(e) =>
+        abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}")
+        runningStages -= stage
+        return
     }
 
     if (tasks.size > 0) {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 833b600746..6bc45f249f 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -784,6 +784,37 @@ class DAGSchedulerSuite
     assert(sc.parallelize(1 to 10, 2).first() === 1)
   }
 
+  test("getPartitions exceptions should not crash DAGScheduler and SparkContext (SPARK-8606)") {
+    val e1 = intercept[DAGSchedulerSuiteDummyException] {
+      val rdd = new MyRDD(sc, 2, Nil) {
+        override def getPartitions: Array[Partition] = {
+          throw new DAGSchedulerSuiteDummyException
+        }
+      }
+      rdd.reduceByKey(_ + _, 1).count()
+    }
+
+    // Make sure we can still run local commands as well as cluster commands.
+    assert(sc.parallelize(1 to 10, 2).count() === 10)
+    assert(sc.parallelize(1 to 10, 2).first() === 1)
+  }
+
+  test("getPreferredLocations errors should not crash DAGScheduler and SparkContext (SPARK-8606)") {
+    val e1 = intercept[SparkException] {
+      val rdd = new MyRDD(sc, 2, Nil) {
+        override def getPreferredLocations(split: Partition): Seq[String] = {
+          throw new DAGSchedulerSuiteDummyException
+        }
+      }
+      rdd.count()
+    }
+    assert(e1.getMessage.contains(classOf[DAGSchedulerSuiteDummyException].getName))
+
+    // Make sure we can still run local commands as well as cluster commands.
+    assert(sc.parallelize(1 to 10, 2).count() === 10)
+    assert(sc.parallelize(1 to 10, 2).first() === 1)
+  }
+
   test("accumulator not calculated for resubmitted result stage") {
     // just for register
     val accum = new Accumulator[Int](0, AccumulatorParam.IntAccumulatorParam)
-- 
GitLab