diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala
index e183efccbb6f7288b09501a8a8937c5fa90fe307..b45e599588ad39fc6750956eb8fdc94f3c11aa23 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala
@@ -121,9 +121,15 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg
 
   /** Returns the maximum number of attempts to register the AM. */
   def getMaxRegAttempts(sparkConf: SparkConf, yarnConf: YarnConfiguration): Int = {
-    sparkConf.getOption("spark.yarn.maxAppAttempts").map(_.toInt).getOrElse(
-      yarnConf.getInt(
-        YarnConfiguration.RM_AM_MAX_ATTEMPTS, YarnConfiguration.DEFAULT_RM_AM_MAX_ATTEMPTS))
+    val sparkMaxAttempts = sparkConf.getOption("spark.yarn.maxAppAttempts").map(_.toInt)
+    val yarnMaxAttempts = yarnConf.getInt(
+      YarnConfiguration.RM_AM_MAX_ATTEMPTS, YarnConfiguration.DEFAULT_RM_AM_MAX_ATTEMPTS)
+    val retval: Int = sparkMaxAttempts match {
+      case Some(x) => if (x <= yarnMaxAttempts) x else yarnMaxAttempts
+      case None => yarnMaxAttempts
+    }
+
+    retval
   }
 
 }