diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 614278c8b2d220b10894c03014f934c05f881fa4..a4b575c85d5fbaa1628821a189a0dad4456f360a 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -20,9 +20,11 @@ package org.apache.spark.deploy.yarn import java.io.{File, IOException} import java.lang.reflect.InvocationTargetException import java.net.{Socket, URI, URL} -import java.util.concurrent.atomic.AtomicReference +import java.util.concurrent.{TimeoutException, TimeUnit} import scala.collection.mutable.HashMap +import scala.concurrent.Promise +import scala.concurrent.duration.Duration import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} @@ -106,12 +108,11 @@ private[spark] class ApplicationMaster( // Next wait interval before allocator poll. private var nextAllocationInterval = initialAllocationInterval - // Fields used in client mode. private var rpcEnv: RpcEnv = null private var amEndpoint: RpcEndpointRef = _ - // Fields used in cluster mode. - private val sparkContextRef = new AtomicReference[SparkContext](null) + // In cluster mode, used to tell the AM when the user's SparkContext has been initialized. + private val sparkContextPromise = Promise[SparkContext]() private var credentialRenewer: AMCredentialRenewer = _ @@ -316,23 +317,15 @@ private[spark] class ApplicationMaster( } private def sparkContextInitialized(sc: SparkContext) = { - sparkContextRef.synchronized { - sparkContextRef.compareAndSet(null, sc) - sparkContextRef.notifyAll() - } - } - - private def sparkContextStopped(sc: SparkContext) = { - sparkContextRef.compareAndSet(sc, null) + sparkContextPromise.success(sc) } private def registerAM( + _sparkConf: SparkConf, _rpcEnv: RpcEnv, driverRef: RpcEndpointRef, uiAddress: String, securityMgr: SecurityManager) = { - val sc = sparkContextRef.get() - val appId = client.getAttemptId().getApplicationId().toString() val attemptId = client.getAttemptId().getAttemptId().toString() val historyAddress = @@ -341,7 +334,6 @@ private[spark] class ApplicationMaster( .map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}/${attemptId}" } .getOrElse("") - val _sparkConf = if (sc != null) sc.getConf else sparkConf val driverUrl = RpcEndpointAddress( _sparkConf.get("spark.driver.host"), _sparkConf.get("spark.driver.port").toInt, @@ -385,21 +377,35 @@ private[spark] class ApplicationMaster( // This a bit hacky, but we need to wait until the spark.driver.port property has // been set by the Thread executing the user class. - val sc = waitForSparkContextInitialized() - - // If there is no SparkContext at this point, just fail the app. - if (sc == null) { - finish(FinalApplicationStatus.FAILED, - ApplicationMaster.EXIT_SC_NOT_INITED, - "Timed out waiting for SparkContext.") - } else { - rpcEnv = sc.env.rpcEnv - val driverRef = runAMEndpoint( - sc.getConf.get("spark.driver.host"), - sc.getConf.get("spark.driver.port"), - isClusterMode = true) - registerAM(rpcEnv, driverRef, sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr) + logInfo("Waiting for spark context initialization...") + val totalWaitTime = sparkConf.get(AM_MAX_WAIT_TIME) + try { + val sc = ThreadUtils.awaitResult(sparkContextPromise.future, + Duration(totalWaitTime, TimeUnit.MILLISECONDS)) + if (sc != null) { + rpcEnv = sc.env.rpcEnv + val driverRef = runAMEndpoint( + sc.getConf.get("spark.driver.host"), + sc.getConf.get("spark.driver.port"), + isClusterMode = true) + registerAM(sc.getConf, rpcEnv, driverRef, sc.ui.map(_.appUIAddress).getOrElse(""), + securityMgr) + } else { + // Sanity check; should never happen in normal operation, since sc should only be null + // if the user app did not create a SparkContext. + if (!finished) { + throw new IllegalStateException("SparkContext is null but app is still running!") + } + } userClassThread.join() + } catch { + case e: SparkException if e.getCause().isInstanceOf[TimeoutException] => + logError( + s"SparkContext did not initialize after waiting for $totalWaitTime ms. " + + "Please check earlier log output for errors. Failing the application.") + finish(FinalApplicationStatus.FAILED, + ApplicationMaster.EXIT_SC_NOT_INITED, + "Timed out waiting for SparkContext.") } } @@ -409,7 +415,8 @@ private[spark] class ApplicationMaster( clientMode = true) val driverRef = waitForSparkDriver() addAmIpFilter() - registerAM(rpcEnv, driverRef, sparkConf.get("spark.driver.appUIAddress", ""), securityMgr) + registerAM(sparkConf, rpcEnv, driverRef, sparkConf.get("spark.driver.appUIAddress", ""), + securityMgr) // In client mode the actor will stop the reporter thread. reporterThread.join() @@ -525,26 +532,6 @@ private[spark] class ApplicationMaster( } } - private def waitForSparkContextInitialized(): SparkContext = { - logInfo("Waiting for spark context initialization") - sparkContextRef.synchronized { - val totalWaitTime = sparkConf.get(AM_MAX_WAIT_TIME) - val deadline = System.currentTimeMillis() + totalWaitTime - - while (sparkContextRef.get() == null && System.currentTimeMillis < deadline && !finished) { - logInfo("Waiting for spark context initialization ... ") - sparkContextRef.wait(10000L) - } - - val sparkContext = sparkContextRef.get() - if (sparkContext == null) { - logError(("SparkContext did not initialize after waiting for %d ms. Please check earlier" - + " log output for errors. Failing the application.").format(totalWaitTime)) - } - sparkContext - } - } - private def waitForSparkDriver(): RpcEndpointRef = { logInfo("Waiting for Spark driver to be reachable.") var driverUp = false @@ -647,6 +634,13 @@ private[spark] class ApplicationMaster( ApplicationMaster.EXIT_EXCEPTION_USER_CLASS, "User class threw exception: " + cause) } + sparkContextPromise.tryFailure(e.getCause()) + } finally { + // Notify the thread waiting for the SparkContext, in case the application did not + // instantiate one. This will do nothing when the user code instantiates a SparkContext + // (with the correct master), or when the user code throws an exception (due to the + // tryFailure above). + sparkContextPromise.trySuccess(null) } } } @@ -759,10 +753,6 @@ object ApplicationMaster extends Logging { master.sparkContextInitialized(sc) } - private[spark] def sparkContextStopped(sc: SparkContext): Boolean = { - master.sparkContextStopped(sc) - } - private[spark] def getAttemptId(): ApplicationAttemptId = { master.getAttemptId } diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala index 72ec4d6b34af6498834cc41583b0d6e67862e618..96c9151fc351d73a427857cffb97d76941df1933 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala @@ -34,9 +34,4 @@ private[spark] class YarnClusterScheduler(sc: SparkContext) extends YarnSchedule logInfo("YarnClusterScheduler.postStartHook done") } - override def stop() { - super.stop() - ApplicationMaster.sparkContextStopped(sc) - } - } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 8ab7b21c22139462f5cdbb9978876df3bbf8a2e2..fb7926f6a1e2808110923a6aea9956d507634c64 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -33,6 +33,7 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging import org.apache.spark.launcher._ import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationStart, @@ -192,6 +193,14 @@ class YarnClusterSuite extends BaseYarnClusterSuite { } } + test("timeout to get SparkContext in cluster mode triggers failure") { + val timeout = 2000 + val finalState = runSpark(false, mainClassName(SparkContextTimeoutApp.getClass), + appArgs = Seq((timeout * 4).toString), + extraConf = Map(AM_MAX_WAIT_TIME.key -> timeout.toString)) + finalState should be (SparkAppHandle.State.FAILED) + } + private def testBasicYarnApp(clientMode: Boolean, conf: Map[String, String] = Map()): Unit = { val result = File.createTempFile("result", null, tempDir) val finalState = runSpark(clientMode, mainClassName(YarnClusterDriver.getClass), @@ -469,3 +478,16 @@ private object YarnLauncherTestApp { } } + +/** + * Used to test code in the AM that detects the SparkContext instance. Expects a single argument + * with the duration to sleep for, in ms. + */ +private object SparkContextTimeoutApp { + + def main(args: Array[String]): Unit = { + val Array(sleepTime) = args + Thread.sleep(java.lang.Long.parseLong(sleepTime)) + } + +}