diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 0470fbeed1ada70e48fa5a4f4879bb60d01dc45c..abb6a8331664f7638936b2cce34a4a299bc8bdbe 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -220,8 +220,14 @@ class SparkContext(config: SparkConf) extends Logging { new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup, conf) // Initialize the Spark UI, registering all associated listeners - private[spark] val ui = new SparkUI(this) - ui.bind() + private[spark] val ui: Option[SparkUI] = + if (conf.getBoolean("spark.ui.enabled", true)) { + Some(new SparkUI(this)) + } else { + // For tests, do not enable the UI + None + } + ui.foreach(_.bind()) /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ val hadoopConfiguration: Configuration = { @@ -1008,7 +1014,7 @@ class SparkContext(config: SparkConf) extends Logging { /** Shut down the SparkContext. */ def stop() { postApplicationEnd() - ui.stop() + ui.foreach(_.stop()) // Do this only if not stopped already - best case effort. // prevent NPE if stopped more than once. val dagSchedulerCopy = dagScheduler diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 2a3711ae2a78c8e25ef603111a0b99c40ba97883..04046e2e5d11dc617198fd392e560223d34756bb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -292,7 +292,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A logInfo(s"Add WebUI Filter. $filterName, $filterParams, $proxyBase") conf.set("spark.ui.filters", filterName) conf.set(s"spark.$filterName.params", filterParams) - JettyUtils.addFilters(scheduler.sc.ui.getHandlers, conf) + scheduler.sc.ui.foreach { ui => JettyUtils.addFilters(ui.getHandlers, conf) } } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala index 4f7133c4bc17c23c682166e95ee7d8e578739a9b..b781842000e6dadfef9ba40dc04e21fe346e0c47 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala @@ -17,7 +17,6 @@ package org.apache.spark.scheduler.cluster -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{Path, FileSystem} import org.apache.spark.{Logging, SparkContext, SparkEnv} @@ -46,16 +45,17 @@ private[spark] class SimrSchedulerBackend( val conf = new Configuration() val fs = FileSystem.get(conf) + val appUIAddress = sc.ui.map(_.appUIAddress).getOrElse("") logInfo("Writing to HDFS file: " + driverFilePath) logInfo("Writing Akka address: " + driverUrl) - logInfo("Writing Spark UI Address: " + sc.ui.appUIAddress) + logInfo("Writing Spark UI Address: " + appUIAddress) // Create temporary file to prevent race condition where executors get empty driverUrl file val temp = fs.create(tmpPath, true) temp.writeUTF(driverUrl) temp.writeInt(maxCores) - temp.writeUTF(sc.ui.appUIAddress) + temp.writeUTF(appUIAddress) temp.close() // "Atomic" rename diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 32138e5246700d98fbf1fb478f0c9ad46f037b14..c1d5ce0a360750f63639d00f161deb9b99743f8c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -63,8 +63,10 @@ private[spark] class SparkDeploySchedulerBackend( val javaOpts = sparkJavaOpts ++ extraJavaOpts val command = Command("org.apache.spark.executor.CoarseGrainedExecutorBackend", args, sc.executorEnvs, classPathEntries, libraryPathEntries, javaOpts) + val appUIAddress = sc.ui.map(_.appUIAddress).getOrElse("") + val eventLogDir = sc.eventLogger.map(_.logDir) val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command, - sc.ui.appUIAddress, sc.eventLogger.map(_.logDir)) + appUIAddress, eventLogDir) client = new AppClient(sc.env.actorSystem, masters, appDesc, this, conf) client.start() diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala index 038746d2eda4b9fd43591d3471532071a629df37..2f5664295670188db05946fb64c11e299a656250 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -36,11 +36,25 @@ import scala.xml.Node class UISuite extends FunSuite { + /** + * Create a test SparkContext with the SparkUI enabled. + * It is safe to `get` the SparkUI directly from the SparkContext returned here. + */ + private def newSparkContext(): SparkContext = { + val conf = new SparkConf() + .setMaster("local") + .setAppName("test") + .set("spark.ui.enabled", "true") + val sc = new SparkContext(conf) + assert(sc.ui.isDefined) + sc + } + ignore("basic ui visibility") { - withSpark(new SparkContext("local", "test")) { sc => + withSpark(newSparkContext()) { sc => // test if the ui is visible, and all the expected tabs are visible eventually(timeout(10 seconds), interval(50 milliseconds)) { - val html = Source.fromURL(sc.ui.appUIAddress).mkString + val html = Source.fromURL(sc.ui.get.appUIAddress).mkString assert(!html.contains("random data that should not be present")) assert(html.toLowerCase.contains("stages")) assert(html.toLowerCase.contains("storage")) @@ -51,7 +65,7 @@ class UISuite extends FunSuite { } ignore("visibility at localhost:4040") { - withSpark(new SparkContext("local", "test")) { sc => + withSpark(newSparkContext()) { sc => // test if visible from http://localhost:4040 eventually(timeout(10 seconds), interval(50 milliseconds)) { val html = Source.fromURL("http://localhost:4040").mkString @@ -61,8 +75,8 @@ class UISuite extends FunSuite { } ignore("attaching a new tab") { - withSpark(new SparkContext("local", "test")) { sc => - val sparkUI = sc.ui + withSpark(newSparkContext()) { sc => + val sparkUI = sc.ui.get val newTab = new WebUITab(sparkUI, "foo") { attachPage(new WebUIPage("") { @@ -73,7 +87,7 @@ class UISuite extends FunSuite { } sparkUI.attachTab(newTab) eventually(timeout(10 seconds), interval(50 milliseconds)) { - val html = Source.fromURL(sc.ui.appUIAddress).mkString + val html = Source.fromURL(sparkUI.appUIAddress).mkString assert(!html.contains("random data that should not be present")) // check whether new page exists @@ -87,7 +101,7 @@ class UISuite extends FunSuite { } eventually(timeout(10 seconds), interval(50 milliseconds)) { - val html = Source.fromURL(sc.ui.appUIAddress.stripSuffix("/") + "/foo").mkString + val html = Source.fromURL(sparkUI.appUIAddress.stripSuffix("/") + "/foo").mkString // check whether new page exists assert(html.contains("magic")) } @@ -129,16 +143,20 @@ class UISuite extends FunSuite { } test("verify appUIAddress contains the scheme") { - withSpark(new SparkContext("local", "test")) { sc => - val uiAddress = sc.ui.appUIAddress - assert(uiAddress.equals("http://" + sc.ui.appUIHostPort)) + withSpark(newSparkContext()) { sc => + val ui = sc.ui.get + val uiAddress = ui.appUIAddress + val uiHostPort = ui.appUIHostPort + assert(uiAddress.equals("http://" + uiHostPort)) } } test("verify appUIAddress contains the port") { - withSpark(new SparkContext("local", "test")) { sc => - val splitUIAddress = sc.ui.appUIAddress.split(':') - assert(splitUIAddress(2).toInt == sc.ui.boundPort) + withSpark(newSparkContext()) { sc => + val ui = sc.ui.get + val splitUIAddress = ui.appUIAddress.split(':') + val boundPort = ui.boundPort + assert(splitUIAddress(2).toInt == boundPort) } } } diff --git a/pom.xml b/pom.xml index 66458e203281b33c7ea1feb6a8e0429779b078f0..c33ea7db3bdc51efff3cac6b94d8988768d00284 100644 --- a/pom.xml +++ b/pom.xml @@ -885,6 +885,7 @@ <java.awt.headless>true</java.awt.headless> <spark.test.home>${session.executionRootDirectory}</spark.test.home> <spark.testing>1</spark.testing> + <spark.ui.enabled>false</spark.ui.enabled> </systemProperties> </configuration> <executions> diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 486de9391387f961e49b90a86941a1e3682a0f87..c968a753c37f9cb1eaebbed1b3bcbaa14c805af1 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -337,7 +337,7 @@ object TestSettings { javaOptions in Test += "-Dspark.test.home=" + sparkHome, javaOptions in Test += "-Dspark.testing=1", javaOptions in Test += "-Dspark.ports.maxRetries=100", - javaOptions in Test += "-Dspark.ui.port=0", + javaOptions in Test += "-Dspark.ui.enabled=false", javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark") .map { case (k,v) => s"-D$k=$v" }.toSeq, diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 101cec1c7a7c22aa6dd14404c943393f0aca4a4a..4fc77bbe1a3676b571a230485af910f122439e58 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -37,7 +37,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.receiver.{ActorSupervisorStrategy, ActorReceiver, Receiver} import org.apache.spark.streaming.scheduler._ -import org.apache.spark.streaming.ui.StreamingTab +import org.apache.spark.streaming.ui.{StreamingJobProgressListener, StreamingTab} import org.apache.spark.util.MetadataCleaner /** @@ -158,7 +158,14 @@ class StreamingContext private[streaming] ( private[streaming] val waiter = new ContextWaiter - private[streaming] val uiTab = new StreamingTab(this) + private[streaming] val progressListener = new StreamingJobProgressListener(this) + + private[streaming] val uiTab: Option[StreamingTab] = + if (conf.getBoolean("spark.ui.enabled", true)) { + Some(new StreamingTab(this)) + } else { + None + } /** Register streaming source to metrics system */ private val streamingSource = new StreamingSource(this) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala index 75f0e8716dc7efa2507934d1cfd75c1372c36eeb..e35a568ddf1151470a4da2c7dd4ad95c585c5910 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala @@ -26,7 +26,7 @@ private[streaming] class StreamingSource(ssc: StreamingContext) extends Source { override val metricRegistry = new MetricRegistry override val sourceName = "%s.StreamingMetrics".format(ssc.sparkContext.appName) - private val streamingListener = ssc.uiTab.listener + private val streamingListener = ssc.progressListener private def registerGauge[T](name: String, f: StreamingJobProgressListener => T, defaultValue: T) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala index 34ac254f337eb6ac332a5b21727e39aae0a471b4..d9d04cd706a04ed286f110ac980725469f7210bd 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala @@ -17,18 +17,31 @@ package org.apache.spark.streaming.ui -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkException} import org.apache.spark.streaming.StreamingContext -import org.apache.spark.ui.SparkUITab +import org.apache.spark.ui.{SparkUI, SparkUITab} -/** Spark Web UI tab that shows statistics of a streaming job */ +import StreamingTab._ + +/** + * Spark Web UI tab that shows statistics of a streaming job. + * This assumes the given SparkContext has enabled its SparkUI. + */ private[spark] class StreamingTab(ssc: StreamingContext) - extends SparkUITab(ssc.sc.ui, "streaming") with Logging { + extends SparkUITab(getSparkUI(ssc), "streaming") with Logging { - val parent = ssc.sc.ui - val listener = new StreamingJobProgressListener(ssc) + val parent = getSparkUI(ssc) + val listener = ssc.progressListener ssc.addStreamingListener(listener) attachPage(new StreamingPage(this)) parent.attachTab(this) } + +private object StreamingTab { + def getSparkUI(ssc: StreamingContext): SparkUI = { + ssc.sc.ui.getOrElse { + throw new SparkException("Parent SparkUI to attach this tab to not found!") + } + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISuite.scala index 2a0db7564915d0fc02b64f5204c308cb078411c6..4c7e43c2943c9ec5c5fe0436cfb1b08d58e49e18 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/UISuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/UISuite.scala @@ -24,13 +24,22 @@ import org.scalatest.FunSuite import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ +import org.apache.spark.SparkConf + class UISuite extends FunSuite { // Ignored: See SPARK-1530 ignore("streaming tab in spark UI") { - val ssc = new StreamingContext("local", "test", Seconds(1)) + val conf = new SparkConf() + .setMaster("local") + .setAppName("test") + .set("spark.ui.enabled", "true") + val ssc = new StreamingContext(conf, Seconds(1)) + assert(ssc.sc.ui.isDefined, "Spark UI is not started!") + val ui = ssc.sc.ui.get + eventually(timeout(10 seconds), interval(50 milliseconds)) { - val html = Source.fromURL(ssc.sparkContext.ui.appUIAddress).mkString + val html = Source.fromURL(ui.appUIAddress).mkString assert(!html.contains("random data that should not be present")) // test if streaming tab exist assert(html.toLowerCase.contains("streaming")) @@ -39,8 +48,7 @@ class UISuite extends FunSuite { } eventually(timeout(10 seconds), interval(50 milliseconds)) { - val html = Source.fromURL( - ssc.sparkContext.ui.appUIAddress.stripSuffix("/") + "/streaming").mkString + val html = Source.fromURL(ui.appUIAddress.stripSuffix("/") + "/streaming").mkString assert(html.toLowerCase.contains("batch")) assert(html.toLowerCase.contains("network")) } diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala new file mode 100644 index 0000000000000000000000000000000000000000..878b6db546032f528dd23f0cd9a75abddf0b5811 --- /dev/null +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -0,0 +1,443 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import java.io.IOException +import java.net.Socket +import java.util.concurrent.atomic.AtomicReference + +import scala.collection.JavaConversions._ +import scala.util.Try + +import akka.actor._ +import akka.remote._ +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.util.ShutdownHookManager +import org.apache.hadoop.yarn.api._ +import org.apache.hadoop.yarn.api.records._ +import org.apache.hadoop.yarn.conf.YarnConfiguration + +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkEnv} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.deploy.history.HistoryServer +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.AddWebUIFilter +import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils} + +/** + * Common application master functionality for Spark on Yarn. + */ +private[spark] class ApplicationMaster(args: ApplicationMasterArguments, + client: YarnRMClient) extends Logging { + // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be + // optimal as more containers are available. Might need to handle this better. + + private val sparkConf = new SparkConf() + private val yarnConf: YarnConfiguration = SparkHadoopUtil.get.newConfiguration(sparkConf) + .asInstanceOf[YarnConfiguration] + private val isDriver = args.userClass != null + + // Default to numExecutors * 2, with minimum of 3 + private val maxNumExecutorFailures = sparkConf.getInt("spark.yarn.max.executor.failures", + sparkConf.getInt("spark.yarn.max.worker.failures", math.max(args.numExecutors * 2, 3))) + + @volatile private var finished = false + @volatile private var finalStatus = FinalApplicationStatus.UNDEFINED + + private var reporterThread: Thread = _ + private var allocator: YarnAllocator = _ + + // Fields used in client mode. + private var actorSystem: ActorSystem = null + private var actor: ActorRef = _ + + // Fields used in cluster mode. + private val sparkContextRef = new AtomicReference[SparkContext](null) + + final def run(): Int = { + val appAttemptId = client.getAttemptId() + + if (isDriver) { + // Set the web ui port to be ephemeral for yarn so we don't conflict with + // other spark processes running on the same box + System.setProperty("spark.ui.port", "0") + + // Set the master property to match the requested mode. + System.setProperty("spark.master", "yarn-cluster") + + // Propagate the application ID so that YarnClusterSchedulerBackend can pick it up. + System.setProperty("spark.yarn.app.id", appAttemptId.getApplicationId().toString()) + } + + logInfo("ApplicationAttemptId: " + appAttemptId) + + val cleanupHook = new Runnable { + override def run() { + // If the SparkContext is still registered, shut it down as a best case effort in case + // users do not call sc.stop or do System.exit(). + val sc = sparkContextRef.get() + if (sc != null) { + logInfo("Invoking sc stop from shutdown hook") + sc.stop() + finish(FinalApplicationStatus.SUCCEEDED) + } + + // Cleanup the staging dir after the app is finished, or if it's the last attempt at + // running the AM. + val maxAppAttempts = client.getMaxRegAttempts(yarnConf) + val isLastAttempt = client.getAttemptId().getAttemptId() >= maxAppAttempts + if (finished || isLastAttempt) { + cleanupStagingDir() + } + } + } + // Use priority 30 as it's higher than HDFS. It's the same priority MapReduce is using. + ShutdownHookManager.get().addShutdownHook(cleanupHook, 30) + + // Call this to force generation of secret so it gets populated into the + // Hadoop UGI. This has to happen before the startUserClass which does a + // doAs in order for the credentials to be passed on to the executor containers. + val securityMgr = new SecurityManager(sparkConf) + + if (isDriver) { + runDriver(securityMgr) + } else { + runExecutorLauncher(securityMgr) + } + + if (finalStatus != FinalApplicationStatus.UNDEFINED) { + finish(finalStatus) + 0 + } else { + 1 + } + } + + final def finish(status: FinalApplicationStatus, diagnostics: String = null) = synchronized { + if (!finished) { + logInfo(s"Finishing ApplicationMaster with $status" + + Option(diagnostics).map(msg => s" (diag message: $msg)").getOrElse("")) + finished = true + finalStatus = status + try { + if (Thread.currentThread() != reporterThread) { + reporterThread.interrupt() + reporterThread.join() + } + } finally { + client.shutdown(status, Option(diagnostics).getOrElse("")) + } + } + } + + private def sparkContextInitialized(sc: SparkContext) = { + sparkContextRef.synchronized { + sparkContextRef.compareAndSet(null, sc) + sparkContextRef.notifyAll() + } + } + + private def sparkContextStopped(sc: SparkContext) = { + sparkContextRef.compareAndSet(sc, null) + } + + private def registerAM(uiAddress: String, securityMgr: SecurityManager) = { + val sc = sparkContextRef.get() + + val appId = client.getAttemptId().getApplicationId().toString() + val historyAddress = + sparkConf.getOption("spark.yarn.historyServer.address") + .map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}" } + .getOrElse("") + + allocator = client.register(yarnConf, + if (sc != null) sc.getConf else sparkConf, + if (sc != null) sc.preferredNodeLocationData else Map(), + uiAddress, + historyAddress, + securityMgr) + + allocator.allocateResources() + reporterThread = launchReporterThread() + } + + private def runDriver(securityMgr: SecurityManager): Unit = { + addAmIpFilter() + val userThread = startUserClass() + + // This a bit hacky, but we need to wait until the spark.driver.port property has + // been set by the Thread executing the user class. + val sc = waitForSparkContextInitialized() + + // If there is no SparkContext at this point, just fail the app. + if (sc == null) { + finish(FinalApplicationStatus.FAILED, "Timed out waiting for SparkContext.") + } else { + registerAM(sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr) + try { + userThread.join() + } finally { + // In cluster mode, ask the reporter thread to stop since the user app is finished. + reporterThread.interrupt() + } + } + } + + private def runExecutorLauncher(securityMgr: SecurityManager): Unit = { + actorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0, + conf = sparkConf, securityManager = securityMgr)._1 + actor = waitForSparkDriver() + addAmIpFilter() + registerAM(sparkConf.get("spark.driver.appUIAddress", ""), securityMgr) + + // In client mode the actor will stop the reporter thread. + reporterThread.join() + finalStatus = FinalApplicationStatus.SUCCEEDED + } + + private def launchReporterThread(): Thread = { + // Ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapses. + val expiryInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000) + + // we want to be reasonably responsive without causing too many requests to RM. + val schedulerInterval = + sparkConf.getLong("spark.yarn.scheduler.heartbeat.interval-ms", 5000) + + // must be <= expiryInterval / 2. + val interval = math.max(0, math.min(expiryInterval / 2, schedulerInterval)) + + val t = new Thread { + override def run() { + while (!finished) { + checkNumExecutorsFailed() + if (!finished) { + logDebug("Sending progress") + allocator.allocateResources() + try { + Thread.sleep(interval) + } catch { + case e: InterruptedException => + } + } + } + } + } + // setting to daemon status, though this is usually not a good idea. + t.setDaemon(true) + t.setName("Reporter") + t.start() + logInfo("Started progress reporter thread - sleep time : " + interval) + t + } + + /** + * Clean up the staging directory. + */ + private def cleanupStagingDir() { + val fs = FileSystem.get(yarnConf) + var stagingDirPath: Path = null + try { + val preserveFiles = sparkConf.get("spark.yarn.preserve.staging.files", "false").toBoolean + if (!preserveFiles) { + stagingDirPath = new Path(System.getenv("SPARK_YARN_STAGING_DIR")) + if (stagingDirPath == null) { + logError("Staging directory is null") + return + } + logInfo("Deleting staging directory " + stagingDirPath) + fs.delete(stagingDirPath, true) + } + } catch { + case ioe: IOException => + logError("Failed to cleanup staging dir " + stagingDirPath, ioe) + } + } + + private def waitForSparkContextInitialized(): SparkContext = { + logInfo("Waiting for spark context initialization") + try { + sparkContextRef.synchronized { + var count = 0 + val waitTime = 10000L + val numTries = sparkConf.getInt("spark.yarn.ApplicationMaster.waitTries", 10) + while (sparkContextRef.get() == null && count < numTries && !finished) { + logInfo("Waiting for spark context initialization ... " + count) + count = count + 1 + sparkContextRef.wait(waitTime) + } + + val sparkContext = sparkContextRef.get() + assert(sparkContext != null || count >= numTries) + if (sparkContext == null) { + logError( + "Unable to retrieve sparkContext inspite of waiting for %d, numTries = %d".format( + count * waitTime, numTries)) + } + sparkContext + } + } + } + + private def waitForSparkDriver(): ActorRef = { + logInfo("Waiting for Spark driver to be reachable.") + var driverUp = false + val hostport = args.userArgs(0) + val (driverHost, driverPort) = Utils.parseHostPort(hostport) + while (!driverUp) { + try { + val socket = new Socket(driverHost, driverPort) + socket.close() + logInfo("Driver now available: %s:%s".format(driverHost, driverPort)) + driverUp = true + } catch { + case e: Exception => + logError("Failed to connect to driver at %s:%s, retrying ...". + format(driverHost, driverPort)) + Thread.sleep(100) + } + } + sparkConf.set("spark.driver.host", driverHost) + sparkConf.set("spark.driver.port", driverPort.toString) + + val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format( + SparkEnv.driverActorSystemName, + driverHost, + driverPort.toString, + CoarseGrainedSchedulerBackend.ACTOR_NAME) + actorSystem.actorOf(Props(new MonitorActor(driverUrl)), name = "YarnAM") + } + + private def checkNumExecutorsFailed() = { + if (allocator.getNumExecutorsFailed >= maxNumExecutorFailures) { + finish(FinalApplicationStatus.FAILED, "Max number of executor failures reached.") + + val sc = sparkContextRef.get() + if (sc != null) { + logInfo("Invoking sc stop from checkNumExecutorsFailed") + sc.stop() + } + } + } + + /** Add the Yarn IP filter that is required for properly securing the UI. */ + private def addAmIpFilter() = { + val amFilter = "org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter" + val proxy = client.getProxyHostAndPort(yarnConf) + val parts = proxy.split(":") + val proxyBase = System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV) + val uriBase = "http://" + proxy + proxyBase + val params = "PROXY_HOST=" + parts(0) + "," + "PROXY_URI_BASE=" + uriBase + + if (isDriver) { + System.setProperty("spark.ui.filters", amFilter) + System.setProperty(s"spark.$amFilter.params", params) + } else { + actor ! AddWebUIFilter(amFilter, params, proxyBase) + } + } + + private def startUserClass(): Thread = { + logInfo("Starting the user JAR in a separate Thread") + System.setProperty("spark.executor.instances", args.numExecutors.toString) + val mainMethod = Class.forName(args.userClass, false, + Thread.currentThread.getContextClassLoader).getMethod("main", classOf[Array[String]]) + + val t = new Thread { + override def run() { + var status = FinalApplicationStatus.FAILED + try { + // Copy + val mainArgs = new Array[String](args.userArgs.size) + args.userArgs.copyToArray(mainArgs, 0, args.userArgs.size) + mainMethod.invoke(null, mainArgs) + // Some apps have "System.exit(0)" at the end. The user thread will stop here unless + // it has an uncaught exception thrown out. It needs a shutdown hook to set SUCCEEDED. + status = FinalApplicationStatus.SUCCEEDED + } finally { + logDebug("Finishing main") + } + finalStatus = status + } + } + t.setName("Driver") + t.start() + t + } + + // Actor used to monitor the driver when running in client deploy mode. + private class MonitorActor(driverUrl: String) extends Actor { + + var driver: ActorSelection = _ + + override def preStart() = { + logInfo("Listen to driver: " + driverUrl) + driver = context.actorSelection(driverUrl) + // Send a hello message to establish the connection, after which + // we can monitor Lifecycle Events. + driver ! "Hello" + context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + } + + override def receive = { + case x: DisassociatedEvent => + logInfo(s"Driver terminated or disconnected! Shutting down. $x") + finish(FinalApplicationStatus.SUCCEEDED) + case x: AddWebUIFilter => + logInfo(s"Add WebUI Filter. $x") + driver ! x + } + + } + +} + +object ApplicationMaster extends Logging { + + private var master: ApplicationMaster = _ + + def main(args: Array[String]) = { + SignalLogger.register(log) + val amArgs = new ApplicationMasterArguments(args) + SparkHadoopUtil.get.runAsSparkUser { () => + master = new ApplicationMaster(amArgs, new YarnRMClientImpl(amArgs)) + System.exit(master.run()) + } + } + + private[spark] def sparkContextInitialized(sc: SparkContext) = { + master.sparkContextInitialized(sc) + } + + private[spark] def sparkContextStopped(sc: SparkContext) = { + master.sparkContextStopped(sc) + } + +} + +/** + * This object does not provide any special functionality. It exists so that it's easy to tell + * apart the client-mode AM from the cluster-mode AM when using tools such as ps or jps. + */ +object ExecutorLauncher { + + def main(args: Array[String]) = { + ApplicationMaster.main(args) + } + +} diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 833e249f9f612645400966abf9d5a98e4e50a1c9..40d9bffe8e6b3c3b34d37e84ae735ce604a8375a 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -55,8 +55,10 @@ private[spark] class YarnClientSchedulerBackend( val driverHost = conf.get("spark.driver.host") val driverPort = conf.get("spark.driver.port") val hostport = driverHost + ":" + driverPort - conf.set("spark.driver.appUIAddress", sc.ui.appUIHostPort) - conf.set("spark.driver.appUIHistoryAddress", YarnSparkHadoopUtil.getUIHistoryAddress(sc, conf)) + sc.ui.foreach { ui => + conf.set("spark.driver.appUIAddress", ui.appUIHostPort) + conf.set("spark.driver.appUIHistoryAddress", YarnSparkHadoopUtil.getUIHistoryAddress(sc, conf)) + } val argsArrayBuf = new ArrayBuffer[String]() argsArrayBuf += (