diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index 902bdda59860ed1b12dfbe72f39f3608be3647a8..d3e327b2497b7b2c28412c6440825ed82ea56e55 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -43,8 +43,11 @@ import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils}
 /**
  * Common application master functionality for Spark on Yarn.
  */
-private[spark] class ApplicationMaster(args: ApplicationMasterArguments,
-  client: YarnRMClient) extends Logging {
+private[spark] class ApplicationMaster(
+    args: ApplicationMasterArguments,
+    client: YarnRMClient)
+  extends Logging {
+
   // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be
   // optimal as more containers are available. Might need to handle this better.
 
@@ -231,6 +234,24 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments,
     reporterThread = launchReporterThread()
   }
 
+  /**
+   * Create an actor that communicates with the driver.
+   *
+   * In cluster mode, the AM and the driver belong to same process
+   * so the AM actor need not monitor lifecycle of the driver.
+   */
+  private def runAMActor(
+      host: String,
+      port: String,
+      isDriver: Boolean): Unit = {
+    val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format(
+      SparkEnv.driverActorSystemName,
+      host,
+      port,
+      YarnSchedulerBackend.ACTOR_NAME)
+    actor = actorSystem.actorOf(Props(new AMActor(driverUrl, isDriver)), name = "YarnAM")
+  }
+
   private def runDriver(securityMgr: SecurityManager): Unit = {
     addAmIpFilter()
     userClassThread = startUserClass()
@@ -245,6 +266,11 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments,
         ApplicationMaster.EXIT_SC_NOT_INITED,
         "Timed out waiting for SparkContext.")
     } else {
+      actorSystem = sc.env.actorSystem
+      runAMActor(
+        sc.getConf.get("spark.driver.host"),
+        sc.getConf.get("spark.driver.port"),
+        isDriver = true)
       registerAM(sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr)
       userClassThread.join()
     }
@@ -253,7 +279,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments,
   private def runExecutorLauncher(securityMgr: SecurityManager): Unit = {
     actorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0,
       conf = sparkConf, securityManager = securityMgr)._1
-    actor = waitForSparkDriver()
+    waitForSparkDriver()
     addAmIpFilter()
     registerAM(sparkConf.get("spark.driver.appUIAddress", ""), securityMgr)
 
@@ -367,7 +393,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments,
     }
   }
 
-  private def waitForSparkDriver(): ActorRef = {
+  private def waitForSparkDriver(): Unit = {
     logInfo("Waiting for Spark driver to be reachable.")
     var driverUp = false
     val hostport = args.userArgs(0)
@@ -399,12 +425,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments,
     sparkConf.set("spark.driver.host", driverHost)
     sparkConf.set("spark.driver.port", driverPort.toString)
 
-    val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format(
-      SparkEnv.driverActorSystemName,
-      driverHost,
-      driverPort.toString,
-      YarnSchedulerBackend.ACTOR_NAME)
-    actorSystem.actorOf(Props(new AMActor(driverUrl)), name = "YarnAM")
+    runAMActor(driverHost, driverPort.toString, isDriver = false)
   }
 
   /** Add the Yarn IP filter that is required for properly securing the UI. */
@@ -462,9 +483,9 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments,
   }
 
   /**
-   * Actor that communicates with the driver in client deploy mode.
+   * An actor that communicates with the driver's scheduler backend.
    */
-  private class AMActor(driverUrl: String) extends Actor {
+  private class AMActor(driverUrl: String, isDriver: Boolean) extends Actor {
     var driver: ActorSelection = _
 
     override def preStart() = {
@@ -474,13 +495,21 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments,
       // we can monitor Lifecycle Events.
       driver ! "Hello"
       driver ! RegisterClusterManager
-      context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
+      // In cluster mode, the AM can directly monitor the driver status instead
+      // of trying to deduce it from the lifecycle of the driver's actor
+      if (!isDriver) {
+        context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
+      }
     }
 
     override def receive = {
       case x: DisassociatedEvent =>
         logInfo(s"Driver terminated or disconnected! Shutting down. $x")
-        finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS)
+        // In cluster mode, do not rely on the disassociated event to exit
+        // This avoids potentially reporting incorrect exit codes if the driver fails
+        if (!isDriver) {
+          finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS)
+        }
 
       case x: AddWebUIFilter =>
         logInfo(s"Add WebUI Filter. $x")