diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala
index 069e41b6cedd6945317326741029d476b3586cb4..698f07b0a187f4897c34e1fa04fba60935741af9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala
@@ -32,6 +32,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.network.util.JavaUtils
 import org.apache.spark.serializer.JavaSerializer
 import org.apache.spark.sql.SparkSession
+import org.apache.spark.util.UninterruptibleThread
 
 
 /**
@@ -91,18 +92,30 @@ class HDFSMetadataLog[T: ClassTag](sparkSession: SparkSession, path: String)
     serializer.deserialize[T](ByteBuffer.wrap(bytes))
   }
 
+  /**
+   * Store the metadata for the specified batchId and return `true` if successful. If the batchId's
+   * metadata has already been stored, this method will return `false`.
+   *
+   * Note that this method must be called on a [[org.apache.spark.util.UninterruptibleThread]]
+   * so that interrupts can be disabled while writing the batch file. This is because there is a
+   * potential dead-lock in Hadoop "Shell.runCommand" before 2.5.0 (HADOOP-10622). If the thread
+   * running "Shell.runCommand" is interrupted, then the thread can get deadlocked. In our
+   * case, `writeBatch` creates a file using HDFS API and calls "Shell.runCommand" to set the
+   * file permissions, and can get deadlocked if the stream execution thread is stopped by
+   * interrupt. Hence, we make sure that this method is called on [[UninterruptibleThread]] which
+   * allows us to disable interrupts here. Also see SPARK-14131.
+   */
   override def add(batchId: Long, metadata: T): Boolean = {
     get(batchId).map(_ => false).getOrElse {
-      // Only write metadata when the batch has not yet been written.
-      try {
-        writeBatch(batchId, serialize(metadata))
-        true
-      } catch {
-        case e: IOException if "java.lang.InterruptedException" == e.getMessage =>
-          // create may convert InterruptedException to IOException. Let's convert it back to
-          // InterruptedException so that this failure won't crash StreamExecution
-          throw new InterruptedException("Creating file is interrupted")
+      // Only write metadata when the batch has not yet been written
+      Thread.currentThread match {
+        case ut: UninterruptibleThread =>
+          ut.runUninterruptibly { writeBatch(batchId, serialize(metadata)) }
+        case _ =>
+          throw new IllegalStateException(
+            "HDFSMetadataLog.add() must be executed on a o.a.spark.util.UninterruptibleThread")
       }
+      true
     }
   }
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index c90dcc568074392944591dcb6a912d962e4f8bb6..af2229a46bebbf418a7ff1ee5c30a045c48e3221 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -110,7 +110,11 @@ class StreamExecution(
   /* Get the call site in the caller thread; will pass this into the micro batch thread */
   private val callSite = Utils.getCallSite()
 
-  /** The thread that runs the micro-batches of this stream. */
+  /**
+   * The thread that runs the micro-batches of this stream. Note that this thread must be
+   * [[org.apache.spark.util.UninterruptibleThread]] to avoid potential deadlocks in using
+   * [[HDFSMetadataLog]]. See SPARK-14131 for more details.
+   */
   private[sql] val microBatchThread =
     new UninterruptibleThread(s"stream execution thread for $name") {
       override def run(): Unit = {
@@ -269,19 +273,11 @@ class StreamExecution(
    * batchId counter is incremented and a new log entry is written with the newest offsets.
    */
   private def constructNextBatch(): Unit = {
-    // There is a potential dead-lock in Hadoop "Shell.runCommand" before 2.5.0 (HADOOP-10622).
-    // If we interrupt some thread running Shell.runCommand, we may hit this issue.
-    // As "FileStreamSource.getOffset" will create a file using HDFS API and call "Shell.runCommand"
-    // to set the file permission, we should not interrupt "microBatchThread" when running this
-    // method. See SPARK-14131.
-    //
     // Check to see what new data is available.
     val hasNewData = {
       awaitBatchLock.lock()
       try {
-        val newData = microBatchThread.runUninterruptibly {
-          uniqueSources.flatMap(s => s.getOffset.map(o => s -> o))
-        }
+        val newData = uniqueSources.flatMap(s => s.getOffset.map(o => s -> o))
         availableOffsets ++= newData
 
         if (dataAvailable) {
@@ -295,16 +291,8 @@ class StreamExecution(
       }
     }
     if (hasNewData) {
-      // There is a potential dead-lock in Hadoop "Shell.runCommand" before 2.5.0 (HADOOP-10622).
-      // If we interrupt some thread running Shell.runCommand, we may hit this issue.
-      // As "offsetLog.add" will create a file using HDFS API and call "Shell.runCommand" to set
-      // the file permission, we should not interrupt "microBatchThread" when running this method.
-      // See SPARK-14131.
-      microBatchThread.runUninterruptibly {
-        assert(
-          offsetLog.add(currentBatchId, availableOffsets.toCompositeOffset(sources)),
-          s"Concurrent update to the log.  Multiple streaming jobs detected for $currentBatchId")
-      }
+      assert(offsetLog.add(currentBatchId, availableOffsets.toCompositeOffset(sources)),
+        s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId")
       logInfo(s"Committed offsets for batch $currentBatchId.")
     } else {
       awaitBatchLock.lock()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala
index a7b2cfe7d0a49cc3ba7a1e3d375510a9dfd5fe7d..39fd1f0cd37bbe97e5d7148c244bac01d5e9b797 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala
@@ -190,7 +190,7 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext {
     }
   }
 
-  test("compact") {
+  testWithUninterruptibleThread("compact") {
     withSQLConf(SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL.key -> "3") {
       withFileStreamSinkLog { sinkLog =>
         for (batchId <- 0 to 10) {
@@ -210,7 +210,7 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext {
     }
   }
 
-  test("delete expired file") {
+  testWithUninterruptibleThread("delete expired file") {
     // Set FILE_SINK_LOG_CLEANUP_DELAY to 0 so that we can detect the deleting behaviour
     // deterministically
     withSQLConf(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala
index ef2b479a5636fdae7c65c11f6ede61135efd24dd..ab5a2d253b94a157aa88e5a307fa1bdeec0ee2c4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala
@@ -33,6 +33,7 @@ import org.apache.spark.{SparkConf, SparkFunSuite}
 import org.apache.spark.sql.execution.streaming.FakeFileSystem._
 import org.apache.spark.sql.execution.streaming.HDFSMetadataLog.{FileContextManager, FileManager, FileSystemManager}
 import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.util.UninterruptibleThread
 
 class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext {
 
@@ -56,7 +57,7 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext {
     }
   }
 
-  test("HDFSMetadataLog: basic") {
+  testWithUninterruptibleThread("HDFSMetadataLog: basic") {
     withTempDir { temp =>
       val dir = new File(temp, "dir") // use non-existent directory to test whether log make the dir
       val metadataLog = new HDFSMetadataLog[String](spark, dir.getAbsolutePath)
@@ -81,7 +82,8 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext {
     }
   }
 
-  testQuietly("HDFSMetadataLog: fallback from FileContext to FileSystem") {
+  testWithUninterruptibleThread(
+    "HDFSMetadataLog: fallback from FileContext to FileSystem", quietly = true) {
     spark.conf.set(
       s"fs.$scheme.impl",
       classOf[FakeFileSystem].getName)
@@ -101,7 +103,7 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext {
     }
   }
 
-  test("HDFSMetadataLog: restart") {
+  testWithUninterruptibleThread("HDFSMetadataLog: restart") {
     withTempDir { temp =>
       val metadataLog = new HDFSMetadataLog[String](spark, temp.getAbsolutePath)
       assert(metadataLog.add(0, "batch0"))
@@ -124,7 +126,7 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext {
       val waiter = new Waiter
       val maxBatchId = 100
       for (id <- 0 until 10) {
-        new Thread() {
+        new UninterruptibleThread(s"HDFSMetadataLog: metadata directory collision - thread $id") {
           override def run(): Unit = waiter {
             val metadataLog =
               new HDFSMetadataLog[String](spark, temp.getAbsolutePath)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index 5286ee5bc23d33b857144769f7722368048cdd5b..d4d8e3e4e83d5cb74d9630e8a396f9ab47bb4acd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -22,6 +22,7 @@ import java.util.UUID
 
 import scala.language.implicitConversions
 import scala.util.Try
+import scala.util.control.NonFatal
 
 import org.apache.hadoop.conf.Configuration
 import org.scalatest.BeforeAndAfterAll
@@ -34,7 +35,7 @@ import org.apache.spark.sql.catalyst.FunctionIdentifier
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.execution.FilterExec
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{UninterruptibleThread, Utils}
 
 /**
  * Helper trait that should be extended by all SQL test suites.
@@ -247,6 +248,46 @@ private[sql] trait SQLTestUtils
       }
     }
   }
+
+  /** Run a test on a separate [[UninterruptibleThread]]. */
+  protected def testWithUninterruptibleThread(name: String, quietly: Boolean = false)
+    (body: => Unit): Unit = {
+    val timeoutMillis = 10000
+    @transient var ex: Throwable = null
+
+    def runOnThread(): Unit = {
+      val thread = new UninterruptibleThread(s"Testing thread for test $name") {
+        override def run(): Unit = {
+          try {
+            body
+          } catch {
+            case NonFatal(e) =>
+              ex = e
+          }
+        }
+      }
+      thread.setDaemon(true)
+      thread.start()
+      thread.join(timeoutMillis)
+      if (thread.isAlive) {
+        thread.interrupt()
+        // If this interrupt does not work, then this thread is most likely running something that
+        // is not interruptible. There is not much point to wait for the thread to termniate, and
+        // we rather let the JVM terminate the thread on exit.
+        fail(
+          s"Test '$name' running on o.a.s.util.UninterruptibleThread timed out after" +
+            s" $timeoutMillis ms")
+      } else if (ex != null) {
+        throw ex
+      }
+    }
+
+    if (quietly) {
+      testQuietly(name) { runOnThread() }
+    } else {
+      test(name) { runOnThread() }
+    }
+  }
 }
 
 private[sql] object SQLTestUtils {