Skip to content
Snippets Groups Projects
Commit e3fec51f authored by Marcelo Vanzin's avatar Marcelo Vanzin
Browse files

[SPARK-16930][YARN] Fix a couple of races in cluster app initialization.

There are two narrow races that could cause the ApplicationMaster to miss
when the user application instantiates the SparkContext, which could cause
app failures when nothing was wrong with the app. It was also possible for
a failing application to get stuck in the loop that waits for the context
for a long time, instead of failing quickly.

The change uses a promise to track the SparkContext instance, which gets
rid of the races and allows for some simplification of the code.

Tested with existing unit tests, and a new one being added to test the
timeout code.

Author: Marcelo Vanzin <vanzin@cloudera.com>

Closes #14542 from vanzin/SPARK-16930.
parent 928ca1c6
No related branches found
No related tags found
No related merge requests found
......@@ -20,9 +20,11 @@ package org.apache.spark.deploy.yarn
import java.io.{File, IOException}
import java.lang.reflect.InvocationTargetException
import java.net.{Socket, URI, URL}
import java.util.concurrent.atomic.AtomicReference
import java.util.concurrent.{TimeoutException, TimeUnit}
import scala.collection.mutable.HashMap
import scala.concurrent.Promise
import scala.concurrent.duration.Duration
import scala.util.control.NonFatal
import org.apache.hadoop.fs.{FileSystem, Path}
......@@ -106,12 +108,11 @@ private[spark] class ApplicationMaster(
// Next wait interval before allocator poll.
private var nextAllocationInterval = initialAllocationInterval
// Fields used in client mode.
private var rpcEnv: RpcEnv = null
private var amEndpoint: RpcEndpointRef = _
// Fields used in cluster mode.
private val sparkContextRef = new AtomicReference[SparkContext](null)
// In cluster mode, used to tell the AM when the user's SparkContext has been initialized.
private val sparkContextPromise = Promise[SparkContext]()
private var credentialRenewer: AMCredentialRenewer = _
......@@ -316,23 +317,15 @@ private[spark] class ApplicationMaster(
}
private def sparkContextInitialized(sc: SparkContext) = {
sparkContextRef.synchronized {
sparkContextRef.compareAndSet(null, sc)
sparkContextRef.notifyAll()
}
}
private def sparkContextStopped(sc: SparkContext) = {
sparkContextRef.compareAndSet(sc, null)
sparkContextPromise.success(sc)
}
private def registerAM(
_sparkConf: SparkConf,
_rpcEnv: RpcEnv,
driverRef: RpcEndpointRef,
uiAddress: String,
securityMgr: SecurityManager) = {
val sc = sparkContextRef.get()
val appId = client.getAttemptId().getApplicationId().toString()
val attemptId = client.getAttemptId().getAttemptId().toString()
val historyAddress =
......@@ -341,7 +334,6 @@ private[spark] class ApplicationMaster(
.map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}/${attemptId}" }
.getOrElse("")
val _sparkConf = if (sc != null) sc.getConf else sparkConf
val driverUrl = RpcEndpointAddress(
_sparkConf.get("spark.driver.host"),
_sparkConf.get("spark.driver.port").toInt,
......@@ -385,21 +377,35 @@ private[spark] class ApplicationMaster(
// This a bit hacky, but we need to wait until the spark.driver.port property has
// been set by the Thread executing the user class.
val sc = waitForSparkContextInitialized()
// If there is no SparkContext at this point, just fail the app.
if (sc == null) {
finish(FinalApplicationStatus.FAILED,
ApplicationMaster.EXIT_SC_NOT_INITED,
"Timed out waiting for SparkContext.")
} else {
rpcEnv = sc.env.rpcEnv
val driverRef = runAMEndpoint(
sc.getConf.get("spark.driver.host"),
sc.getConf.get("spark.driver.port"),
isClusterMode = true)
registerAM(rpcEnv, driverRef, sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr)
logInfo("Waiting for spark context initialization...")
val totalWaitTime = sparkConf.get(AM_MAX_WAIT_TIME)
try {
val sc = ThreadUtils.awaitResult(sparkContextPromise.future,
Duration(totalWaitTime, TimeUnit.MILLISECONDS))
if (sc != null) {
rpcEnv = sc.env.rpcEnv
val driverRef = runAMEndpoint(
sc.getConf.get("spark.driver.host"),
sc.getConf.get("spark.driver.port"),
isClusterMode = true)
registerAM(sc.getConf, rpcEnv, driverRef, sc.ui.map(_.appUIAddress).getOrElse(""),
securityMgr)
} else {
// Sanity check; should never happen in normal operation, since sc should only be null
// if the user app did not create a SparkContext.
if (!finished) {
throw new IllegalStateException("SparkContext is null but app is still running!")
}
}
userClassThread.join()
} catch {
case e: SparkException if e.getCause().isInstanceOf[TimeoutException] =>
logError(
s"SparkContext did not initialize after waiting for $totalWaitTime ms. " +
"Please check earlier log output for errors. Failing the application.")
finish(FinalApplicationStatus.FAILED,
ApplicationMaster.EXIT_SC_NOT_INITED,
"Timed out waiting for SparkContext.")
}
}
......@@ -409,7 +415,8 @@ private[spark] class ApplicationMaster(
clientMode = true)
val driverRef = waitForSparkDriver()
addAmIpFilter()
registerAM(rpcEnv, driverRef, sparkConf.get("spark.driver.appUIAddress", ""), securityMgr)
registerAM(sparkConf, rpcEnv, driverRef, sparkConf.get("spark.driver.appUIAddress", ""),
securityMgr)
// In client mode the actor will stop the reporter thread.
reporterThread.join()
......@@ -525,26 +532,6 @@ private[spark] class ApplicationMaster(
}
}
private def waitForSparkContextInitialized(): SparkContext = {
logInfo("Waiting for spark context initialization")
sparkContextRef.synchronized {
val totalWaitTime = sparkConf.get(AM_MAX_WAIT_TIME)
val deadline = System.currentTimeMillis() + totalWaitTime
while (sparkContextRef.get() == null && System.currentTimeMillis < deadline && !finished) {
logInfo("Waiting for spark context initialization ... ")
sparkContextRef.wait(10000L)
}
val sparkContext = sparkContextRef.get()
if (sparkContext == null) {
logError(("SparkContext did not initialize after waiting for %d ms. Please check earlier"
+ " log output for errors. Failing the application.").format(totalWaitTime))
}
sparkContext
}
}
private def waitForSparkDriver(): RpcEndpointRef = {
logInfo("Waiting for Spark driver to be reachable.")
var driverUp = false
......@@ -647,6 +634,13 @@ private[spark] class ApplicationMaster(
ApplicationMaster.EXIT_EXCEPTION_USER_CLASS,
"User class threw exception: " + cause)
}
sparkContextPromise.tryFailure(e.getCause())
} finally {
// Notify the thread waiting for the SparkContext, in case the application did not
// instantiate one. This will do nothing when the user code instantiates a SparkContext
// (with the correct master), or when the user code throws an exception (due to the
// tryFailure above).
sparkContextPromise.trySuccess(null)
}
}
}
......@@ -759,10 +753,6 @@ object ApplicationMaster extends Logging {
master.sparkContextInitialized(sc)
}
private[spark] def sparkContextStopped(sc: SparkContext): Boolean = {
master.sparkContextStopped(sc)
}
private[spark] def getAttemptId(): ApplicationAttemptId = {
master.getAttemptId
}
......
......@@ -34,9 +34,4 @@ private[spark] class YarnClusterScheduler(sc: SparkContext) extends YarnSchedule
logInfo("YarnClusterScheduler.postStartHook done")
}
override def stop() {
super.stop()
ApplicationMaster.sparkContextStopped(sc)
}
}
......@@ -33,6 +33,7 @@ import org.scalatest.concurrent.Eventually._
import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.deploy.yarn.config._
import org.apache.spark.internal.Logging
import org.apache.spark.launcher._
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationStart,
......@@ -192,6 +193,14 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
}
}
test("timeout to get SparkContext in cluster mode triggers failure") {
val timeout = 2000
val finalState = runSpark(false, mainClassName(SparkContextTimeoutApp.getClass),
appArgs = Seq((timeout * 4).toString),
extraConf = Map(AM_MAX_WAIT_TIME.key -> timeout.toString))
finalState should be (SparkAppHandle.State.FAILED)
}
private def testBasicYarnApp(clientMode: Boolean, conf: Map[String, String] = Map()): Unit = {
val result = File.createTempFile("result", null, tempDir)
val finalState = runSpark(clientMode, mainClassName(YarnClusterDriver.getClass),
......@@ -469,3 +478,16 @@ private object YarnLauncherTestApp {
}
}
/**
* Used to test code in the AM that detects the SparkContext instance. Expects a single argument
* with the duration to sleep for, in ms.
*/
private object SparkContextTimeoutApp {
def main(args: Array[String]): Unit = {
val Array(sleepTime) = args
Thread.sleep(java.lang.Long.parseLong(sleepTime))
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment