Skip to content
Snippets Groups Projects
Commit e1bcf6af authored by Reynold Xin's avatar Reynold Xin
Browse files

[SPARK-10827] replace volatile with Atomic* in AppClient.scala.

This is a followup for #9317 to replace volatile fields with AtomicBoolean and AtomicReference.

Author: Reynold Xin <rxin@databricks.com>

Closes #9611 from rxin/SPARK-10827.
parent 2d76e44b
No related branches found
No related tags found
No related merge requests found
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
package org.apache.spark.deploy.client package org.apache.spark.deploy.client
import java.util.concurrent._ import java.util.concurrent._
import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference}
import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture}
import scala.util.control.NonFatal import scala.util.control.NonFatal
...@@ -49,9 +50,9 @@ private[spark] class AppClient( ...@@ -49,9 +50,9 @@ private[spark] class AppClient(
private val REGISTRATION_TIMEOUT_SECONDS = 20 private val REGISTRATION_TIMEOUT_SECONDS = 20
private val REGISTRATION_RETRIES = 3 private val REGISTRATION_RETRIES = 3
@volatile private var endpoint: RpcEndpointRef = null private val endpoint = new AtomicReference[RpcEndpointRef]
@volatile private var appId: String = null private val appId = new AtomicReference[String]
@volatile private var registered = false private val registered = new AtomicBoolean(false)
private class ClientEndpoint(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint private class ClientEndpoint(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint
with Logging { with Logging {
...@@ -59,16 +60,17 @@ private[spark] class AppClient( ...@@ -59,16 +60,17 @@ private[spark] class AppClient(
private var master: Option[RpcEndpointRef] = None private var master: Option[RpcEndpointRef] = None
// To avoid calling listener.disconnected() multiple times // To avoid calling listener.disconnected() multiple times
private var alreadyDisconnected = false private var alreadyDisconnected = false
@volatile private var alreadyDead = false // To avoid calling listener.dead() multiple times // To avoid calling listener.dead() multiple times
@volatile private var registerMasterFutures: Array[JFuture[_]] = null private val alreadyDead = new AtomicBoolean(false)
@volatile private var registrationRetryTimer: JScheduledFuture[_] = null private val registerMasterFutures = new AtomicReference[Array[JFuture[_]]]
private val registrationRetryTimer = new AtomicReference[JScheduledFuture[_]]
// A thread pool for registering with masters. Because registering with a master is a blocking // A thread pool for registering with masters. Because registering with a master is a blocking
// action, this thread pool must be able to create "masterRpcAddresses.size" threads at the same // action, this thread pool must be able to create "masterRpcAddresses.size" threads at the same
// time so that we can register with all masters. // time so that we can register with all masters.
private val registerMasterThreadPool = new ThreadPoolExecutor( private val registerMasterThreadPool = new ThreadPoolExecutor(
0, 0,
masterRpcAddresses.size, // Make sure we can register with all masters at the same time masterRpcAddresses.length, // Make sure we can register with all masters at the same time
60L, TimeUnit.SECONDS, 60L, TimeUnit.SECONDS,
new SynchronousQueue[Runnable](), new SynchronousQueue[Runnable](),
ThreadUtils.namedThreadFactory("appclient-register-master-threadpool")) ThreadUtils.namedThreadFactory("appclient-register-master-threadpool"))
...@@ -100,7 +102,7 @@ private[spark] class AppClient( ...@@ -100,7 +102,7 @@ private[spark] class AppClient(
for (masterAddress <- masterRpcAddresses) yield { for (masterAddress <- masterRpcAddresses) yield {
registerMasterThreadPool.submit(new Runnable { registerMasterThreadPool.submit(new Runnable {
override def run(): Unit = try { override def run(): Unit = try {
if (registered) { if (registered.get) {
return return
} }
logInfo("Connecting to master " + masterAddress.toSparkURL + "...") logInfo("Connecting to master " + masterAddress.toSparkURL + "...")
...@@ -123,22 +125,22 @@ private[spark] class AppClient( ...@@ -123,22 +125,22 @@ private[spark] class AppClient(
* nthRetry means this is the nth attempt to register with master. * nthRetry means this is the nth attempt to register with master.
*/ */
private def registerWithMaster(nthRetry: Int) { private def registerWithMaster(nthRetry: Int) {
registerMasterFutures = tryRegisterAllMasters() registerMasterFutures.set(tryRegisterAllMasters())
registrationRetryTimer = registrationRetryThread.scheduleAtFixedRate(new Runnable { registrationRetryTimer.set(registrationRetryThread.scheduleAtFixedRate(new Runnable {
override def run(): Unit = { override def run(): Unit = {
Utils.tryOrExit { Utils.tryOrExit {
if (registered) { if (registered.get) {
registerMasterFutures.foreach(_.cancel(true)) registerMasterFutures.get.foreach(_.cancel(true))
registerMasterThreadPool.shutdownNow() registerMasterThreadPool.shutdownNow()
} else if (nthRetry >= REGISTRATION_RETRIES) { } else if (nthRetry >= REGISTRATION_RETRIES) {
markDead("All masters are unresponsive! Giving up.") markDead("All masters are unresponsive! Giving up.")
} else { } else {
registerMasterFutures.foreach(_.cancel(true)) registerMasterFutures.get.foreach(_.cancel(true))
registerWithMaster(nthRetry + 1) registerWithMaster(nthRetry + 1)
} }
} }
} }
}, REGISTRATION_TIMEOUT_SECONDS, REGISTRATION_TIMEOUT_SECONDS, TimeUnit.SECONDS) }, REGISTRATION_TIMEOUT_SECONDS, REGISTRATION_TIMEOUT_SECONDS, TimeUnit.SECONDS))
} }
/** /**
...@@ -163,10 +165,10 @@ private[spark] class AppClient( ...@@ -163,10 +165,10 @@ private[spark] class AppClient(
// RegisteredApplications due to an unstable network. // RegisteredApplications due to an unstable network.
// 2. Receive multiple RegisteredApplication from different masters because the master is // 2. Receive multiple RegisteredApplication from different masters because the master is
// changing. // changing.
appId = appId_ appId.set(appId_)
registered = true registered.set(true)
master = Some(masterRef) master = Some(masterRef)
listener.connected(appId) listener.connected(appId.get)
case ApplicationRemoved(message) => case ApplicationRemoved(message) =>
markDead("Master removed our application: %s".format(message)) markDead("Master removed our application: %s".format(message))
...@@ -178,7 +180,7 @@ private[spark] class AppClient( ...@@ -178,7 +180,7 @@ private[spark] class AppClient(
cores)) cores))
// FIXME if changing master and `ExecutorAdded` happen at the same time (the order is not // FIXME if changing master and `ExecutorAdded` happen at the same time (the order is not
// guaranteed), `ExecutorStateChanged` may be sent to a dead master. // guaranteed), `ExecutorStateChanged` may be sent to a dead master.
sendToMaster(ExecutorStateChanged(appId, id, ExecutorState.RUNNING, None, None)) sendToMaster(ExecutorStateChanged(appId.get, id, ExecutorState.RUNNING, None, None))
listener.executorAdded(fullId, workerId, hostPort, cores, memory) listener.executorAdded(fullId, workerId, hostPort, cores, memory)
case ExecutorUpdated(id, state, message, exitStatus) => case ExecutorUpdated(id, state, message, exitStatus) =>
...@@ -193,13 +195,13 @@ private[spark] class AppClient( ...@@ -193,13 +195,13 @@ private[spark] class AppClient(
logInfo("Master has changed, new master is at " + masterRef.address.toSparkURL) logInfo("Master has changed, new master is at " + masterRef.address.toSparkURL)
master = Some(masterRef) master = Some(masterRef)
alreadyDisconnected = false alreadyDisconnected = false
masterRef.send(MasterChangeAcknowledged(appId)) masterRef.send(MasterChangeAcknowledged(appId.get))
} }
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case StopAppClient => case StopAppClient =>
markDead("Application has been stopped.") markDead("Application has been stopped.")
sendToMaster(UnregisterApplication(appId)) sendToMaster(UnregisterApplication(appId.get))
context.reply(true) context.reply(true)
stop() stop()
...@@ -263,18 +265,18 @@ private[spark] class AppClient( ...@@ -263,18 +265,18 @@ private[spark] class AppClient(
} }
def markDead(reason: String) { def markDead(reason: String) {
if (!alreadyDead) { if (!alreadyDead.get) {
listener.dead(reason) listener.dead(reason)
alreadyDead = true alreadyDead.set(true)
} }
} }
override def onStop(): Unit = { override def onStop(): Unit = {
if (registrationRetryTimer != null) { if (registrationRetryTimer.get != null) {
registrationRetryTimer.cancel(true) registrationRetryTimer.get.cancel(true)
} }
registrationRetryThread.shutdownNow() registrationRetryThread.shutdownNow()
registerMasterFutures.foreach(_.cancel(true)) registerMasterFutures.get.foreach(_.cancel(true))
registerMasterThreadPool.shutdownNow() registerMasterThreadPool.shutdownNow()
askAndReplyThreadPool.shutdownNow() askAndReplyThreadPool.shutdownNow()
} }
...@@ -283,19 +285,19 @@ private[spark] class AppClient( ...@@ -283,19 +285,19 @@ private[spark] class AppClient(
def start() { def start() {
// Just launch an rpcEndpoint; it will call back into the listener. // Just launch an rpcEndpoint; it will call back into the listener.
endpoint = rpcEnv.setupEndpoint("AppClient", new ClientEndpoint(rpcEnv)) endpoint.set(rpcEnv.setupEndpoint("AppClient", new ClientEndpoint(rpcEnv)))
} }
def stop() { def stop() {
if (endpoint != null) { if (endpoint.get != null) {
try { try {
val timeout = RpcUtils.askRpcTimeout(conf) val timeout = RpcUtils.askRpcTimeout(conf)
timeout.awaitResult(endpoint.ask[Boolean](StopAppClient)) timeout.awaitResult(endpoint.get.ask[Boolean](StopAppClient))
} catch { } catch {
case e: TimeoutException => case e: TimeoutException =>
logInfo("Stop request to Master timed out; it may already be shut down.") logInfo("Stop request to Master timed out; it may already be shut down.")
} }
endpoint = null endpoint.set(null)
} }
} }
...@@ -306,8 +308,8 @@ private[spark] class AppClient( ...@@ -306,8 +308,8 @@ private[spark] class AppClient(
* @return whether the request is acknowledged. * @return whether the request is acknowledged.
*/ */
def requestTotalExecutors(requestedTotal: Int): Boolean = { def requestTotalExecutors(requestedTotal: Int): Boolean = {
if (endpoint != null && appId != null) { if (endpoint.get != null && appId.get != null) {
endpoint.askWithRetry[Boolean](RequestExecutors(appId, requestedTotal)) endpoint.get.askWithRetry[Boolean](RequestExecutors(appId.get, requestedTotal))
} else { } else {
logWarning("Attempted to request executors before driver fully initialized.") logWarning("Attempted to request executors before driver fully initialized.")
false false
...@@ -319,8 +321,8 @@ private[spark] class AppClient( ...@@ -319,8 +321,8 @@ private[spark] class AppClient(
* @return whether the kill request is acknowledged. * @return whether the kill request is acknowledged.
*/ */
def killExecutors(executorIds: Seq[String]): Boolean = { def killExecutors(executorIds: Seq[String]): Boolean = {
if (endpoint != null && appId != null) { if (endpoint.get != null && appId.get != null) {
endpoint.askWithRetry[Boolean](KillExecutors(appId, executorIds)) endpoint.get.askWithRetry[Boolean](KillExecutors(appId.get, executorIds))
} else { } else {
logWarning("Attempted to kill executors before driver fully initialized.") logWarning("Attempted to kill executors before driver fully initialized.")
false false
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment