diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala
index e6a48a06a03f596e4453994f8e629cfb6337676b..6af60d60d56d390fd0f84fa4d41250ba6d71a9a7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala
@@ -63,8 +63,34 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path:
   val metadataPath = new Path(path)
   protected val fileManager = createFileManager()
 
-  if (!fileManager.exists(metadataPath)) {
-    fileManager.mkdirs(metadataPath)
+  runUninterruptiblyIfLocal {
+    if (!fileManager.exists(metadataPath)) {
+      fileManager.mkdirs(metadataPath)
+    }
+  }
+
+  private def runUninterruptiblyIfLocal[T](body: => T): T = {
+    if (fileManager.isLocalFileSystem && Thread.currentThread.isInstanceOf[UninterruptibleThread]) {
+      // When using a local file system, some file system APIs like "create" or "mkdirs" must be
+      // called in [[org.apache.spark.util.UninterruptibleThread]] so that interrupts can be
+      // disabled.
+      //
+      // This is because there is a potential dead-lock in Hadoop "Shell.runCommand" before
+      // 2.5.0 (HADOOP-10622). If the thread running "Shell.runCommand" is interrupted, then
+      // the thread can get deadlocked. In our case, file system APIs like "create" or "mkdirs"
+      // will call "Shell.runCommand" to set the file permission if using the local file system,
+      // and can get deadlocked if the stream execution thread is stopped by interrupt.
+      //
+      // Hence, we use "runUninterruptibly" here to disable interrupts here. (SPARK-14131)
+      Thread.currentThread.asInstanceOf[UninterruptibleThread].runUninterruptibly {
+        body
+      }
+    } else {
+      // For a distributed file system, such as HDFS or S3, if the network is broken, write
+      // operations may just hang until timeout. We should enable interrupts to allow stopping
+      // the query fast.
+      body
+    }
   }
 
   /**
@@ -109,39 +135,14 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path:
   override def add(batchId: Long, metadata: T): Boolean = {
     get(batchId).map(_ => false).getOrElse {
       // Only write metadata when the batch has not yet been written
-      if (fileManager.isLocalFileSystem) {
-        Thread.currentThread match {
-          case ut: UninterruptibleThread =>
-            // When using a local file system, "writeBatch" must be called on a
-            // [[org.apache.spark.util.UninterruptibleThread]] so that interrupts can be disabled
-            // while writing the batch file.
-            //
-            // This is because Hadoop "Shell.runCommand" swallows InterruptException (HADOOP-14084).
-            // If the user tries to stop a query, and the thread running "Shell.runCommand" is
-            // interrupted, then InterruptException will be dropped and the query will be still
-            // running. (Note: `writeBatch` creates a file using HDFS APIs and will call
-            // "Shell.runCommand" to set the file permission if using the local file system)
-            //
-            // Hence, we make sure that "writeBatch" is called on [[UninterruptibleThread]] which
-            // allows us to disable interrupts here, in order to propagate the interrupt state
-            // correctly. Also see SPARK-19599.
-            ut.runUninterruptibly { writeBatch(batchId, metadata) }
-          case _ =>
-            throw new IllegalStateException(
-              "HDFSMetadataLog.add() on a local file system must be executed on " +
-                "a o.a.spark.util.UninterruptibleThread")
-        }
-      } else {
-        // For a distributed file system, such as HDFS or S3, if the network is broken, write
-        // operations may just hang until timeout. We should enable interrupts to allow stopping
-        // the query fast.
+      runUninterruptiblyIfLocal {
         writeBatch(batchId, metadata)
       }
       true
     }
   }
 
-  def writeTempBatch(metadata: T): Option[Path] = {
+  private def writeTempBatch(metadata: T): Option[Path] = {
     while (true) {
       val tempPath = new Path(metadataPath, s".${UUID.randomUUID.toString}.tmp")
       try {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index 9346a6769d4f52588a5214c552df0892ac4afb1d..93face4390acba9eca84b0aaf5e0d0521d8caa4e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming
 
 import java.util.UUID
 import java.util.concurrent.{CountDownLatch, TimeUnit}
+import java.util.concurrent.atomic.AtomicReference
 import java.util.concurrent.locks.ReentrantLock
 
 import scala.collection.mutable.ArrayBuffer
@@ -157,8 +158,7 @@ class StreamExecution(
   }
 
   /** Defines the internal state of execution */
-  @volatile
-  private var state: State = INITIALIZING
+  private val state = new AtomicReference[State](INITIALIZING)
 
   @volatile
   var lastExecution: IncrementalExecution = _
@@ -178,8 +178,9 @@ class StreamExecution(
 
   /**
    * The thread that runs the micro-batches of this stream. Note that this thread must be
-   * [[org.apache.spark.util.UninterruptibleThread]] to avoid swallowing `InterruptException` when
-   * using [[HDFSMetadataLog]]. See SPARK-19599 for more details.
+   * [[org.apache.spark.util.UninterruptibleThread]] to workaround KAFKA-1894: interrupting a
+   * running `KafkaConsumer` may cause endless loop, and HADOOP-10622: interrupting
+   * `Shell.runCommand` causes deadlock. (SPARK-14131)
    */
   val microBatchThread =
     new StreamExecutionThread(s"stream execution thread for $prettyIdString") {
@@ -200,10 +201,10 @@ class StreamExecution(
   val offsetLog = new OffsetSeqLog(sparkSession, checkpointFile("offsets"))
 
   /** Whether all fields of the query have been initialized */
-  private def isInitialized: Boolean = state != INITIALIZING
+  private def isInitialized: Boolean = state.get != INITIALIZING
 
   /** Whether the query is currently active or not */
-  override def isActive: Boolean = state != TERMINATED
+  override def isActive: Boolean = state.get != TERMINATED
 
   /** Returns the [[StreamingQueryException]] if the query was terminated by an exception. */
   override def exception: Option[StreamingQueryException] = Option(streamDeathCause)
@@ -249,53 +250,56 @@ class StreamExecution(
       updateStatusMessage("Initializing sources")
       // force initialization of the logical plan so that the sources can be created
       logicalPlan
-      state = ACTIVE
-      // Unblock `awaitInitialization`
-      initializationLatch.countDown()
-
-      triggerExecutor.execute(() => {
-        startTrigger()
-
-        val isTerminated =
-          if (isActive) {
-            reportTimeTaken("triggerExecution") {
-              if (currentBatchId < 0) {
-                // We'll do this initialization only once
-                populateStartOffsets()
-                logDebug(s"Stream running from $committedOffsets to $availableOffsets")
-              } else {
-                constructNextBatch()
+      if (state.compareAndSet(INITIALIZING, ACTIVE)) {
+        // Unblock `awaitInitialization`
+        initializationLatch.countDown()
+
+        triggerExecutor.execute(() => {
+          startTrigger()
+
+          val continueToRun =
+            if (isActive) {
+              reportTimeTaken("triggerExecution") {
+                if (currentBatchId < 0) {
+                  // We'll do this initialization only once
+                  populateStartOffsets()
+                  logDebug(s"Stream running from $committedOffsets to $availableOffsets")
+                } else {
+                  constructNextBatch()
+                }
+                if (dataAvailable) {
+                  currentStatus = currentStatus.copy(isDataAvailable = true)
+                  updateStatusMessage("Processing new data")
+                  runBatch()
+                }
               }
+
+              // Report trigger as finished and construct progress object.
+              finishTrigger(dataAvailable)
               if (dataAvailable) {
-                currentStatus = currentStatus.copy(isDataAvailable = true)
-                updateStatusMessage("Processing new data")
-                runBatch()
+                // We'll increase currentBatchId after we complete processing current batch's data
+                currentBatchId += 1
+              } else {
+                currentStatus = currentStatus.copy(isDataAvailable = false)
+                updateStatusMessage("Waiting for data to arrive")
+                Thread.sleep(pollingDelayMs)
               }
-            }
-
-            // Report trigger as finished and construct progress object.
-            finishTrigger(dataAvailable)
-            if (dataAvailable) {
-              // We'll increase currentBatchId after we complete processing current batch's data
-              currentBatchId += 1
+              true
             } else {
-              currentStatus = currentStatus.copy(isDataAvailable = false)
-              updateStatusMessage("Waiting for data to arrive")
-              Thread.sleep(pollingDelayMs)
+              false
             }
-            true
-          } else {
-            false
-          }
 
-        // Update committed offsets.
-        committedOffsets ++= availableOffsets
-        updateStatusMessage("Waiting for next trigger")
-        isTerminated
-      })
-      updateStatusMessage("Stopped")
+          // Update committed offsets.
+          committedOffsets ++= availableOffsets
+          updateStatusMessage("Waiting for next trigger")
+          continueToRun
+        })
+        updateStatusMessage("Stopped")
+      } else {
+        // `stop()` is already called. Let `finally` finish the cleanup.
+      }
     } catch {
-      case _: InterruptedException if state == TERMINATED => // interrupted by stop()
+      case _: InterruptedException if state.get == TERMINATED => // interrupted by stop()
         updateStatusMessage("Stopped")
       case e: Throwable =>
         streamDeathCause = new StreamingQueryException(
@@ -318,7 +322,7 @@ class StreamExecution(
       initializationLatch.countDown()
 
       try {
-        state = TERMINATED
+        state.set(TERMINATED)
         currentStatus = status.copy(isTriggerActive = false, isDataAvailable = false)
 
         // Update metrics and status
@@ -562,7 +566,7 @@ class StreamExecution(
   override def stop(): Unit = {
     // Set the state to TERMINATED so that the batching thread knows that it was interrupted
     // intentionally
-    state = TERMINATED
+    state.set(TERMINATED)
     if (microBatchThread.isAlive) {
       microBatchThread.interrupt()
       microBatchThread.join()