diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala
index 2761d39e37029146624914c8b1a35528e7400b70..efd26486abaec988d4497c34e7e0b06f6dd18139 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala
@@ -24,7 +24,7 @@ import scala.concurrent.duration._
 import scala.util.control.NonFatal
 
 import org.apache.spark.{SparkConf, SparkException}
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{ThreadUtils, Utils}
 
 /**
  * An exception thrown if RpcTimeout modifies a [[TimeoutException]].
@@ -72,15 +72,9 @@ private[spark] class RpcTimeout(val duration: FiniteDuration, val timeoutProp: S
    *         is still not ready
    */
   def awaitResult[T](future: Future[T]): T = {
-    val wrapAndRethrow: PartialFunction[Throwable, T] = {
-      case NonFatal(t) =>
-        throw new SparkException("Exception thrown in awaitResult", t)
-    }
     try {
-      // scalastyle:off awaitresult
-      Await.result(future, duration)
-      // scalastyle:on awaitresult
-    } catch addMessageIfTimeout.orElse(wrapAndRethrow)
+      ThreadUtils.awaitResult(future, duration)
+    } catch addMessageIfTimeout
   }
 }
 
diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
index 60a6e82c6f90d33de43e5a55005f25508451b6aa..1aa4456ed01b4d135d4959c0ceaa875ddc42079b 100644
--- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
@@ -19,7 +19,7 @@ package org.apache.spark.util
 
 import java.util.concurrent._
 
-import scala.concurrent.{Await, Awaitable, ExecutionContext, ExecutionContextExecutor}
+import scala.concurrent.{Awaitable, ExecutionContext, ExecutionContextExecutor}
 import scala.concurrent.duration.Duration
 import scala.concurrent.forkjoin.{ForkJoinPool => SForkJoinPool, ForkJoinWorkerThread => SForkJoinWorkerThread}
 import scala.util.control.NonFatal
@@ -180,39 +180,30 @@ private[spark] object ThreadUtils {
 
   // scalastyle:off awaitresult
   /**
-   * Preferred alternative to `Await.result()`. This method wraps and re-throws any exceptions
-   * thrown by the underlying `Await` call, ensuring that this thread's stack trace appears in
-   * logs.
-   */
-  @throws(classOf[SparkException])
-  def awaitResult[T](awaitable: Awaitable[T], atMost: Duration): T = {
-    try {
-      Await.result(awaitable, atMost)
-      // scalastyle:on awaitresult
-    } catch {
-      case NonFatal(t) =>
-        throw new SparkException("Exception thrown in awaitResult: ", t)
-    }
-  }
-
-  /**
-   * Calls `Awaitable.result` directly to avoid using `ForkJoinPool`'s `BlockingContext`, wraps
-   * and re-throws any exceptions with nice stack track.
+   * Preferred alternative to `Await.result()`.
+   *
+   * This method wraps and re-throws any exceptions thrown by the underlying `Await` call, ensuring
+   * that this thread's stack trace appears in logs.
    *
-   * Codes running in the user's thread may be in a thread of Scala ForkJoinPool. As concurrent
-   * executions in ForkJoinPool may see some [[ThreadLocal]] value unexpectedly, this method
-   * basically prevents ForkJoinPool from running other tasks in the current waiting thread.
+   * In addition, it calls `Awaitable.result` directly to avoid using `ForkJoinPool`'s
+   * `BlockingContext`. Codes running in the user's thread may be in a thread of Scala ForkJoinPool.
+   * As concurrent executions in ForkJoinPool may see some [[ThreadLocal]] value unexpectedly, this
+   * method basically prevents ForkJoinPool from running other tasks in the current waiting thread.
+   * In general, we should use this method because many places in Spark use [[ThreadLocal]] and it's
+   * hard to debug when [[ThreadLocal]]s leak to other tasks.
    */
   @throws(classOf[SparkException])
-  def awaitResultInForkJoinSafely[T](awaitable: Awaitable[T], atMost: Duration): T = {
+  def awaitResult[T](awaitable: Awaitable[T], atMost: Duration): T = {
     try {
       // `awaitPermission` is not actually used anywhere so it's safe to pass in null here.
       // See SPARK-13747.
       val awaitPermission = null.asInstanceOf[scala.concurrent.CanAwait]
-      awaitable.result(Duration.Inf)(awaitPermission)
+      awaitable.result(atMost)(awaitPermission)
     } catch {
-      case NonFatal(t) =>
+      // TimeoutException is thrown in the current thread, so not need to warp the exception.
+      case NonFatal(t) if !t.isInstanceOf[TimeoutException] =>
         throw new SparkException("Exception thrown in awaitResult: ", t)
     }
   }
+  // scalastyle:on awaitresult
 }
diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
index 58664e77d24a52d273f4db0ba83b766394d9df1e..b29a53cffeb51dfc8d2cdf50f00cc20c89d886e2 100644
--- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
@@ -199,10 +199,9 @@ class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Tim
     val f = sc.parallelize(1 to 100, 4)
               .mapPartitions(itr => { Thread.sleep(20); itr })
               .countAsync()
-    val e = intercept[SparkException] {
+    intercept[TimeoutException] {
       ThreadUtils.awaitResult(f, Duration(20, "milliseconds"))
     }
-    assert(e.getCause.isInstanceOf[TimeoutException])
   }
 
   private def testAsyncAction[R](action: RDD[Int] => FutureAction[R]): Unit = {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala
index 83288db92bb43feda155b7de2c5665c6d08b250b..8c4e389e86a7789adf4a570c1639f0ad3a6a3393 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala
@@ -158,10 +158,9 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter {
       0 until rdd.partitions.size, resultHandler, () => Unit)
     // It's an error if the job completes successfully even though no committer was authorized,
     // so throw an exception if the job was allowed to complete.
-    val e = intercept[SparkException] {
+    intercept[TimeoutException] {
       ThreadUtils.awaitResult(futureAction, 5 seconds)
     }
-    assert(e.getCause.isInstanceOf[TimeoutException])
     assert(tempDir.list().size === 0)
   }
 
diff --git a/scalastyle-config.xml b/scalastyle-config.xml
index 48333851efb54af4a955c2a6baa7da91884fcdca..1f48d71cc7a2b93bedab47e3f407066426f85a10 100644
--- a/scalastyle-config.xml
+++ b/scalastyle-config.xml
@@ -200,7 +200,6 @@ This file is divided into 3 sections:
       // scalastyle:off awaitresult
       Await.result(...)
       // scalastyle:on awaitresult
-      If your codes use ThreadLocal and may run in threads created by the user, use ThreadUtils.awaitResultInForkJoinSafely instead.
     ]]></customMessage>
   </check>
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index e6f1de5cb05b842e4a5d4bd885424921fca57771..fb90799534497e3870faadbcaa0b0bc944192e23 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -578,7 +578,7 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode {
   }
 
   override def executeCollect(): Array[InternalRow] = {
-    ThreadUtils.awaitResultInForkJoinSafely(relationFuture, Duration.Inf)
+    ThreadUtils.awaitResult(relationFuture, Duration.Inf)
   }
 }
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
index ce5013daeb1f90dc39f759ef48868e9c2ab21669..7be5d31d4a7658f5aa4ad995fa6ad49e91e2a08f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
@@ -128,8 +128,7 @@ case class BroadcastExchangeExec(
   }
 
   override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
-    ThreadUtils.awaitResultInForkJoinSafely(relationFuture, timeout)
-      .asInstanceOf[broadcast.Broadcast[T]]
+    ThreadUtils.awaitResult(relationFuture, timeout).asInstanceOf[broadcast.Broadcast[T]]
   }
 }