diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 692ed8083475cf5ea5e00d8f2d8d3bdf2a4e8a96..d944f268755de168222b36d3b4a7725d01db83d6 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -187,6 +187,13 @@ class DAGScheduler(
   /** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */
   private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false)
 
+  /**
+   * Number of consecutive stage attempts allowed before a stage is aborted.
+   */
+  private[scheduler] val maxConsecutiveStageAttempts =
+    sc.getConf.getInt("spark.stage.maxConsecutiveAttempts",
+      DAGScheduler.DEFAULT_MAX_CONSECUTIVE_STAGE_ATTEMPTS)
+
   private val messageScheduler =
     ThreadUtils.newDaemonSingleThreadScheduledExecutor("dag-scheduler-message")
 
@@ -1282,8 +1289,9 @@ class DAGScheduler(
               s"longer running")
           }
 
+          failedStage.fetchFailedAttemptIds.add(task.stageAttemptId)
           val shouldAbortStage =
-            failedStage.failedOnFetchAndShouldAbort(task.stageAttemptId) ||
+            failedStage.fetchFailedAttemptIds.size >= maxConsecutiveStageAttempts ||
             disallowStageRetryForTest
 
           if (shouldAbortStage) {
@@ -1292,7 +1300,7 @@ class DAGScheduler(
             } else {
               s"""$failedStage (${failedStage.name})
                  |has failed the maximum allowable number of
-                 |times: ${Stage.MAX_CONSECUTIVE_FETCH_FAILURES}.
+                 |times: $maxConsecutiveStageAttempts.
                  |Most recent failure reason: $failureMessage""".stripMargin.replaceAll("\n", " ")
             }
             abortStage(failedStage, abortMessage, None)
@@ -1726,4 +1734,7 @@ private[spark] object DAGScheduler {
   // this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one
   // as more failure events come in
   val RESUBMIT_TIMEOUT = 200
+
+  // Number of consecutive stage attempts allowed before a stage is aborted
+  val DEFAULT_MAX_CONSECUTIVE_STAGE_ATTEMPTS = 4
 }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index 32e5df6d75f4f998c55196ec0d714a7be9e9fcb7..290fd073caf2761ba6f78758e2bca89ae2ebbc2f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -87,23 +87,12 @@ private[scheduler] abstract class Stage(
    * We keep track of each attempt ID that has failed to avoid recording duplicate failures if
    * multiple tasks from the same stage attempt fail (SPARK-5945).
    */
-  private val fetchFailedAttemptIds = new HashSet[Int]
+  val fetchFailedAttemptIds = new HashSet[Int]
 
   private[scheduler] def clearFailures() : Unit = {
     fetchFailedAttemptIds.clear()
   }
 
-  /**
-   * Check whether we should abort the failedStage due to multiple consecutive fetch failures.
-   *
-   * This method updates the running set of failed stage attempts and returns
-   * true if the number of failures exceeds the allowable number of failures.
-   */
-  private[scheduler] def failedOnFetchAndShouldAbort(stageAttemptId: Int): Boolean = {
-    fetchFailedAttemptIds.add(stageAttemptId)
-    fetchFailedAttemptIds.size >= Stage.MAX_CONSECUTIVE_FETCH_FAILURES
-  }
-
   /** Creates a new attempt for this stage by creating a new StageInfo with a new attempt ID. */
   def makeNewStageAttempt(
       numPartitionsToCompute: Int,
@@ -128,8 +117,3 @@ private[scheduler] abstract class Stage(
   /** Returns the sequence of partition ids that are missing (i.e. needs to be computed). */
   def findMissingPartitions(): Seq[Int]
 }
-
-private[scheduler] object Stage {
-  // The number of consecutive failures allowed before a stage is aborted
-  val MAX_CONSECUTIVE_FETCH_FAILURES = 4
-}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 8eaf9dfcf49b1e18c935a5ddb312961e82064528..dfad5db68a91403c98e453f9460c10400613248f 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -801,7 +801,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
     val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker)
     submit(reduceRdd, Array(0, 1))
 
-    for (attempt <- 0 until Stage.MAX_CONSECUTIVE_FETCH_FAILURES) {
+    for (attempt <- 0 until scheduler.maxConsecutiveStageAttempts) {
       // Complete all the tasks for the current attempt of stage 0 successfully
       completeShuffleMapStageSuccessfully(0, attempt, numShufflePartitions = 2)
 
@@ -813,7 +813,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
       // map output, for the next iteration through the loop
       scheduler.resubmitFailedStages()
 
-      if (attempt < Stage.MAX_CONSECUTIVE_FETCH_FAILURES - 1) {
+      if (attempt < scheduler.maxConsecutiveStageAttempts - 1) {
         assert(scheduler.runningStages.nonEmpty)
         assert(!ended)
       } else {
@@ -847,11 +847,11 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
 
     // In the first two iterations, Stage 0 succeeds and stage 1 fails. In the next two iterations,
     // stage 2 fails.
-    for (attempt <- 0 until Stage.MAX_CONSECUTIVE_FETCH_FAILURES) {
+    for (attempt <- 0 until scheduler.maxConsecutiveStageAttempts) {
       // Complete all the tasks for the current attempt of stage 0 successfully
       completeShuffleMapStageSuccessfully(0, attempt, numShufflePartitions = 2)
 
-      if (attempt < Stage.MAX_CONSECUTIVE_FETCH_FAILURES / 2) {
+      if (attempt < scheduler.maxConsecutiveStageAttempts / 2) {
         // Now we should have a new taskSet, for a new attempt of stage 1.
         // Fail all these tasks with FetchFailure
         completeNextStageWithFetchFailure(1, attempt, shuffleDepOne)
@@ -859,8 +859,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
         completeShuffleMapStageSuccessfully(1, attempt, numShufflePartitions = 1)
 
         // Fail stage 2
-        completeNextStageWithFetchFailure(2, attempt - Stage.MAX_CONSECUTIVE_FETCH_FAILURES / 2,
-          shuffleDepTwo)
+        completeNextStageWithFetchFailure(2,
+          attempt - scheduler.maxConsecutiveStageAttempts / 2, shuffleDepTwo)
       }
 
       // this will trigger a resubmission of stage 0, since we've lost some of its
@@ -872,7 +872,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
     completeShuffleMapStageSuccessfully(1, 4, numShufflePartitions = 1)
 
     // Succeed stage2 with a "42"
-    completeNextResultStageWithSuccess(2, Stage.MAX_CONSECUTIVE_FETCH_FAILURES/2)
+    completeNextResultStageWithSuccess(2, scheduler.maxConsecutiveStageAttempts / 2)
 
     assert(results === Map(0 -> 42))
     assertDataStructuresEmpty()
@@ -895,7 +895,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
     submit(finalRdd, Array(0))
 
     // First, execute stages 0 and 1, failing stage 1 up to MAX-1 times.
-    for (attempt <- 0 until Stage.MAX_CONSECUTIVE_FETCH_FAILURES - 1) {
+    for (attempt <- 0 until scheduler.maxConsecutiveStageAttempts - 1) {
       // Make each task in stage 0 success
       completeShuffleMapStageSuccessfully(0, attempt, numShufflePartitions = 2)
 
diff --git a/docs/configuration.md b/docs/configuration.md
index 63392a741a1f038d1ed4364261aa524fe84ca90d..4729f1b0404c18e50f1ce564b2d2878a2fe7ed7d 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -1506,6 +1506,11 @@ Apart from these, the following properties are also available, and may be useful
     of this setting is to act as a safety-net to prevent runaway uncancellable tasks from rendering
     an executor unusable.
   </td>
+  <td><code>spark.stage.maxConsecutiveAttempts</code></td>
+  <td>4</td>
+  <td>
+    Number of consecutive stage attempts allowed before a stage is aborted.
+  </td>
 </tr>
 </table>