From 26432df9cc6ffe569583aa628c6ecd7050b38316 Mon Sep 17 00:00:00 2001
From: Shixiong Zhu <shixiong@databricks.com>
Date: Thu, 8 Dec 2016 11:54:04 -0800
Subject: [PATCH] [SPARK-18751][CORE] Fix deadlock when SparkContext.stop is
 called in Utils.tryOrStopSparkContext

## What changes were proposed in this pull request?

When `SparkContext.stop` is called in `Utils.tryOrStopSparkContext` (the following three places), it will cause deadlock because the `stop` method needs to wait for the thread running `stop` to exit.

- ContextCleaner.keepCleaning
- LiveListenerBus.listenerThread.run
- TaskSchedulerImpl.start

This PR adds `SparkContext.stopInNewThread` and uses it to eliminate the potential deadlock. I also removed my changes in #15775 since they are not necessary now.

## How was this patch tested?

Jenkins

Author: Shixiong Zhu <shixiong@databricks.com>

Closes #16178 from zsxwing/fix-stop-deadlock.
---
 .../scala/org/apache/spark/SparkContext.scala | 35 +++++++++++--------
 .../scala/org/apache/spark/rpc/RpcEnv.scala   |  5 ---
 .../apache/spark/rpc/netty/Dispatcher.scala   |  1 -
 .../apache/spark/rpc/netty/NettyRpcEnv.scala  |  5 ---
 .../apache/spark/scheduler/DAGScheduler.scala |  2 +-
 .../cluster/StandaloneSchedulerBackend.scala  |  2 +-
 .../scala/org/apache/spark/util/Utils.scala   |  2 +-
 .../org/apache/spark/rpc/RpcEnvSuite.scala    | 13 -------
 8 files changed, 23 insertions(+), 42 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index be4dae19df..b42820a8ee 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -1760,25 +1760,30 @@ class SparkContext(config: SparkConf) extends Logging {
   def listJars(): Seq[String] = addedJars.keySet.toSeq
 
   /**
-   * Shut down the SparkContext.
+   * When stopping SparkContext inside Spark components, it's easy to cause dead-lock since Spark
+   * may wait for some internal threads to finish. It's better to use this method to stop
+   * SparkContext instead.
    */
-  def stop(): Unit = {
-    if (env.rpcEnv.isInRPCThread) {
-      // `stop` will block until all RPC threads exit, so we cannot call stop inside a RPC thread.
-      // We should launch a new thread to call `stop` to avoid dead-lock.
-      new Thread("stop-spark-context") {
-        setDaemon(true)
-
-        override def run(): Unit = {
-          _stop()
+  private[spark] def stopInNewThread(): Unit = {
+    new Thread("stop-spark-context") {
+      setDaemon(true)
+
+      override def run(): Unit = {
+        try {
+          SparkContext.this.stop()
+        } catch {
+          case e: Throwable =>
+            logError(e.getMessage, e)
+            throw e
         }
-      }.start()
-    } else {
-      _stop()
-    }
+      }
+    }.start()
   }
 
-  private def _stop() {
+  /**
+   * Shut down the SparkContext.
+   */
+  def stop(): Unit = {
     if (LiveListenerBus.withinListenerThread.value) {
       throw new SparkException(
         s"Cannot stop SparkContext within listener thread of ${LiveListenerBus.name}")
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
index bbc4163814..530743c036 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
@@ -146,11 +146,6 @@ private[spark] abstract class RpcEnv(conf: SparkConf) {
    * @param uri URI with location of the file.
    */
   def openChannel(uri: String): ReadableByteChannel
-
-  /**
-   * Return if the current thread is a RPC thread.
-   */
-  def isInRPCThread: Boolean
 }
 
 /**
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
index 67baabd2cb..a02cf30a5d 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
@@ -201,7 +201,6 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
   /** Message loop used for dispatching messages. */
   private class MessageLoop extends Runnable {
     override def run(): Unit = {
-      NettyRpcEnv.rpcThreadFlag.value = true
       try {
         while (true) {
           try {
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
index 0b8cd144a2..e56943da13 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
@@ -407,14 +407,9 @@ private[netty] class NettyRpcEnv(
     }
 
   }
-
-  override def isInRPCThread: Boolean = NettyRpcEnv.rpcThreadFlag.value
 }
 
 private[netty] object NettyRpcEnv extends Logging {
-
-  private[netty] val rpcThreadFlag = new DynamicVariable[Boolean](false)
-
   /**
    * When deserializing the [[NettyRpcEndpointRef]], it needs a reference to [[NettyRpcEnv]].
    * Use `currentEnv` to wrap the deserialization codes. E.g.,
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 7fde34d897..9378f15b7b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -1661,7 +1661,7 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler
     } catch {
       case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t)
     }
-    dagScheduler.sc.stop()
+    dagScheduler.sc.stopInNewThread()
   }
 
   override def onStop(): Unit = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
index 368cd30a2e..7befdb0c1f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
@@ -139,7 +139,7 @@ private[spark] class StandaloneSchedulerBackend(
         scheduler.error(reason)
       } finally {
         // Ensure the application terminates, as we can no longer run jobs.
-        sc.stop()
+        sc.stopInNewThread()
       }
     }
   }
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 91f5606127..c6ad154167 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -1249,7 +1249,7 @@ private[spark] object Utils extends Logging {
         val currentThreadName = Thread.currentThread().getName
         if (sc != null) {
           logError(s"uncaught error in thread $currentThreadName, stopping SparkContext", t)
-          sc.stop()
+          sc.stopInNewThread()
         }
         if (!NonFatal(t)) {
           logError(s"throw uncaught fatal error in thread $currentThreadName", t)
diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
index aa0705987d..acdf21df9a 100644
--- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
@@ -870,19 +870,6 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
     verify(endpoint, never()).onDisconnected(any())
     verify(endpoint, never()).onNetworkError(any(), any())
   }
-
-  test("isInRPCThread") {
-    val rpcEndpointRef = env.setupEndpoint("isInRPCThread", new RpcEndpoint {
-      override val rpcEnv = env
-
-      override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
-        case m => context.reply(rpcEnv.isInRPCThread)
-      }
-    })
-    assert(rpcEndpointRef.askWithRetry[Boolean]("hello") === true)
-    assert(env.isInRPCThread === false)
-    env.stop(rpcEndpointRef)
-  }
 }
 
 class UnserializableClass
-- 
GitLab