diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css
index a2220e761ac98aec57e264396ae3a60396cfc109..db57712c835032c71ff302fa339574c01946744b 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/webui.css
+++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css
@@ -120,6 +120,20 @@ pre {
   border: none;
 }
 
+.stacktrace-details {
+  max-height: 300px;
+  overflow-y: auto;
+  margin: 0;
+  transition: max-height 0.5s ease-out, padding 0.5s ease-out;
+}
+
+.stacktrace-details.collapsed {
+  max-height: 0;
+  padding-top: 0;
+  padding-bottom: 0;
+  border: none;
+}
+
 span.expand-additional-metrics {
   cursor: pointer;
 }
diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
index f45b463fb6f623161ef56bac329da16ebbc40cd0..af5fd8e0ac00c5c71ae24a31affa497b9e45d37b 100644
--- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala
+++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
@@ -83,15 +83,48 @@ case class FetchFailed(
  * :: DeveloperApi ::
  * Task failed due to a runtime exception. This is the most common failure case and also captures
  * user program exceptions.
+ *
+ * `stackTrace` contains the stack trace of the exception itself. It still exists for backward
+ * compatibility. It's better to use `this(e: Throwable, metrics: Option[TaskMetrics])` to
+ * create `ExceptionFailure` as it will handle the backward compatibility properly.
+ *
+ * `fullStackTrace` is a better representation of the stack trace because it contains the whole
+ * stack trace including the exception and its causes
  */
 @DeveloperApi
 case class ExceptionFailure(
     className: String,
     description: String,
     stackTrace: Array[StackTraceElement],
+    fullStackTrace: String,
     metrics: Option[TaskMetrics])
   extends TaskFailedReason {
-  override def toErrorString: String = Utils.exceptionString(className, description, stackTrace)
+
+  private[spark] def this(e: Throwable, metrics: Option[TaskMetrics]) {
+    this(e.getClass.getName, e.getMessage, e.getStackTrace, Utils.exceptionString(e), metrics)
+  }
+
+  override def toErrorString: String =
+    if (fullStackTrace == null) {
+      // fullStackTrace is added in 1.2.0
+      // If fullStackTrace is null, use the old error string for backward compatibility
+      exceptionString(className, description, stackTrace)
+    } else {
+      fullStackTrace
+    }
+
+  /**
+   * Return a nice string representation of the exception, including the stack trace.
+   * Note: It does not include the exception's causes, and is only used for backward compatibility.
+   */
+  private def exceptionString(
+      className: String,
+      description: String,
+      stackTrace: Array[StackTraceElement]): String = {
+    val desc = if (description == null) "" else description
+    val st = if (stackTrace == null) "" else stackTrace.map("        " + _).mkString("\n")
+    s"$className: $desc\n$st"
+  }
 }
 
 /**
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 96114571d6c7754dc5dcda4e1a0767a498d07d0c..caf4d76713d490242020cf92892e2924753d26ab 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -263,7 +263,7 @@ private[spark] class Executor(
             m.executorRunTime = serviceTime
             m.jvmGCTime = gcTime - startGCTime
           }
-          val reason = ExceptionFailure(t.getClass.getName, t.getMessage, t.getStackTrace, metrics)
+          val reason = new ExceptionFailure(t, metrics)
           execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
 
           // Don't forcibly exit unless the exception was inherently fatal, to avoid
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 96114c0423a9e468f61a674f132bb5dfbe9d0ac4..22449517d100fdd587b9a09181176818231ac63d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -1063,7 +1063,7 @@ class DAGScheduler(
         if (runningStages.contains(failedStage)) {
           logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
             s"due to a fetch failure from $mapStage (${mapStage.name})")
-          markStageAsFinished(failedStage, Some("Fetch failure: " + failureMessage))
+          markStageAsFinished(failedStage, Some(failureMessage))
           runningStages -= failedStage
         }
 
@@ -1094,7 +1094,7 @@ class DAGScheduler(
           handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch))
         }
 
-      case ExceptionFailure(className, description, stackTrace, metrics) =>
+      case ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics) =>
         // Do nothing here, left up to the TaskScheduler to decide how to handle user failures
 
       case TaskResultLost =>
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
index 0c1b6f4defdb34f0d967606e1610f293aacc8f8c..be184464e0ae97b54671f96473066029a20494e9 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
@@ -32,10 +32,21 @@ private[spark] class FetchFailedException(
     shuffleId: Int,
     mapId: Int,
     reduceId: Int,
-    message: String)
-  extends Exception(message) {
+    message: String,
+    cause: Throwable = null)
+  extends Exception(message, cause) {
+
+  def this(
+      bmAddress: BlockManagerId,
+      shuffleId: Int,
+      mapId: Int,
+      reduceId: Int,
+      cause: Throwable) {
+    this(bmAddress, shuffleId, mapId, reduceId, cause.getMessage, cause)
+  }
 
-  def toTaskEndReason: TaskEndReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId, message)
+  def toTaskEndReason: TaskEndReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId,
+    Utils.exceptionString(this))
 }
 
 /**
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
index 0d5247f4176d4d3e7677fbc22e0515045f670b43..e3e7434df45b03ebf55549c65436022bb900d8fa 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
@@ -25,7 +25,7 @@ import org.apache.spark._
 import org.apache.spark.serializer.Serializer
 import org.apache.spark.shuffle.FetchFailedException
 import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId}
-import org.apache.spark.util.{CompletionIterator, Utils}
+import org.apache.spark.util.CompletionIterator
 
 private[hash] object BlockStoreShuffleFetcher extends Logging {
   def fetch[T](
@@ -64,8 +64,7 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
           blockId match {
             case ShuffleBlockId(shufId, mapId, _) =>
               val address = statuses(mapId.toInt)._1
-              throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId,
-                Utils.exceptionString(e))
+              throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e)
             case _ =>
               throw new SparkException(
                 "Failed to get block " + blockId + ", which is not a shuffle block", e)
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index 63ed5fc4949c2247ee558d878e53456bcff5122e..250bddbe2f262025fa02b627d7786b5b37c18a7f 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -22,6 +22,8 @@ import javax.servlet.http.HttpServletRequest
 
 import scala.xml.{Node, Unparsed}
 
+import org.apache.commons.lang3.StringEscapeUtils
+
 import org.apache.spark.executor.TaskMetrics
 import org.apache.spark.ui.{ToolTips, WebUIPage, UIUtils}
 import org.apache.spark.ui.jobs.UIData._
@@ -436,13 +438,37 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
             {diskBytesSpilledReadable}
           </td>
         }}
-        <td>
-          {errorMessage.map { e => <pre>{e}</pre> }.getOrElse("")}
-        </td>
+        {errorMessageCell(errorMessage)}
       </tr>
     }
   }
 
+  private def errorMessageCell(errorMessage: Option[String]): Seq[Node] = {
+    val error = errorMessage.getOrElse("")
+    val isMultiline = error.indexOf('\n') >= 0
+    // Display the first line by default
+    val errorSummary = StringEscapeUtils.escapeHtml4(
+      if (isMultiline) {
+        error.substring(0, error.indexOf('\n'))
+      } else {
+        error
+      })
+    val details = if (isMultiline) {
+      // scalastyle:off
+      <span onclick="this.parentNode.querySelector('.stacktrace-details').classList.toggle('collapsed')"
+            class="expand-details">
+        +details
+      </span> ++
+        <div class="stacktrace-details collapsed">
+          <pre>{error}</pre>
+        </div>
+      // scalastyle:on
+    } else {
+      ""
+    }
+    <td>{errorSummary}{details}</td>
+  }
+
   private def getSchedulerDelay(info: TaskInfo, metrics: TaskMetrics): Long = {
     val totalExecutionTime = {
       if (info.gettingResultTime > 0) {
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
index 4ee7f08ab47a24fb68380f3665b7ab31dd82b5d8..3b4866e05956d1224ab71a99d170b98653c3b9fc 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
@@ -22,6 +22,8 @@ import scala.xml.Text
 
 import java.util.Date
 
+import org.apache.commons.lang3.StringEscapeUtils
+
 import org.apache.spark.scheduler.StageInfo
 import org.apache.spark.ui.{ToolTips, UIUtils}
 import org.apache.spark.util.Utils
@@ -195,7 +197,29 @@ private[ui] class FailedStageTable(
 
   override protected def stageRow(s: StageInfo): Seq[Node] = {
     val basicColumns = super.stageRow(s)
-    val failureReason = <td valign="middle"><pre>{s.failureReason.getOrElse("")}</pre></td>
-    basicColumns ++ failureReason
+    val failureReason = s.failureReason.getOrElse("")
+    val isMultiline = failureReason.indexOf('\n') >= 0
+    // Display the first line by default
+    val failureReasonSummary = StringEscapeUtils.escapeHtml4(
+      if (isMultiline) {
+        failureReason.substring(0, failureReason.indexOf('\n'))
+      } else {
+        failureReason
+      })
+    val details = if (isMultiline) {
+      // scalastyle:off
+      <span onclick="this.parentNode.querySelector('.stacktrace-details').classList.toggle('collapsed')"
+            class="expand-details">
+        +details
+      </span> ++
+        <div class="stacktrace-details collapsed">
+          <pre>{failureReason}</pre>
+        </div>
+      // scalastyle:on
+    } else {
+      ""
+    }
+    val failureReasonHtml = <td valign="middle">{failureReasonSummary}{details}</td>
+    basicColumns ++ failureReasonHtml
   }
 }
diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
index f7ae1f7f334de2ce9643e44456feb2731e4ebd6b..f15d0c856663f41643052bbff998e846697c7012 100644
--- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
@@ -287,6 +287,7 @@ private[spark] object JsonProtocol {
         ("Class Name" -> exceptionFailure.className) ~
         ("Description" -> exceptionFailure.description) ~
         ("Stack Trace" -> stackTrace) ~
+        ("Full Stack Trace" -> exceptionFailure.fullStackTrace) ~
         ("Metrics" -> metrics)
       case ExecutorLostFailure(executorId) =>
         ("Executor ID" -> executorId)
@@ -637,8 +638,10 @@ private[spark] object JsonProtocol {
         val className = (json \ "Class Name").extract[String]
         val description = (json \ "Description").extract[String]
         val stackTrace = stackTraceFromJson(json \ "Stack Trace")
+        val fullStackTrace = Utils.jsonOption(json \ "Full Stack Trace").
+          map(_.extract[String]).orNull
         val metrics = Utils.jsonOption(json \ "Metrics").map(taskMetricsFromJson)
-        new ExceptionFailure(className, description, stackTrace, metrics)
+        ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics)
       case `taskResultLost` => TaskResultLost
       case `taskKilled` => TaskKilled
       case `executorLostFailure` =>
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 2cbd38d72caa1e2dff5bd98a636ff81b0123ebf6..a14d6125484fe8d003d3ef149d67865737315409 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -1599,19 +1599,19 @@ private[spark] object Utils extends Logging {
       .orNull
   }
 
-  /** Return a nice string representation of the exception, including the stack trace. */
+  /**
+   * Return a nice string representation of the exception. It will call "printStackTrace" to
+   * recursively generate the stack trace including the exception and its causes.
+   */
   def exceptionString(e: Throwable): String = {
-    if (e == null) "" else exceptionString(getFormattedClassName(e), e.getMessage, e.getStackTrace)
-  }
-
-  /** Return a nice string representation of the exception, including the stack trace. */
-  def exceptionString(
-      className: String,
-      description: String,
-      stackTrace: Array[StackTraceElement]): String = {
-    val desc = if (description == null) "" else description
-    val st = if (stackTrace == null) "" else stackTrace.map("        " + _).mkString("\n")
-    s"$className: $desc\n$st"
+    if (e == null) {
+      ""
+    } else {
+      // Use e.printStackTrace here because e.getStackTrace doesn't include the cause
+      val stringWriter = new StringWriter()
+      e.printStackTrace(new PrintWriter(stringWriter))
+      stringWriter.toString
+    }
   }
 
   /** Return a thread dump of all threads' stacktraces.  Used to capture dumps for the web UI */
diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
index 2efbae689771a5da1a3e1951e697bc1fee5d4307..2608ad4b32e1e28e1155d9bbc642f5c1aec25d97 100644
--- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
@@ -116,7 +116,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc
     val taskFailedReasons = Seq(
       Resubmitted,
       new FetchFailed(null, 0, 0, 0, "ignored"),
-      new ExceptionFailure("Exception", "description", null, None),
+      ExceptionFailure("Exception", "description", null, null, None),
       TaskResultLost,
       TaskKilled,
       ExecutorLostFailure("0"),
diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
index aec1e409db95c90feabfd94542a64e10289a7e02..39e69851e7e3c2b56dd774989f67fe031247f62f 100644
--- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
@@ -109,7 +109,7 @@ class JsonProtocolSuite extends FunSuite {
     // TaskEndReason
     val fetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19,
       "Some exception")
-    val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, None)
+    val exceptionFailure = new ExceptionFailure(exception, None)
     testTaskEndReason(Success)
     testTaskEndReason(Resubmitted)
     testTaskEndReason(fetchFailed)
@@ -127,6 +127,13 @@ class JsonProtocolSuite extends FunSuite {
     testBlockId(StreamBlockId(1, 2L))
   }
 
+  test("ExceptionFailure backward compatibility") {
+    val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, null, None)
+    val oldEvent = JsonProtocol.taskEndReasonToJson(exceptionFailure)
+      .removeField({ _._1 == "Full Stack Trace" })
+    assertEquals(exceptionFailure, JsonProtocol.taskEndReasonFromJson(oldEvent))
+  }
+
   test("StageInfo backward compatibility") {
     val info = makeStageInfo(1, 2, 3, 4L, 5L)
     val newJson = JsonProtocol.stageInfoToJson(info)
@@ -422,6 +429,7 @@ class JsonProtocolSuite extends FunSuite {
         assert(r1.className === r2.className)
         assert(r1.description === r2.description)
         assertSeqEquals(r1.stackTrace, r2.stackTrace, assertStackTraceElementEquals)
+        assert(r1.fullStackTrace === r2.fullStackTrace)
         assertOptionEquals(r1.metrics, r2.metrics, assertTaskMetricsEquals)
       case (TaskResultLost, TaskResultLost) =>
       case (TaskKilled, TaskKilled) =>