diff --git a/core/src/main/scala/spark/ui/jobs/JobProgressUI.scala b/core/src/main/scala/spark/ui/jobs/JobProgressUI.scala
index 36b1cd00edd234de55f099f64ce6967716be9d9d..84730cc091b3145851c1ebf1920fce3cac856305 100644
--- a/core/src/main/scala/spark/ui/jobs/JobProgressUI.scala
+++ b/core/src/main/scala/spark/ui/jobs/JobProgressUI.scala
@@ -51,7 +51,7 @@ private[spark] class JobProgressListener extends SparkListener {
   val stageToTasksComplete = HashMap[Int, Int]()
   val stageToTasksFailed = HashMap[Int, Int]()
   val stageToTaskInfos =
-    HashMap[Int, ArrayBuffer[(TaskInfo, TaskMetrics, Option[ExceptionFailure])]]()
+    HashMap[Int, ArrayBuffer[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]]()
 
   override def onJobStart(jobStart: SparkListenerJobStart) {}
 
@@ -78,17 +78,17 @@ private[spark] class JobProgressListener extends SparkListener {
 
   override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
     val sid = taskEnd.task.stageId
-    val (failureInfo, metrics): (Option[ExceptionFailure], TaskMetrics) =
+    val (failureInfo, metrics): (Option[ExceptionFailure], Option[TaskMetrics]) =
       taskEnd.reason match {
         case e: ExceptionFailure =>
           stageToTasksFailed(sid) = stageToTasksFailed.getOrElse(sid, 0) + 1
-          (Some(e), e.metrics.get)
+          (Some(e), e.metrics)
         case _ =>
           stageToTasksComplete(sid) = stageToTasksComplete.getOrElse(sid, 0) + 1
-          (None, taskEnd.taskMetrics)
+          (None, Some(taskEnd.taskMetrics))
       }
     val taskList = stageToTaskInfos.getOrElse(
-      sid, ArrayBuffer[(TaskInfo, TaskMetrics, Option[ExceptionFailure])]())
+      sid, ArrayBuffer[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]())
     taskList += ((taskEnd.taskInfo, metrics, failureInfo))
     stageToTaskInfos(sid) = taskList
   }
@@ -111,7 +111,7 @@ private[spark] class JobProgressListener extends SparkListener {
   def hasShuffleRead(stageID: Int): Boolean = {
     // This is written in a slightly complicated way to avoid having to scan all tasks
     for (s <- stageToTaskInfos.get(stageID).getOrElse(Seq())) {
-      if (s._2 != null) return s._2.shuffleReadMetrics.isDefined
+      if (s._2 != null) return s._2.flatMap(m => m.shuffleReadMetrics).isDefined
     }
     return false // No tasks have finished for this stage
   }
@@ -120,7 +120,7 @@ private[spark] class JobProgressListener extends SparkListener {
   def hasShuffleWrite(stageID: Int): Boolean = {
     // This is written in a slightly complicated way to avoid having to scan all tasks
     for (s <- stageToTaskInfos.get(stageID).getOrElse(Seq())) {
-      if (s._2 != null) return s._2.shuffleWriteMetrics.isDefined
+      if (s._2 != null) return s._2.flatMap(m => m.shuffleWriteMetrics).isDefined
     }
     return false // No tasks have finished for this stage
   }
diff --git a/core/src/main/scala/spark/ui/jobs/StagePage.scala b/core/src/main/scala/spark/ui/jobs/StagePage.scala
index 49e84880cf6438098b3a1d8d873581617c6f2aa5..51b82b6a8c2b54bf58edfaeb6592e7e5815382f1 100644
--- a/core/src/main/scala/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/spark/ui/jobs/StagePage.scala
@@ -52,7 +52,7 @@ private[spark] class StagePage(parent: JobProgressUI) {
       }
       else {
         val serviceTimes = validTasks.map{case (info, metrics, exception) =>
-          metrics.executorRunTime.toDouble}
+          metrics.get.executorRunTime.toDouble}
         val serviceQuantiles = "Duration" +: Distribution(serviceTimes).get.getQuantiles().map(
           ms => parent.formatDuration(ms.toLong))
 
@@ -61,13 +61,13 @@ private[spark] class StagePage(parent: JobProgressUI) {
 
         val shuffleReadSizes = validTasks.map {
           case(info, metrics, exception) =>
-            metrics.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L).toDouble
+            metrics.get.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L).toDouble
         }
         val shuffleReadQuantiles = "Shuffle Read (Remote)" +: getQuantileCols(shuffleReadSizes)
 
         val shuffleWriteSizes = validTasks.map {
           case(info, metrics, exception) =>
-            metrics.shuffleWriteMetrics.map(_.shuffleBytesWritten).getOrElse(0L).toDouble
+            metrics.get.shuffleWriteMetrics.map(_.shuffleBytesWritten).getOrElse(0L).toDouble
         }
         val shuffleWriteQuantiles = "Shuffle Write" +: getQuantileCols(shuffleWriteSizes)
 
@@ -87,21 +87,21 @@ private[spark] class StagePage(parent: JobProgressUI) {
   }
 
 
-  def taskRow(taskData: (TaskInfo, TaskMetrics, Option[ExceptionFailure])): Seq[Node] = {
+  def taskRow(taskData: (TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])): Seq[Node] = {
     def fmtStackTrace(trace: Seq[StackTraceElement]): Seq[Node] =
       trace.map(e => <span style="display:block;">{e.toString}</span>)
     val (info, metrics, exception) = taskData
     <tr>
       <td>{info.taskId}</td>
-      <td sorttable_customkey={Option(metrics).map{m => m.executorRunTime.toString}.getOrElse("1")}>
-        {Option(metrics).map{m => parent.formatDuration(m.executorRunTime)}.getOrElse("")}
+      <td sorttable_customkey={metrics.map{m => m.executorRunTime.toString}.getOrElse("1")}>
+        {metrics.map{m => parent.formatDuration(m.executorRunTime)}.getOrElse("")}
       </td>
       <td>{info.taskLocality}</td>
       <td>{info.hostPort}</td>
       <td>{dateFmt.format(new Date(info.launchTime))}</td>
-      {Option(metrics).flatMap{m => m.shuffleReadMetrics}.map{s =>
+      {metrics.flatMap{m => m.shuffleReadMetrics}.map{s =>
         <td>{Utils.memoryBytesToString(s.remoteBytesRead)}</td>}.getOrElse("")}
-      {Option(metrics).flatMap{m => m.shuffleWriteMetrics}.map{s =>
+      {metrics.flatMap{m => m.shuffleWriteMetrics}.map{s =>
         <td>{Utils.memoryBytesToString(s.shuffleBytesWritten)}</td>}.getOrElse("")}
       <td>{exception.map(e =>
         <span>