diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index 614278c8b2d220b10894c03014f934c05f881fa4..a4b575c85d5fbaa1628821a189a0dad4456f360a 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -20,9 +20,11 @@ package org.apache.spark.deploy.yarn
 import java.io.{File, IOException}
 import java.lang.reflect.InvocationTargetException
 import java.net.{Socket, URI, URL}
-import java.util.concurrent.atomic.AtomicReference
+import java.util.concurrent.{TimeoutException, TimeUnit}
 
 import scala.collection.mutable.HashMap
+import scala.concurrent.Promise
+import scala.concurrent.duration.Duration
 import scala.util.control.NonFatal
 
 import org.apache.hadoop.fs.{FileSystem, Path}
@@ -106,12 +108,11 @@ private[spark] class ApplicationMaster(
   // Next wait interval before allocator poll.
   private var nextAllocationInterval = initialAllocationInterval
 
-  // Fields used in client mode.
   private var rpcEnv: RpcEnv = null
   private var amEndpoint: RpcEndpointRef = _
 
-  // Fields used in cluster mode.
-  private val sparkContextRef = new AtomicReference[SparkContext](null)
+  // In cluster mode, used to tell the AM when the user's SparkContext has been initialized.
+  private val sparkContextPromise = Promise[SparkContext]()
 
   private var credentialRenewer: AMCredentialRenewer = _
 
@@ -316,23 +317,15 @@ private[spark] class ApplicationMaster(
   }
 
   private def sparkContextInitialized(sc: SparkContext) = {
-    sparkContextRef.synchronized {
-      sparkContextRef.compareAndSet(null, sc)
-      sparkContextRef.notifyAll()
-    }
-  }
-
-  private def sparkContextStopped(sc: SparkContext) = {
-    sparkContextRef.compareAndSet(sc, null)
+    sparkContextPromise.success(sc)
   }
 
   private def registerAM(
+      _sparkConf: SparkConf,
       _rpcEnv: RpcEnv,
       driverRef: RpcEndpointRef,
       uiAddress: String,
       securityMgr: SecurityManager) = {
-    val sc = sparkContextRef.get()
-
     val appId = client.getAttemptId().getApplicationId().toString()
     val attemptId = client.getAttemptId().getAttemptId().toString()
     val historyAddress =
@@ -341,7 +334,6 @@ private[spark] class ApplicationMaster(
         .map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}/${attemptId}" }
         .getOrElse("")
 
-    val _sparkConf = if (sc != null) sc.getConf else sparkConf
     val driverUrl = RpcEndpointAddress(
       _sparkConf.get("spark.driver.host"),
       _sparkConf.get("spark.driver.port").toInt,
@@ -385,21 +377,35 @@ private[spark] class ApplicationMaster(
 
     // This a bit hacky, but we need to wait until the spark.driver.port property has
     // been set by the Thread executing the user class.
-    val sc = waitForSparkContextInitialized()
-
-    // If there is no SparkContext at this point, just fail the app.
-    if (sc == null) {
-      finish(FinalApplicationStatus.FAILED,
-        ApplicationMaster.EXIT_SC_NOT_INITED,
-        "Timed out waiting for SparkContext.")
-    } else {
-      rpcEnv = sc.env.rpcEnv
-      val driverRef = runAMEndpoint(
-        sc.getConf.get("spark.driver.host"),
-        sc.getConf.get("spark.driver.port"),
-        isClusterMode = true)
-      registerAM(rpcEnv, driverRef, sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr)
+    logInfo("Waiting for spark context initialization...")
+    val totalWaitTime = sparkConf.get(AM_MAX_WAIT_TIME)
+    try {
+      val sc = ThreadUtils.awaitResult(sparkContextPromise.future,
+        Duration(totalWaitTime, TimeUnit.MILLISECONDS))
+      if (sc != null) {
+        rpcEnv = sc.env.rpcEnv
+        val driverRef = runAMEndpoint(
+          sc.getConf.get("spark.driver.host"),
+          sc.getConf.get("spark.driver.port"),
+          isClusterMode = true)
+        registerAM(sc.getConf, rpcEnv, driverRef, sc.ui.map(_.appUIAddress).getOrElse(""),
+          securityMgr)
+      } else {
+        // Sanity check; should never happen in normal operation, since sc should only be null
+        // if the user app did not create a SparkContext.
+        if (!finished) {
+          throw new IllegalStateException("SparkContext is null but app is still running!")
+        }
+      }
       userClassThread.join()
+    } catch {
+      case e: SparkException if e.getCause().isInstanceOf[TimeoutException] =>
+        logError(
+          s"SparkContext did not initialize after waiting for $totalWaitTime ms. " +
+           "Please check earlier log output for errors. Failing the application.")
+        finish(FinalApplicationStatus.FAILED,
+          ApplicationMaster.EXIT_SC_NOT_INITED,
+          "Timed out waiting for SparkContext.")
     }
   }
 
@@ -409,7 +415,8 @@ private[spark] class ApplicationMaster(
       clientMode = true)
     val driverRef = waitForSparkDriver()
     addAmIpFilter()
-    registerAM(rpcEnv, driverRef, sparkConf.get("spark.driver.appUIAddress", ""), securityMgr)
+    registerAM(sparkConf, rpcEnv, driverRef, sparkConf.get("spark.driver.appUIAddress", ""),
+      securityMgr)
 
     // In client mode the actor will stop the reporter thread.
     reporterThread.join()
@@ -525,26 +532,6 @@ private[spark] class ApplicationMaster(
     }
   }
 
-  private def waitForSparkContextInitialized(): SparkContext = {
-    logInfo("Waiting for spark context initialization")
-    sparkContextRef.synchronized {
-      val totalWaitTime = sparkConf.get(AM_MAX_WAIT_TIME)
-      val deadline = System.currentTimeMillis() + totalWaitTime
-
-      while (sparkContextRef.get() == null && System.currentTimeMillis < deadline && !finished) {
-        logInfo("Waiting for spark context initialization ... ")
-        sparkContextRef.wait(10000L)
-      }
-
-      val sparkContext = sparkContextRef.get()
-      if (sparkContext == null) {
-        logError(("SparkContext did not initialize after waiting for %d ms. Please check earlier"
-          + " log output for errors. Failing the application.").format(totalWaitTime))
-      }
-      sparkContext
-    }
-  }
-
   private def waitForSparkDriver(): RpcEndpointRef = {
     logInfo("Waiting for Spark driver to be reachable.")
     var driverUp = false
@@ -647,6 +634,13 @@ private[spark] class ApplicationMaster(
                   ApplicationMaster.EXIT_EXCEPTION_USER_CLASS,
                   "User class threw exception: " + cause)
             }
+            sparkContextPromise.tryFailure(e.getCause())
+        } finally {
+          // Notify the thread waiting for the SparkContext, in case the application did not
+          // instantiate one. This will do nothing when the user code instantiates a SparkContext
+          // (with the correct master), or when the user code throws an exception (due to the
+          // tryFailure above).
+          sparkContextPromise.trySuccess(null)
         }
       }
     }
@@ -759,10 +753,6 @@ object ApplicationMaster extends Logging {
     master.sparkContextInitialized(sc)
   }
 
-  private[spark] def sparkContextStopped(sc: SparkContext): Boolean = {
-    master.sparkContextStopped(sc)
-  }
-
   private[spark] def getAttemptId(): ApplicationAttemptId = {
     master.getAttemptId
   }
diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
index 72ec4d6b34af6498834cc41583b0d6e67862e618..96c9151fc351d73a427857cffb97d76941df1933 100644
--- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
+++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
@@ -34,9 +34,4 @@ private[spark] class YarnClusterScheduler(sc: SparkContext) extends YarnSchedule
     logInfo("YarnClusterScheduler.postStartHook done")
   }
 
-  override def stop() {
-    super.stop()
-    ApplicationMaster.sparkContextStopped(sc)
-  }
-
 }
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
index 8ab7b21c22139462f5cdbb9978876df3bbf8a2e2..fb7926f6a1e2808110923a6aea9956d507634c64 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
@@ -33,6 +33,7 @@ import org.scalatest.concurrent.Eventually._
 
 import org.apache.spark._
 import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.deploy.yarn.config._
 import org.apache.spark.internal.Logging
 import org.apache.spark.launcher._
 import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationStart,
@@ -192,6 +193,14 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
     }
   }
 
+  test("timeout to get SparkContext in cluster mode triggers failure") {
+    val timeout = 2000
+    val finalState = runSpark(false, mainClassName(SparkContextTimeoutApp.getClass),
+      appArgs = Seq((timeout * 4).toString),
+      extraConf = Map(AM_MAX_WAIT_TIME.key -> timeout.toString))
+    finalState should be (SparkAppHandle.State.FAILED)
+  }
+
   private def testBasicYarnApp(clientMode: Boolean, conf: Map[String, String] = Map()): Unit = {
     val result = File.createTempFile("result", null, tempDir)
     val finalState = runSpark(clientMode, mainClassName(YarnClusterDriver.getClass),
@@ -469,3 +478,16 @@ private object YarnLauncherTestApp {
   }
 
 }
+
+/**
+ * Used to test code in the AM that detects the SparkContext instance. Expects a single argument
+ * with the duration to sleep for, in ms.
+ */
+private object SparkContextTimeoutApp {
+
+  def main(args: Array[String]): Unit = {
+    val Array(sleepTime) = args
+    Thread.sleep(java.lang.Long.parseLong(sleepTime))
+  }
+
+}