diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
index fcc72ff49276da7378bd2be2eb9daeaf0e9b2077..9a0e3b555789204617b8fd63f29627f030301deb 100644
--- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
+++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
@@ -21,7 +21,7 @@ import java.util.concurrent.TimeUnit
 
 import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
-import scala.util.control.ControlThrowable
+import scala.util.control.{ControlThrowable, NonFatal}
 
 import com.codahale.metrics.{Gauge, MetricRegistry}
 
@@ -245,14 +245,15 @@ private[spark] class ExecutorAllocationManager(
   }
 
   /**
-   * Reset the allocation manager to the initial state. Currently this will only be called in
-   * yarn-client mode when AM re-registers after a failure.
+   * Reset the allocation manager when the cluster manager loses track of the driver's state.
+   * This is currently only done in YARN client mode, when the AM is restarted.
+   *
+   * This method forgets about any state about existing executors, and forces the scheduler to
+   * re-evaluate the number of needed executors the next time it's run.
    */
   def reset(): Unit = synchronized {
-    initializing = true
+    addTime = 0L
     numExecutorsTarget = initialNumExecutors
-    numExecutorsToAdd = 1
-
     executorsPendingToRemove.clear()
     removeTimes.clear()
   }
@@ -376,8 +377,17 @@ private[spark] class ExecutorAllocationManager(
       return 0
     }
 
-    val addRequestAcknowledged = testing ||
-      client.requestTotalExecutors(numExecutorsTarget, localityAwareTasks, hostToLocalTaskCount)
+    val addRequestAcknowledged = try {
+      testing ||
+        client.requestTotalExecutors(numExecutorsTarget, localityAwareTasks, hostToLocalTaskCount)
+    } catch {
+      case NonFatal(e) =>
+        // Use INFO level so the error it doesn't show up by default in shells. Errors here are more
+        // commonly caused by YARN AM restarts, which is a recoverable issue, and generate a lot of
+        // noisy output.
+        logInfo("Error reaching cluster manager.", e)
+        false
+    }
     if (addRequestAcknowledged) {
       val executorsString = "executor" + { if (delta > 1) "s" else "" }
       logInfo(s"Requesting $delta new $executorsString because tasks are backlogged" +
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index 0387b44dbcc100f76bb03db4623f58c932894f01..e227bff88f71d5308f6e76e6adeab1eee71317fa 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -130,7 +130,6 @@ private[spark] class ApplicationMaster(
   private var nextAllocationInterval = initialAllocationInterval
 
   private var rpcEnv: RpcEnv = null
-  private var amEndpoint: RpcEndpointRef = _
 
   // In cluster mode, used to tell the AM when the user's SparkContext has been initialized.
   private val sparkContextPromise = Promise[SparkContext]()
@@ -405,32 +404,26 @@ private[spark] class ApplicationMaster(
       securityMgr,
       localResources)
 
+    // Initialize the AM endpoint *after* the allocator has been initialized. This ensures
+    // that when the driver sends an initial executor request (e.g. after an AM restart),
+    // the allocator is ready to service requests.
+    rpcEnv.setupEndpoint("YarnAM", new AMEndpoint(rpcEnv, driverRef))
+
     allocator.allocateResources()
     reporterThread = launchReporterThread()
   }
 
   /**
-   * Create an [[RpcEndpoint]] that communicates with the driver.
-   *
-   * In cluster mode, the AM and the driver belong to same process
-   * so the AMEndpoint need not monitor lifecycle of the driver.
-   *
-   * @return A reference to the driver's RPC endpoint.
+   * @return An [[RpcEndpoint]] that communicates with the driver's scheduler backend.
    */
-  private def runAMEndpoint(
-      host: String,
-      port: String,
-      isClusterMode: Boolean): RpcEndpointRef = {
-    val driverEndpoint = rpcEnv.setupEndpointRef(
+  private def createSchedulerRef(host: String, port: String): RpcEndpointRef = {
+    rpcEnv.setupEndpointRef(
       RpcAddress(host, port.toInt),
       YarnSchedulerBackend.ENDPOINT_NAME)
-    amEndpoint =
-      rpcEnv.setupEndpoint("YarnAM", new AMEndpoint(rpcEnv, driverEndpoint, isClusterMode))
-    driverEndpoint
   }
 
   private def runDriver(securityMgr: SecurityManager): Unit = {
-    addAmIpFilter()
+    addAmIpFilter(None)
     userClassThread = startUserApplication()
 
     // This a bit hacky, but we need to wait until the spark.driver.port property has
@@ -442,10 +435,9 @@ private[spark] class ApplicationMaster(
         Duration(totalWaitTime, TimeUnit.MILLISECONDS))
       if (sc != null) {
         rpcEnv = sc.env.rpcEnv
-        val driverRef = runAMEndpoint(
+        val driverRef = createSchedulerRef(
           sc.getConf.get("spark.driver.host"),
-          sc.getConf.get("spark.driver.port"),
-          isClusterMode = true)
+          sc.getConf.get("spark.driver.port"))
         registerAM(sc.getConf, rpcEnv, driverRef, sc.ui.map(_.webUrl), securityMgr)
         registered = true
       } else {
@@ -471,7 +463,7 @@ private[spark] class ApplicationMaster(
     rpcEnv = RpcEnv.create("sparkYarnAM", hostname, hostname, -1, sparkConf, securityMgr,
       amCores, true)
     val driverRef = waitForSparkDriver()
-    addAmIpFilter()
+    addAmIpFilter(Some(driverRef))
     registerAM(sparkConf, rpcEnv, driverRef, sparkConf.getOption("spark.driver.appUIAddress"),
       securityMgr)
     registered = true
@@ -620,20 +612,21 @@ private[spark] class ApplicationMaster(
 
     sparkConf.set("spark.driver.host", driverHost)
     sparkConf.set("spark.driver.port", driverPort.toString)
-
-    runAMEndpoint(driverHost, driverPort.toString, isClusterMode = false)
+    createSchedulerRef(driverHost, driverPort.toString)
   }
 
   /** Add the Yarn IP filter that is required for properly securing the UI. */
-  private def addAmIpFilter() = {
+  private def addAmIpFilter(driver: Option[RpcEndpointRef]) = {
     val proxyBase = System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV)
     val amFilter = "org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter"
     val params = client.getAmIpFilterParams(yarnConf, proxyBase)
-    if (isClusterMode) {
-      System.setProperty("spark.ui.filters", amFilter)
-      params.foreach { case (k, v) => System.setProperty(s"spark.$amFilter.param.$k", v) }
-    } else {
-      amEndpoint.send(AddWebUIFilter(amFilter, params.toMap, proxyBase))
+    driver match {
+      case Some(d) =>
+        d.send(AddWebUIFilter(amFilter, params.toMap, proxyBase))
+
+      case None =>
+        System.setProperty("spark.ui.filters", amFilter)
+        params.foreach { case (k, v) => System.setProperty(s"spark.$amFilter.param.$k", v) }
     }
   }
 
@@ -704,20 +697,13 @@ private[spark] class ApplicationMaster(
   /**
    * An [[RpcEndpoint]] that communicates with the driver's scheduler backend.
    */
-  private class AMEndpoint(
-      override val rpcEnv: RpcEnv, driver: RpcEndpointRef, isClusterMode: Boolean)
+  private class AMEndpoint(override val rpcEnv: RpcEnv, driver: RpcEndpointRef)
     extends RpcEndpoint with Logging {
 
     override def onStart(): Unit = {
       driver.send(RegisterClusterManager(self))
     }
 
-    override def receive: PartialFunction[Any, Unit] = {
-      case x: AddWebUIFilter =>
-        logInfo(s"Add WebUI Filter. $x")
-        driver.send(x)
-    }
-
     override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
       case r: RequestExecutors =>
         Option(allocator) match {
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
index 8452f43774194a97a32ad6215302a2b69d2a9653..415a29fd887e81dd6f868dafe96804b0e70eacca 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
@@ -69,9 +69,6 @@ private[spark] abstract class YarnSchedulerBackend(
   /** Scheduler extension services. */
   private val services: SchedulerExtensionServices = new SchedulerExtensionServices()
 
-  // Flag to specify whether this schedulerBackend should be reset.
-  private var shouldResetOnAmRegister = false
-
   /**
    * Bind to YARN. This *must* be done before calling [[start()]].
    *
@@ -262,13 +259,7 @@ private[spark] abstract class YarnSchedulerBackend(
       case RegisterClusterManager(am) =>
         logInfo(s"ApplicationMaster registered as $am")
         amEndpoint = Option(am)
-        if (!shouldResetOnAmRegister) {
-          shouldResetOnAmRegister = true
-        } else {
-          // AM is already registered before, this potentially means that AM failed and
-          // a new one registered after the failure. This will only happen in yarn-client mode.
-          reset()
-        }
+        reset()
 
       case AddWebUIFilter(filterName, filterParams, proxyBase) =>
         addWebUIFilter(filterName, filterParams, proxyBase)