Skip to content
Snippets Groups Projects
Commit ba80eaf7 authored by Shixiong Zhu's avatar Shixiong Zhu
Browse files

[SPARK-18280][CORE] Fix potential deadlock in `StandaloneSchedulerBackend.dead`

## What changes were proposed in this pull request?

"StandaloneSchedulerBackend.dead" is called in a RPC thread, so it should not call "SparkContext.stop" in the same thread. "SparkContext.stop" will block until all RPC threads exit, if it's called inside a RPC thread, it will be dead-lock.

This PR add a thread local flag inside RPC threads. `SparkContext.stop` uses it to decide if launching a new thread to stop the SparkContext.

## How was this patch tested?

Jenkins

Author: Shixiong Zhu <shixiong@databricks.com>

Closes #15775 from zsxwing/SPARK-18280.
parent 21bbf94b
No related branches found
No related tags found
No related merge requests found
......@@ -1757,8 +1757,26 @@ class SparkContext(config: SparkConf) extends Logging {
*/
def listJars(): Seq[String] = addedJars.keySet.toSeq
// Shut down the SparkContext.
def stop() {
/**
* Shut down the SparkContext.
*/
def stop(): Unit = {
if (env.rpcEnv.isInRPCThread) {
// `stop` will block until all RPC threads exit, so we cannot call stop inside a RPC thread.
// We should launch a new thread to call `stop` to avoid dead-lock.
new Thread("stop-spark-context") {
setDaemon(true)
override def run(): Unit = {
_stop()
}
}.start()
} else {
_stop()
}
}
private def _stop() {
if (LiveListenerBus.withinListenerThread.value) {
throw new SparkException(
s"Cannot stop SparkContext within listener thread of ${LiveListenerBus.name}")
......
......@@ -147,6 +147,10 @@ private[spark] abstract class RpcEnv(conf: SparkConf) {
*/
def openChannel(uri: String): ReadableByteChannel
/**
* Return if the current thread is a RPC thread.
*/
def isInRPCThread: Boolean
}
/**
......
......@@ -201,6 +201,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
/** Message loop used for dispatching messages. */
private class MessageLoop extends Runnable {
override def run(): Unit = {
NettyRpcEnv.rpcThreadFlag.value = true
try {
while (true) {
try {
......
......@@ -408,10 +408,13 @@ private[netty] class NettyRpcEnv(
}
override def isInRPCThread: Boolean = NettyRpcEnv.rpcThreadFlag.value
}
private[netty] object NettyRpcEnv extends Logging {
private[netty] val rpcThreadFlag = new DynamicVariable[Boolean](false)
/**
* When deserializing the [[NettyRpcEndpointRef]], it needs a reference to [[NettyRpcEnv]].
* Use `currentEnv` to wrap the deserialization codes. E.g.,
......
......@@ -870,6 +870,19 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
verify(endpoint, never()).onDisconnected(any())
verify(endpoint, never()).onNetworkError(any(), any())
}
test("isInRPCThread") {
val rpcEndpointRef = env.setupEndpoint("isInRPCThread", new RpcEndpoint {
override val rpcEnv = env
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case m => context.reply(rpcEnv.isInRPCThread)
}
})
assert(rpcEndpointRef.askWithRetry[Boolean]("hello") === true)
assert(env.isInRPCThread === false)
env.stop(rpcEndpointRef)
}
}
class UnserializableClass
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment