diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md
index b5fb077441468458eb121c7db642568f3c107464..86276b1aa9cea0c65d7ccf64c10ed9115791469a 100644
--- a/docs/running-on-yarn.md
+++ b/docs/running-on-yarn.md
@@ -22,10 +22,12 @@ Most of the configs are the same for Spark on YARN as for other deployment modes
 <table class="table">
 <tr><th>Property Name</th><th>Default</th><th>Meaning</th></tr>
 <tr>
-  <td><code>spark.yarn.applicationMaster.waitTries</code></td>
-  <td>10</td>
+  <td><code>spark.yarn.am.waitTime</code></td>
+  <td>100000</td>
   <td>
-    Set the number of times the ApplicationMaster waits for the the Spark master and then also the number of tries it waits for the SparkContext to be initialized
+    In yarn-cluster mode, time in milliseconds for the application master to wait for the
+    SparkContext to be initialized. In yarn-client mode, time for the application master to wait
+    for the driver to connect to it.
   </td>
 </tr>
 <tr>
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index 987b3373fb8ffef4373caf2127655e89449ee622..dc7a078446324908cea2c2b32bfc648fb798a77d 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -329,43 +329,43 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments,
 
   private def waitForSparkContextInitialized(): SparkContext = {
     logInfo("Waiting for spark context initialization")
-    try {
-      sparkContextRef.synchronized {
-        var count = 0
-        val waitTime = 10000L
-        val numTries = sparkConf.getInt("spark.yarn.applicationMaster.waitTries", 10)
-        while (sparkContextRef.get() == null && count < numTries && !finished) {
-          logInfo("Waiting for spark context initialization ... " + count)
-          count = count + 1
-          sparkContextRef.wait(waitTime)
-        }
+    sparkContextRef.synchronized {
+      val waitTries = sparkConf.getOption("spark.yarn.applicationMaster.waitTries")
+        .map(_.toLong * 10000L)
+      if (waitTries.isDefined) {
+        logWarning(
+          "spark.yarn.applicationMaster.waitTries is deprecated, use spark.yarn.am.waitTime")
+      }
+      val totalWaitTime = sparkConf.getLong("spark.yarn.am.waitTime", waitTries.getOrElse(100000L))
+      val deadline = System.currentTimeMillis() + totalWaitTime
 
-        val sparkContext = sparkContextRef.get()
-        if (sparkContext == null) {
-          logError(("SparkContext did not initialize after waiting for %d ms. Please check earlier"
-            + " log output for errors. Failing the application.").format(numTries * waitTime))
-        }
-        sparkContext
+      while (sparkContextRef.get() == null && System.currentTimeMillis < deadline && !finished) {
+        logInfo("Waiting for spark context initialization ... ")
+        sparkContextRef.wait(10000L)
+      }
+
+      val sparkContext = sparkContextRef.get()
+      if (sparkContext == null) {
+        logError(("SparkContext did not initialize after waiting for %d ms. Please check earlier"
+          + " log output for errors. Failing the application.").format(totalWaitTime))
       }
+      sparkContext
     }
   }
 
   private def waitForSparkDriver(): ActorRef = {
     logInfo("Waiting for Spark driver to be reachable.")
     var driverUp = false
-    var count = 0
     val hostport = args.userArgs(0)
     val (driverHost, driverPort) = Utils.parseHostPort(hostport)
 
-    // spark driver should already be up since it launched us, but we don't want to
+    // Spark driver should already be up since it launched us, but we don't want to
     // wait forever, so wait 100 seconds max to match the cluster mode setting.
-    // Leave this config unpublished for now. SPARK-3779 to investigating changing
-    // this config to be time based.
-    val numTries = sparkConf.getInt("spark.yarn.applicationMaster.waitTries", 1000)
+    val totalWaitTime = sparkConf.getLong("spark.yarn.am.waitTime", 100000L)
+    val deadline = System.currentTimeMillis + totalWaitTime
 
-    while (!driverUp && !finished && count < numTries) {
+    while (!driverUp && !finished && System.currentTimeMillis < deadline) {
       try {
-        count = count + 1
         val socket = new Socket(driverHost, driverPort)
         socket.close()
         logInfo("Driver now available: %s:%s".format(driverHost, driverPort))
@@ -374,7 +374,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments,
         case e: Exception =>
           logError("Failed to connect to driver at %s:%s, retrying ...".
             format(driverHost, driverPort))
-          Thread.sleep(100)
+          Thread.sleep(100L)
       }
     }