diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index e183efccbb6f7288b09501a8a8937c5fa90fe307..b45e599588ad39fc6750956eb8fdc94f3c11aa23 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -121,9 +121,15 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg /** Returns the maximum number of attempts to register the AM. */ def getMaxRegAttempts(sparkConf: SparkConf, yarnConf: YarnConfiguration): Int = { - sparkConf.getOption("spark.yarn.maxAppAttempts").map(_.toInt).getOrElse( - yarnConf.getInt( - YarnConfiguration.RM_AM_MAX_ATTEMPTS, YarnConfiguration.DEFAULT_RM_AM_MAX_ATTEMPTS)) + val sparkMaxAttempts = sparkConf.getOption("spark.yarn.maxAppAttempts").map(_.toInt) + val yarnMaxAttempts = yarnConf.getInt( + YarnConfiguration.RM_AM_MAX_ATTEMPTS, YarnConfiguration.DEFAULT_RM_AM_MAX_ATTEMPTS) + val retval: Int = sparkMaxAttempts match { + case Some(x) => if (x <= yarnMaxAttempts) x else yarnMaxAttempts + case None => yarnMaxAttempts + } + + retval } }