diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
index 8870187f2219cb7ed44070991129c053fbf491f4..7370f9feb68cd2bcd20ca6ae70204601c25e2eeb 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
@@ -329,13 +329,12 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
   override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = synchronized {
     val taskInfo = taskStart.taskInfo
     if (taskInfo != null) {
-      val metrics = TaskMetrics.empty
       val stageData = stageIdToData.getOrElseUpdate((taskStart.stageId, taskStart.stageAttemptId), {
         logWarning("Task start for unknown stage " + taskStart.stageId)
         new StageUIData
       })
       stageData.numActiveTasks += 1
-      stageData.taskData.put(taskInfo.taskId, TaskUIData(taskInfo, Some(metrics)))
+      stageData.taskData.put(taskInfo.taskId, TaskUIData(taskInfo))
     }
     for (
       activeJobsDependentOnStage <- stageIdToActiveJobIds.get(taskStart.stageId);
@@ -405,7 +404,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
         updateAggregateMetrics(stageData, info.executorId, m, oldMetrics)
       }
 
-      val taskData = stageData.taskData.getOrElseUpdate(info.taskId, TaskUIData(info, None))
+      val taskData = stageData.taskData.getOrElseUpdate(info.taskId, TaskUIData(info))
       taskData.updateTaskInfo(info)
       taskData.updateTaskMetrics(taskMetrics)
       taskData.errorMessage = errorMessage
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
index ac1a74ad8029d142e79b51ec08b47b0968b51dcb..8d280bc00c3b3e01c66a7a8f0a5ac46911ece7a8 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
@@ -112,9 +112,9 @@ private[spark] object UIData {
   /**
    * These are kept mutable and reused throughout a task's lifetime to avoid excessive reallocation.
    */
-  class TaskUIData private(
-      private var _taskInfo: TaskInfo,
-      private var _metrics: Option[TaskMetricsUIData]) {
+  class TaskUIData private(private var _taskInfo: TaskInfo) {
+
+    private[this] var _metrics: Option[TaskMetricsUIData] = Some(TaskMetricsUIData.EMPTY)
 
     var errorMessage: Option[String] = None
 
@@ -127,7 +127,7 @@ private[spark] object UIData {
     }
 
     def updateTaskMetrics(metrics: Option[TaskMetrics]): Unit = {
-      _metrics = TaskUIData.toTaskMetricsUIData(metrics)
+      _metrics = metrics.map(TaskMetricsUIData.fromTaskMetrics)
     }
 
     def taskDuration: Option[Long] = {
@@ -140,28 +140,8 @@ private[spark] object UIData {
   }
 
   object TaskUIData {
-    def apply(taskInfo: TaskInfo, metrics: Option[TaskMetrics]): TaskUIData = {
-      new TaskUIData(dropInternalAndSQLAccumulables(taskInfo), toTaskMetricsUIData(metrics))
-    }
-
-    private def toTaskMetricsUIData(metrics: Option[TaskMetrics]): Option[TaskMetricsUIData] = {
-      metrics.map { m =>
-        TaskMetricsUIData(
-          executorDeserializeTime = m.executorDeserializeTime,
-          executorDeserializeCpuTime = m.executorDeserializeCpuTime,
-          executorRunTime = m.executorRunTime,
-          executorCpuTime = m.executorCpuTime,
-          resultSize = m.resultSize,
-          jvmGCTime = m.jvmGCTime,
-          resultSerializationTime = m.resultSerializationTime,
-          memoryBytesSpilled = m.memoryBytesSpilled,
-          diskBytesSpilled = m.diskBytesSpilled,
-          peakExecutionMemory = m.peakExecutionMemory,
-          inputMetrics = InputMetricsUIData(m.inputMetrics),
-          outputMetrics = OutputMetricsUIData(m.outputMetrics),
-          shuffleReadMetrics = ShuffleReadMetricsUIData(m.shuffleReadMetrics),
-          shuffleWriteMetrics = ShuffleWriteMetricsUIData(m.shuffleWriteMetrics))
-      }
+    def apply(taskInfo: TaskInfo): TaskUIData = {
+      new TaskUIData(dropInternalAndSQLAccumulables(taskInfo))
     }
 
     /**
@@ -206,6 +186,28 @@ private[spark] object UIData {
       shuffleReadMetrics: ShuffleReadMetricsUIData,
       shuffleWriteMetrics: ShuffleWriteMetricsUIData)
 
+  object TaskMetricsUIData {
+    def fromTaskMetrics(m: TaskMetrics): TaskMetricsUIData = {
+      TaskMetricsUIData(
+        executorDeserializeTime = m.executorDeserializeTime,
+        executorDeserializeCpuTime = m.executorDeserializeCpuTime,
+        executorRunTime = m.executorRunTime,
+        executorCpuTime = m.executorCpuTime,
+        resultSize = m.resultSize,
+        jvmGCTime = m.jvmGCTime,
+        resultSerializationTime = m.resultSerializationTime,
+        memoryBytesSpilled = m.memoryBytesSpilled,
+        diskBytesSpilled = m.diskBytesSpilled,
+        peakExecutionMemory = m.peakExecutionMemory,
+        inputMetrics = InputMetricsUIData(m.inputMetrics),
+        outputMetrics = OutputMetricsUIData(m.outputMetrics),
+        shuffleReadMetrics = ShuffleReadMetricsUIData(m.shuffleReadMetrics),
+        shuffleWriteMetrics = ShuffleWriteMetricsUIData(m.shuffleWriteMetrics))
+    }
+
+    val EMPTY: TaskMetricsUIData = fromTaskMetrics(TaskMetrics.empty)
+  }
+
   case class InputMetricsUIData(bytesRead: Long, recordsRead: Long)
   object InputMetricsUIData {
     def apply(metrics: InputMetrics): InputMetricsUIData = {
diff --git a/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala b/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala
index 1bfb0c1547ec4eca6c2c46ff3050df482387ffb3..82bd7c4ff6604ceb707dc5b794f2bb8ef94c779a 100644
--- a/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala
+++ b/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala
@@ -31,7 +31,7 @@ class AllStagesResourceSuite extends SparkFunSuite {
     val tasks = new LinkedHashMap[Long, TaskUIData]
     taskLaunchTimes.zipWithIndex.foreach { case (time, idx) =>
       tasks(idx.toLong) = TaskUIData(
-        new TaskInfo(idx, idx, 1, time, "", "", TaskLocality.ANY, false), None)
+        new TaskInfo(idx, idx, 1, time, "", "", TaskLocality.ANY, false))
     }
 
     val stageUiData = new StageUIData()