diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index d3d79a27ea1c6c22604549af11b88a7018e6392a..128a82579b800ece5bea43a5f7030209d863def7 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -444,13 +444,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
   @Override
   public Option<MapStatus> stop(boolean success) {
     try {
-      // Update task metrics from accumulators (null in UnsafeShuffleWriterSuite)
-      Map<String, Accumulator<Object>> internalAccumulators =
-        taskContext.internalMetricsToAccumulators();
-      if (internalAccumulators != null) {
-        internalAccumulators.apply(InternalAccumulator.PEAK_EXECUTION_MEMORY())
-          .add(getPeakMemoryUsedBytes());
-      }
+      taskContext.taskMetrics().incPeakExecutionMemory(getPeakMemoryUsedBytes());
 
       if (stopping) {
         return Option.apply(null);
diff --git a/core/src/main/scala/org/apache/spark/Accumulable.scala b/core/src/main/scala/org/apache/spark/Accumulable.scala
index a456d420b8d6a5040306db48d1a65fdc32cf4cac..bde136141f40de753c5ff97d123b33fbc613e150 100644
--- a/core/src/main/scala/org/apache/spark/Accumulable.scala
+++ b/core/src/main/scala/org/apache/spark/Accumulable.scala
@@ -35,40 +35,67 @@ import org.apache.spark.util.Utils
  * [[org.apache.spark.Accumulator]]. They won't always be the same, though -- e.g., imagine you are
  * accumulating a set. You will add items to the set, and you will union two sets together.
  *
+ * All accumulators created on the driver to be used on the executors must be registered with
+ * [[Accumulators]]. This is already done automatically for accumulators created by the user.
+ * Internal accumulators must be explicitly registered by the caller.
+ *
+ * Operations are not thread-safe.
+ *
+ * @param id ID of this accumulator; for internal use only.
  * @param initialValue initial value of accumulator
  * @param param helper object defining how to add elements of type `R` and `T`
  * @param name human-readable name for use in Spark's web UI
  * @param internal if this [[Accumulable]] is internal. Internal [[Accumulable]]s will be reported
  *                 to the driver via heartbeats. For internal [[Accumulable]]s, `R` must be
  *                 thread safe so that they can be reported correctly.
+ * @param countFailedValues whether to accumulate values from failed tasks. This is set to true
+ *                          for system and time metrics like serialization time or bytes spilled,
+ *                          and false for things with absolute values like number of input rows.
+ *                          This should be used for internal metrics only.
  * @tparam R the full accumulated data (result type)
  * @tparam T partial data that can be added in
  */
-class Accumulable[R, T] private[spark] (
-    initialValue: R,
+class Accumulable[R, T] private (
+    val id: Long,
+    @transient initialValue: R,
     param: AccumulableParam[R, T],
     val name: Option[String],
-    internal: Boolean)
+    internal: Boolean,
+    private[spark] val countFailedValues: Boolean)
   extends Serializable {
 
   private[spark] def this(
-      @transient initialValue: R, param: AccumulableParam[R, T], internal: Boolean) = {
-    this(initialValue, param, None, internal)
+      initialValue: R,
+      param: AccumulableParam[R, T],
+      name: Option[String],
+      internal: Boolean,
+      countFailedValues: Boolean) = {
+    this(Accumulators.newId(), initialValue, param, name, internal, countFailedValues)
   }
 
-  def this(@transient initialValue: R, param: AccumulableParam[R, T], name: Option[String]) =
-    this(initialValue, param, name, false)
+  private[spark] def this(
+      initialValue: R,
+      param: AccumulableParam[R, T],
+      name: Option[String],
+      internal: Boolean) = {
+    this(initialValue, param, name, internal, false /* countFailedValues */)
+  }
 
-  def this(@transient initialValue: R, param: AccumulableParam[R, T]) =
-    this(initialValue, param, None)
+  def this(initialValue: R, param: AccumulableParam[R, T], name: Option[String]) =
+    this(initialValue, param, name, false /* internal */)
 
-  val id: Long = Accumulators.newId
+  def this(initialValue: R, param: AccumulableParam[R, T]) = this(initialValue, param, None)
 
-  @volatile @transient private var value_ : R = initialValue // Current value on master
-  val zero = param.zero(initialValue)  // Zero value to be passed to workers
+  @volatile @transient private var value_ : R = initialValue // Current value on driver
+  val zero = param.zero(initialValue) // Zero value to be passed to executors
   private var deserialized = false
 
-  Accumulators.register(this)
+  // In many places we create internal accumulators without access to the active context cleaner,
+  // so if we register them here then we may never unregister these accumulators. To avoid memory
+  // leaks, we require the caller to explicitly register internal accumulators elsewhere.
+  if (!internal) {
+    Accumulators.register(this)
+  }
 
   /**
    * If this [[Accumulable]] is internal. Internal [[Accumulable]]s will be reported to the driver
@@ -77,6 +104,17 @@ class Accumulable[R, T] private[spark] (
    */
   private[spark] def isInternal: Boolean = internal
 
+  /**
+   * Return a copy of this [[Accumulable]].
+   *
+   * The copy will have the same ID as the original and will not be registered with
+   * [[Accumulators]] again. This method exists so that the caller can avoid passing the
+   * same mutable instance around.
+   */
+  private[spark] def copy(): Accumulable[R, T] = {
+    new Accumulable[R, T](id, initialValue, param, name, internal, countFailedValues)
+  }
+
   /**
    * Add more data to this accumulator / accumulable
    * @param term the data to add
@@ -106,7 +144,7 @@ class Accumulable[R, T] private[spark] (
   def merge(term: R) { value_ = param.addInPlace(value_, term)}
 
   /**
-   * Access the accumulator's current value; only allowed on master.
+   * Access the accumulator's current value; only allowed on driver.
    */
   def value: R = {
     if (!deserialized) {
@@ -128,7 +166,7 @@ class Accumulable[R, T] private[spark] (
   def localValue: R = value_
 
   /**
-   * Set the accumulator's value; only allowed on master.
+   * Set the accumulator's value; only allowed on driver.
    */
   def value_= (newValue: R) {
     if (!deserialized) {
@@ -139,22 +177,24 @@ class Accumulable[R, T] private[spark] (
   }
 
   /**
-   * Set the accumulator's value; only allowed on master
+   * Set the accumulator's value. For internal use only.
    */
-  def setValue(newValue: R) {
-    this.value = newValue
-  }
+  def setValue(newValue: R): Unit = { value_ = newValue }
+
+  /**
+   * Set the accumulator's value. For internal use only.
+   */
+  private[spark] def setValueAny(newValue: Any): Unit = { setValue(newValue.asInstanceOf[R]) }
 
   // Called by Java when deserializing an object
   private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
     in.defaultReadObject()
     value_ = zero
     deserialized = true
+
     // Automatically register the accumulator when it is deserialized with the task closure.
-    //
-    // Note internal accumulators sent with task are deserialized before the TaskContext is created
-    // and are registered in the TaskContext constructor. Other internal accumulators, such SQL
-    // metrics, still need to register here.
+    // This is for external accumulators and internal ones that do not represent task level
+    // metrics, e.g. internal SQL metrics, which are per-operator.
     val taskContext = TaskContext.get()
     if (taskContext != null) {
       taskContext.registerAccumulator(this)
diff --git a/core/src/main/scala/org/apache/spark/Accumulator.scala b/core/src/main/scala/org/apache/spark/Accumulator.scala
index 007136e6ae3497b43095001b11dd4fc4a4e63e74..558bd447e22c5f0df1a696198e5a086c3a995652 100644
--- a/core/src/main/scala/org/apache/spark/Accumulator.scala
+++ b/core/src/main/scala/org/apache/spark/Accumulator.scala
@@ -17,9 +17,14 @@
 
 package org.apache.spark
 
-import scala.collection.{mutable, Map}
+import java.util.concurrent.atomic.AtomicLong
+import javax.annotation.concurrent.GuardedBy
+
+import scala.collection.mutable
 import scala.ref.WeakReference
 
+import org.apache.spark.storage.{BlockId, BlockStatus}
+
 
 /**
  * A simpler value of [[Accumulable]] where the result type being accumulated is the same
@@ -49,14 +54,18 @@ import scala.ref.WeakReference
  *
  * @param initialValue initial value of accumulator
  * @param param helper object defining how to add elements of type `T`
+ * @param name human-readable name associated with this accumulator
+ * @param internal whether this accumulator is used internally within Spark only
+ * @param countFailedValues whether to accumulate values from failed tasks
  * @tparam T result type
  */
 class Accumulator[T] private[spark] (
     @transient private[spark] val initialValue: T,
     param: AccumulatorParam[T],
     name: Option[String],
-    internal: Boolean)
-  extends Accumulable[T, T](initialValue, param, name, internal) {
+    internal: Boolean,
+    override val countFailedValues: Boolean = false)
+  extends Accumulable[T, T](initialValue, param, name, internal, countFailedValues) {
 
   def this(initialValue: T, param: AccumulatorParam[T], name: Option[String]) = {
     this(initialValue, param, name, false)
@@ -75,43 +84,63 @@ private[spark] object Accumulators extends Logging {
    * This global map holds the original accumulator objects that are created on the driver.
    * It keeps weak references to these objects so that accumulators can be garbage-collected
    * once the RDDs and user-code that reference them are cleaned up.
+   * TODO: Don't use a global map; these should be tied to a SparkContext at the very least.
    */
+  @GuardedBy("Accumulators")
   val originals = mutable.Map[Long, WeakReference[Accumulable[_, _]]]()
 
-  private var lastId: Long = 0
+  private val nextId = new AtomicLong(0L)
 
-  def newId(): Long = synchronized {
-    lastId += 1
-    lastId
-  }
+  /**
+   * Return a globally unique ID for a new [[Accumulable]].
+   * Note: Once you copy the [[Accumulable]] the ID is no longer unique.
+   */
+  def newId(): Long = nextId.getAndIncrement
 
+  /**
+   * Register an [[Accumulable]] created on the driver such that it can be used on the executors.
+   *
+   * All accumulators registered here can later be used as a container for accumulating partial
+   * values across multiple tasks. This is what [[org.apache.spark.scheduler.DAGScheduler]] does.
+   * Note: if an accumulator is registered here, it should also be registered with the active
+   * context cleaner for cleanup so as to avoid memory leaks.
+   *
+   * If an [[Accumulable]] with the same ID was already registered, this does nothing instead
+   * of overwriting it. This happens when we copy accumulators, e.g. when we reconstruct
+   * [[org.apache.spark.executor.TaskMetrics]] from accumulator updates.
+   */
   def register(a: Accumulable[_, _]): Unit = synchronized {
-    originals(a.id) = new WeakReference[Accumulable[_, _]](a)
+    if (!originals.contains(a.id)) {
+      originals(a.id) = new WeakReference[Accumulable[_, _]](a)
+    }
   }
 
-  def remove(accId: Long) {
-    synchronized {
-      originals.remove(accId)
-    }
+  /**
+   * Unregister the [[Accumulable]] with the given ID, if any.
+   */
+  def remove(accId: Long): Unit = synchronized {
+    originals.remove(accId)
   }
 
-  // Add values to the original accumulators with some given IDs
-  def add(values: Map[Long, Any]): Unit = synchronized {
-    for ((id, value) <- values) {
-      if (originals.contains(id)) {
-        // Since we are now storing weak references, we must check whether the underlying data
-        // is valid.
-        originals(id).get match {
-          case Some(accum) => accum.asInstanceOf[Accumulable[Any, Any]] ++= value
-          case None =>
-            throw new IllegalAccessError("Attempted to access garbage collected Accumulator.")
-        }
-      } else {
-        logWarning(s"Ignoring accumulator update for unknown accumulator id $id")
+  /**
+   * Return the [[Accumulable]] registered with the given ID, if any.
+   */
+  def get(id: Long): Option[Accumulable[_, _]] = synchronized {
+    originals.get(id).map { weakRef =>
+      // Since we are storing weak references, we must check whether the underlying data is valid.
+      weakRef.get.getOrElse {
+        throw new IllegalAccessError(s"Attempted to access garbage collected accumulator $id")
       }
     }
   }
 
+  /**
+   * Clear all registered [[Accumulable]]s. For testing only.
+   */
+  def clear(): Unit = synchronized {
+    originals.clear()
+  }
+
 }
 
 
@@ -156,5 +185,23 @@ object AccumulatorParam {
     def zero(initialValue: Float): Float = 0f
   }
 
-  // TODO: Add AccumulatorParams for other types, e.g. lists and strings
+  // Note: when merging values, this param just adopts the newer value. This is used only
+  // internally for things that shouldn't really be accumulated across tasks, like input
+  // read method, which should be the same across all tasks in the same stage.
+  private[spark] object StringAccumulatorParam extends AccumulatorParam[String] {
+    def addInPlace(t1: String, t2: String): String = t2
+    def zero(initialValue: String): String = ""
+  }
+
+  // Note: this is expensive as it makes a copy of the list every time the caller adds an item.
+  // A better way to use this is to first accumulate the values yourself then them all at once.
+  private[spark] class ListAccumulatorParam[T] extends AccumulatorParam[Seq[T]] {
+    def addInPlace(t1: Seq[T], t2: Seq[T]): Seq[T] = t1 ++ t2
+    def zero(initialValue: Seq[T]): Seq[T] = Seq.empty[T]
+  }
+
+  // For the internal metric that records what blocks are updated in a particular task
+  private[spark] object UpdatedBlockStatusesAccumulatorParam
+    extends ListAccumulatorParam[(BlockId, BlockStatus)]
+
 }
diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala
index 62629000cfc23d65cd974cf0230997d2aceeb685..e493d9a3cf9ccd8173e84d44f09763da0063a2ae 100644
--- a/core/src/main/scala/org/apache/spark/Aggregator.scala
+++ b/core/src/main/scala/org/apache/spark/Aggregator.scala
@@ -57,8 +57,7 @@ case class Aggregator[K, V, C] (
     Option(context).foreach { c =>
       c.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled)
       c.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled)
-      c.internalMetricsToAccumulators(
-        InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes)
+      c.taskMetrics().incPeakExecutionMemory(map.peakMemoryUsedBytes)
     }
   }
 }
diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
index e03977828b86dd9f411a44ac2ee648ffebd22d7f..45b20c0e8d60c97e7bf4185b8e6838f3c63b5d33 100644
--- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
+++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
@@ -35,7 +35,7 @@ import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils}
  */
 private[spark] case class Heartbeat(
     executorId: String,
-    taskMetrics: Array[(Long, TaskMetrics)], // taskId -> TaskMetrics
+    accumUpdates: Array[(Long, Seq[AccumulableInfo])], // taskId -> accum updates
     blockManagerId: BlockManagerId)
 
 /**
@@ -119,14 +119,14 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock)
       context.reply(true)
 
     // Messages received from executors
-    case heartbeat @ Heartbeat(executorId, taskMetrics, blockManagerId) =>
+    case heartbeat @ Heartbeat(executorId, accumUpdates, blockManagerId) =>
       if (scheduler != null) {
         if (executorLastSeen.contains(executorId)) {
           executorLastSeen(executorId) = clock.getTimeMillis()
           eventLoopThread.submit(new Runnable {
             override def run(): Unit = Utils.tryLogNonFatalError {
               val unknownExecutor = !scheduler.executorHeartbeatReceived(
-                executorId, taskMetrics, blockManagerId)
+                executorId, accumUpdates, blockManagerId)
               val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor)
               context.reply(response)
             }
diff --git a/core/src/main/scala/org/apache/spark/InternalAccumulator.scala b/core/src/main/scala/org/apache/spark/InternalAccumulator.scala
index 6ea997c079f33513ab14e60404e61b3fc5af404b..c191122c0630a09cf78fd01558610416f5a4b8ea 100644
--- a/core/src/main/scala/org/apache/spark/InternalAccumulator.scala
+++ b/core/src/main/scala/org/apache/spark/InternalAccumulator.scala
@@ -17,23 +17,169 @@
 
 package org.apache.spark
 
+import org.apache.spark.storage.{BlockId, BlockStatus}
 
-// This is moved to its own file because many more things will be added to it in SPARK-10620.
+
+/**
+ * A collection of fields and methods concerned with internal accumulators that represent
+ * task level metrics.
+ */
 private[spark] object InternalAccumulator {
-  val PEAK_EXECUTION_MEMORY = "peakExecutionMemory"
-  val TEST_ACCUMULATOR = "testAccumulator"
-
-  // For testing only.
-  // This needs to be a def since we don't want to reuse the same accumulator across stages.
-  private def maybeTestAccumulator: Option[Accumulator[Long]] = {
-    if (sys.props.contains("spark.testing")) {
-      Some(new Accumulator(
-        0L, AccumulatorParam.LongAccumulatorParam, Some(TEST_ACCUMULATOR), internal = true))
-    } else {
-      None
+
+  import AccumulatorParam._
+
+  // Prefixes used in names of internal task level metrics
+  val METRICS_PREFIX = "internal.metrics."
+  val SHUFFLE_READ_METRICS_PREFIX = METRICS_PREFIX + "shuffle.read."
+  val SHUFFLE_WRITE_METRICS_PREFIX = METRICS_PREFIX + "shuffle.write."
+  val OUTPUT_METRICS_PREFIX = METRICS_PREFIX + "output."
+  val INPUT_METRICS_PREFIX = METRICS_PREFIX + "input."
+
+  // Names of internal task level metrics
+  val EXECUTOR_DESERIALIZE_TIME = METRICS_PREFIX + "executorDeserializeTime"
+  val EXECUTOR_RUN_TIME = METRICS_PREFIX + "executorRunTime"
+  val RESULT_SIZE = METRICS_PREFIX + "resultSize"
+  val JVM_GC_TIME = METRICS_PREFIX + "jvmGCTime"
+  val RESULT_SERIALIZATION_TIME = METRICS_PREFIX + "resultSerializationTime"
+  val MEMORY_BYTES_SPILLED = METRICS_PREFIX + "memoryBytesSpilled"
+  val DISK_BYTES_SPILLED = METRICS_PREFIX + "diskBytesSpilled"
+  val PEAK_EXECUTION_MEMORY = METRICS_PREFIX + "peakExecutionMemory"
+  val UPDATED_BLOCK_STATUSES = METRICS_PREFIX + "updatedBlockStatuses"
+  val TEST_ACCUM = METRICS_PREFIX + "testAccumulator"
+
+  // scalastyle:off
+
+  // Names of shuffle read metrics
+  object shuffleRead {
+    val REMOTE_BLOCKS_FETCHED = SHUFFLE_READ_METRICS_PREFIX + "remoteBlocksFetched"
+    val LOCAL_BLOCKS_FETCHED = SHUFFLE_READ_METRICS_PREFIX + "localBlocksFetched"
+    val REMOTE_BYTES_READ = SHUFFLE_READ_METRICS_PREFIX + "remoteBytesRead"
+    val LOCAL_BYTES_READ = SHUFFLE_READ_METRICS_PREFIX + "localBytesRead"
+    val FETCH_WAIT_TIME = SHUFFLE_READ_METRICS_PREFIX + "fetchWaitTime"
+    val RECORDS_READ = SHUFFLE_READ_METRICS_PREFIX + "recordsRead"
+  }
+
+  // Names of shuffle write metrics
+  object shuffleWrite {
+    val BYTES_WRITTEN = SHUFFLE_WRITE_METRICS_PREFIX + "bytesWritten"
+    val RECORDS_WRITTEN = SHUFFLE_WRITE_METRICS_PREFIX + "recordsWritten"
+    val WRITE_TIME = SHUFFLE_WRITE_METRICS_PREFIX + "writeTime"
+  }
+
+  // Names of output metrics
+  object output {
+    val WRITE_METHOD = OUTPUT_METRICS_PREFIX + "writeMethod"
+    val BYTES_WRITTEN = OUTPUT_METRICS_PREFIX + "bytesWritten"
+    val RECORDS_WRITTEN = OUTPUT_METRICS_PREFIX + "recordsWritten"
+  }
+
+  // Names of input metrics
+  object input {
+    val READ_METHOD = INPUT_METRICS_PREFIX + "readMethod"
+    val BYTES_READ = INPUT_METRICS_PREFIX + "bytesRead"
+    val RECORDS_READ = INPUT_METRICS_PREFIX + "recordsRead"
+  }
+
+  // scalastyle:on
+
+  /**
+   * Create an internal [[Accumulator]] by name, which must begin with [[METRICS_PREFIX]].
+   */
+  def create(name: String): Accumulator[_] = {
+    require(name.startsWith(METRICS_PREFIX),
+      s"internal accumulator name must start with '$METRICS_PREFIX': $name")
+    getParam(name) match {
+      case p @ LongAccumulatorParam => newMetric[Long](0L, name, p)
+      case p @ IntAccumulatorParam => newMetric[Int](0, name, p)
+      case p @ StringAccumulatorParam => newMetric[String]("", name, p)
+      case p @ UpdatedBlockStatusesAccumulatorParam =>
+        newMetric[Seq[(BlockId, BlockStatus)]](Seq(), name, p)
+      case p => throw new IllegalArgumentException(
+        s"unsupported accumulator param '${p.getClass.getSimpleName}' for metric '$name'.")
+    }
+  }
+
+  /**
+   * Get the [[AccumulatorParam]] associated with the internal metric name,
+   * which must begin with [[METRICS_PREFIX]].
+   */
+  def getParam(name: String): AccumulatorParam[_] = {
+    require(name.startsWith(METRICS_PREFIX),
+      s"internal accumulator name must start with '$METRICS_PREFIX': $name")
+    name match {
+      case UPDATED_BLOCK_STATUSES => UpdatedBlockStatusesAccumulatorParam
+      case shuffleRead.LOCAL_BLOCKS_FETCHED => IntAccumulatorParam
+      case shuffleRead.REMOTE_BLOCKS_FETCHED => IntAccumulatorParam
+      case input.READ_METHOD => StringAccumulatorParam
+      case output.WRITE_METHOD => StringAccumulatorParam
+      case _ => LongAccumulatorParam
     }
   }
 
+  /**
+   * Accumulators for tracking internal metrics.
+   */
+  def create(): Seq[Accumulator[_]] = {
+    Seq[String](
+      EXECUTOR_DESERIALIZE_TIME,
+      EXECUTOR_RUN_TIME,
+      RESULT_SIZE,
+      JVM_GC_TIME,
+      RESULT_SERIALIZATION_TIME,
+      MEMORY_BYTES_SPILLED,
+      DISK_BYTES_SPILLED,
+      PEAK_EXECUTION_MEMORY,
+      UPDATED_BLOCK_STATUSES).map(create) ++
+      createShuffleReadAccums() ++
+      createShuffleWriteAccums() ++
+      createInputAccums() ++
+      createOutputAccums() ++
+      sys.props.get("spark.testing").map(_ => create(TEST_ACCUM)).toSeq
+  }
+
+  /**
+   * Accumulators for tracking shuffle read metrics.
+   */
+  def createShuffleReadAccums(): Seq[Accumulator[_]] = {
+    Seq[String](
+      shuffleRead.REMOTE_BLOCKS_FETCHED,
+      shuffleRead.LOCAL_BLOCKS_FETCHED,
+      shuffleRead.REMOTE_BYTES_READ,
+      shuffleRead.LOCAL_BYTES_READ,
+      shuffleRead.FETCH_WAIT_TIME,
+      shuffleRead.RECORDS_READ).map(create)
+  }
+
+  /**
+   * Accumulators for tracking shuffle write metrics.
+   */
+  def createShuffleWriteAccums(): Seq[Accumulator[_]] = {
+    Seq[String](
+      shuffleWrite.BYTES_WRITTEN,
+      shuffleWrite.RECORDS_WRITTEN,
+      shuffleWrite.WRITE_TIME).map(create)
+  }
+
+  /**
+   * Accumulators for tracking input metrics.
+   */
+  def createInputAccums(): Seq[Accumulator[_]] = {
+    Seq[String](
+      input.READ_METHOD,
+      input.BYTES_READ,
+      input.RECORDS_READ).map(create)
+  }
+
+  /**
+   * Accumulators for tracking output metrics.
+   */
+  def createOutputAccums(): Seq[Accumulator[_]] = {
+    Seq[String](
+      output.WRITE_METHOD,
+      output.BYTES_WRITTEN,
+      output.RECORDS_WRITTEN).map(create)
+  }
+
   /**
    * Accumulators for tracking internal metrics.
    *
@@ -41,18 +187,23 @@ private[spark] object InternalAccumulator {
    * add to the same set of accumulators. We do this to report the distribution of accumulator
    * values across all tasks within each stage.
    */
-  def create(sc: SparkContext): Seq[Accumulator[Long]] = {
-    val internalAccumulators = Seq(
-      // Execution memory refers to the memory used by internal data structures created
-      // during shuffles, aggregations and joins. The value of this accumulator should be
-      // approximately the sum of the peak sizes across all such data structures created
-      // in this task. For SQL jobs, this only tracks all unsafe operators and ExternalSort.
-      new Accumulator(
-        0L, AccumulatorParam.LongAccumulatorParam, Some(PEAK_EXECUTION_MEMORY), internal = true)
-    ) ++ maybeTestAccumulator.toSeq
-    internalAccumulators.foreach { accumulator =>
-      sc.cleaner.foreach(_.registerAccumulatorForCleanup(accumulator))
+  def create(sc: SparkContext): Seq[Accumulator[_]] = {
+    val accums = create()
+    accums.foreach { accum =>
+      Accumulators.register(accum)
+      sc.cleaner.foreach(_.registerAccumulatorForCleanup(accum))
     }
-    internalAccumulators
+    accums
+  }
+
+  /**
+   * Create a new accumulator representing an internal task metric.
+   */
+  private def newMetric[T](
+      initialValue: T,
+      name: String,
+      param: AccumulatorParam[T]): Accumulator[T] = {
+    new Accumulator[T](initialValue, param, Some(name), internal = true, countFailedValues = true)
   }
+
 }
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index 7704abc1340961b959748d2921f2d34dc4354077..9f49cf1c4c9bdcc8e5c2ca8dba28c72f4fc59067 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -64,7 +64,7 @@ object TaskContext {
    * An empty task context that does not represent an actual task.
    */
   private[spark] def empty(): TaskContextImpl = {
-    new TaskContextImpl(0, 0, 0, 0, null, null, Seq.empty)
+    new TaskContextImpl(0, 0, 0, 0, null, null)
   }
 
 }
@@ -138,7 +138,6 @@ abstract class TaskContext extends Serializable {
    */
   def taskAttemptId(): Long
 
-  /** ::DeveloperApi:: */
   @DeveloperApi
   def taskMetrics(): TaskMetrics
 
@@ -161,20 +160,4 @@ abstract class TaskContext extends Serializable {
    */
   private[spark] def registerAccumulator(a: Accumulable[_, _]): Unit
 
-  /**
-   * Return the local values of internal accumulators that belong to this task. The key of the Map
-   * is the accumulator id and the value of the Map is the latest accumulator local value.
-   */
-  private[spark] def collectInternalAccumulators(): Map[Long, Any]
-
-  /**
-   * Return the local values of accumulators that belong to this task. The key of the Map is the
-   * accumulator id and the value of the Map is the latest accumulator local value.
-   */
-  private[spark] def collectAccumulators(): Map[Long, Any]
-
-  /**
-   * Accumulators for tracking internal metrics indexed by the name.
-   */
-  private[spark] val internalMetricsToAccumulators: Map[String, Accumulator[Long]]
 }
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index 94ff884b742b84b4ff0dde58be69657a59eb13b8..27ca46f73d8cada058fab6a240273bfca6631e2f 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark
 
-import scala.collection.mutable.{ArrayBuffer, HashMap}
+import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.executor.TaskMetrics
 import org.apache.spark.memory.TaskMemoryManager
@@ -32,11 +32,15 @@ private[spark] class TaskContextImpl(
     override val attemptNumber: Int,
     override val taskMemoryManager: TaskMemoryManager,
     @transient private val metricsSystem: MetricsSystem,
-    internalAccumulators: Seq[Accumulator[Long]],
-    val taskMetrics: TaskMetrics = TaskMetrics.empty)
+    initialAccumulators: Seq[Accumulator[_]] = InternalAccumulator.create())
   extends TaskContext
   with Logging {
 
+  /**
+   * Metrics associated with this task.
+   */
+  override val taskMetrics: TaskMetrics = new TaskMetrics(initialAccumulators)
+
   // List of callback functions to execute when the task completes.
   @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener]
 
@@ -91,24 +95,8 @@ private[spark] class TaskContextImpl(
   override def getMetricsSources(sourceName: String): Seq[Source] =
     metricsSystem.getSourcesByName(sourceName)
 
-  @transient private val accumulators = new HashMap[Long, Accumulable[_, _]]
-
-  private[spark] override def registerAccumulator(a: Accumulable[_, _]): Unit = synchronized {
-    accumulators(a.id) = a
-  }
-
-  private[spark] override def collectInternalAccumulators(): Map[Long, Any] = synchronized {
-    accumulators.filter(_._2.isInternal).mapValues(_.localValue).toMap
+  private[spark] override def registerAccumulator(a: Accumulable[_, _]): Unit = {
+    taskMetrics.registerAccumulator(a)
   }
 
-  private[spark] override def collectAccumulators(): Map[Long, Any] = synchronized {
-    accumulators.mapValues(_.localValue).toMap
-  }
-
-  private[spark] override val internalMetricsToAccumulators: Map[String, Accumulator[Long]] = {
-    // Explicitly register internal accumulators here because these are
-    // not captured in the task closure and are already deserialized
-    internalAccumulators.foreach(registerAccumulator)
-    internalAccumulators.map { a => (a.name.get, a) }.toMap
-  }
 }
diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
index 13241b77bf97b842c464173624feaf1aca994520..68340cc704daee06030d612df3fb2c042e76f663 100644
--- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala
+++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
@@ -19,8 +19,11 @@ package org.apache.spark
 
 import java.io.{ObjectInputStream, ObjectOutputStream}
 
+import scala.util.Try
+
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.scheduler.AccumulableInfo
 import org.apache.spark.storage.BlockManagerId
 import org.apache.spark.util.Utils
 
@@ -115,22 +118,34 @@ case class ExceptionFailure(
     description: String,
     stackTrace: Array[StackTraceElement],
     fullStackTrace: String,
-    metrics: Option[TaskMetrics],
-    private val exceptionWrapper: Option[ThrowableSerializationWrapper])
+    exceptionWrapper: Option[ThrowableSerializationWrapper],
+    accumUpdates: Seq[AccumulableInfo] = Seq.empty[AccumulableInfo])
   extends TaskFailedReason {
 
+  @deprecated("use accumUpdates instead", "2.0.0")
+  val metrics: Option[TaskMetrics] = {
+    if (accumUpdates.nonEmpty) {
+      Try(TaskMetrics.fromAccumulatorUpdates(accumUpdates)).toOption
+    } else {
+      None
+    }
+  }
+
   /**
    * `preserveCause` is used to keep the exception itself so it is available to the
    * driver. This may be set to `false` in the event that the exception is not in fact
    * serializable.
    */
-  private[spark] def this(e: Throwable, metrics: Option[TaskMetrics], preserveCause: Boolean) {
-    this(e.getClass.getName, e.getMessage, e.getStackTrace, Utils.exceptionString(e), metrics,
-      if (preserveCause) Some(new ThrowableSerializationWrapper(e)) else None)
+  private[spark] def this(
+      e: Throwable,
+      accumUpdates: Seq[AccumulableInfo],
+      preserveCause: Boolean) {
+    this(e.getClass.getName, e.getMessage, e.getStackTrace, Utils.exceptionString(e),
+      if (preserveCause) Some(new ThrowableSerializationWrapper(e)) else None, accumUpdates)
   }
 
-  private[spark] def this(e: Throwable, metrics: Option[TaskMetrics]) {
-    this(e, metrics, preserveCause = true)
+  private[spark] def this(e: Throwable, accumUpdates: Seq[AccumulableInfo]) {
+    this(e, accumUpdates, preserveCause = true)
   }
 
   def exception: Option[Throwable] = exceptionWrapper.flatMap {
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
index 8ba3f5e24189914b4090273843143f3ed4a336b1..06b5101b1f566214bc688f95cd20410c827f6198 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -370,6 +370,14 @@ object SparkHadoopUtil {
 
   val SPARK_YARN_CREDS_COUNTER_DELIM = "-"
 
+  /**
+   * Number of records to update input metrics when reading from HadoopRDDs.
+   *
+   * Each update is potentially expensive because we need to use reflection to access the
+   * Hadoop FileSystem API of interest (only available in 2.5), so we should do this sparingly.
+   */
+  private[spark] val UPDATE_INPUT_METRICS_INTERVAL_RECORDS = 1000
+
   def get: SparkHadoopUtil = {
     // Check each time to support changing to/from YARN
     val yarnMode = java.lang.Boolean.valueOf(
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 030ae41db4a625348bab60264fdfbf51dd67558d..51c000ea5c5747daf2bc439bdd2c83a89da5835f 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -31,7 +31,7 @@ import org.apache.spark._
 import org.apache.spark.deploy.SparkHadoopUtil
 import org.apache.spark.memory.TaskMemoryManager
 import org.apache.spark.rpc.RpcTimeout
-import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task}
+import org.apache.spark.scheduler.{AccumulableInfo, DirectTaskResult, IndirectTaskResult, Task}
 import org.apache.spark.shuffle.FetchFailedException
 import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
 import org.apache.spark.util._
@@ -210,7 +210,7 @@ private[spark] class Executor(
         // Run the actual task and measure its runtime.
         taskStart = System.currentTimeMillis()
         var threwException = true
-        val (value, accumUpdates) = try {
+        val value = try {
           val res = task.run(
             taskAttemptId = taskId,
             attemptNumber = attemptNumber,
@@ -249,10 +249,11 @@ private[spark] class Executor(
           m.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime)
           m.setJvmGCTime(computeTotalGcTime() - startGCTime)
           m.setResultSerializationTime(afterSerialization - beforeSerialization)
-          m.updateAccumulators()
         }
 
-        val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull)
+        // Note: accumulator updates must be collected after TaskMetrics is updated
+        val accumUpdates = task.collectAccumulatorUpdates()
+        val directResult = new DirectTaskResult(valueBytes, accumUpdates)
         val serializedDirectResult = ser.serialize(directResult)
         val resultSize = serializedDirectResult.limit
 
@@ -297,21 +298,25 @@ private[spark] class Executor(
           // the default uncaught exception handler, which will terminate the Executor.
           logError(s"Exception in $taskName (TID $taskId)", t)
 
-          val metrics: Option[TaskMetrics] = Option(task).flatMap { task =>
-            task.metrics.map { m =>
+          // Collect latest accumulator values to report back to the driver
+          val accumulatorUpdates: Seq[AccumulableInfo] =
+          if (task != null) {
+            task.metrics.foreach { m =>
               m.setExecutorRunTime(System.currentTimeMillis() - taskStart)
               m.setJvmGCTime(computeTotalGcTime() - startGCTime)
-              m.updateAccumulators()
-              m
             }
+            task.collectAccumulatorUpdates(taskFailed = true)
+          } else {
+            Seq.empty[AccumulableInfo]
           }
+
           val serializedTaskEndReason = {
             try {
-              ser.serialize(new ExceptionFailure(t, metrics))
+              ser.serialize(new ExceptionFailure(t, accumulatorUpdates))
             } catch {
               case _: NotSerializableException =>
                 // t is not serializable so just send the stacktrace
-                ser.serialize(new ExceptionFailure(t, metrics, false))
+                ser.serialize(new ExceptionFailure(t, accumulatorUpdates, preserveCause = false))
             }
           }
           execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason)
@@ -418,33 +423,21 @@ private[spark] class Executor(
 
   /** Reports heartbeat and metrics for active tasks to the driver. */
   private def reportHeartBeat(): Unit = {
-    // list of (task id, metrics) to send back to the driver
-    val tasksMetrics = new ArrayBuffer[(Long, TaskMetrics)]()
+    // list of (task id, accumUpdates) to send back to the driver
+    val accumUpdates = new ArrayBuffer[(Long, Seq[AccumulableInfo])]()
     val curGCTime = computeTotalGcTime()
 
     for (taskRunner <- runningTasks.values().asScala) {
       if (taskRunner.task != null) {
         taskRunner.task.metrics.foreach { metrics =>
           metrics.mergeShuffleReadMetrics()
-          metrics.updateInputMetrics()
           metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime)
-          metrics.updateAccumulators()
-
-          if (isLocal) {
-            // JobProgressListener will hold an reference of it during
-            // onExecutorMetricsUpdate(), then JobProgressListener can not see
-            // the changes of metrics any more, so make a deep copy of it
-            val copiedMetrics = Utils.deserialize[TaskMetrics](Utils.serialize(metrics))
-            tasksMetrics += ((taskRunner.taskId, copiedMetrics))
-          } else {
-            // It will be copied by serialization
-            tasksMetrics += ((taskRunner.taskId, metrics))
-          }
+          accumUpdates += ((taskRunner.taskId, metrics.accumulatorUpdates()))
         }
       }
     }
 
-    val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId)
+    val message = Heartbeat(executorId, accumUpdates.toArray, env.blockManager.blockManagerId)
     try {
       val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse](
           message, RpcTimeout(conf, "spark.executor.heartbeatInterval", "10s"))
diff --git a/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala
index 8f1d7f89a44b4091e7ea38f99901d7052372d169..ed9e157ce758bd12ec6cfa0ee0d288e9d0a444a6 100644
--- a/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala
@@ -17,13 +17,15 @@
 
 package org.apache.spark.executor
 
+import org.apache.spark.{Accumulator, InternalAccumulator}
 import org.apache.spark.annotation.DeveloperApi
 
 
 /**
  * :: DeveloperApi ::
- * Method by which input data was read.  Network means that the data was read over the network
+ * Method by which input data was read. Network means that the data was read over the network
  * from a remote block manager (which may have stored the data on-disk or in-memory).
+ * Operations are not thread-safe.
  */
 @DeveloperApi
 object DataReadMethod extends Enumeration with Serializable {
@@ -34,44 +36,75 @@ object DataReadMethod extends Enumeration with Serializable {
 
 /**
  * :: DeveloperApi ::
- * Metrics about reading input data.
+ * A collection of accumulators that represents metrics about reading data from external systems.
  */
 @DeveloperApi
-case class InputMetrics(readMethod: DataReadMethod.Value) {
+class InputMetrics private (
+    _bytesRead: Accumulator[Long],
+    _recordsRead: Accumulator[Long],
+    _readMethod: Accumulator[String])
+  extends Serializable {
+
+  private[executor] def this(accumMap: Map[String, Accumulator[_]]) {
+    this(
+      TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.input.BYTES_READ),
+      TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.input.RECORDS_READ),
+      TaskMetrics.getAccum[String](accumMap, InternalAccumulator.input.READ_METHOD))
+  }
 
   /**
-   * This is volatile so that it is visible to the updater thread.
+   * Create a new [[InputMetrics]] that is not associated with any particular task.
+   *
+   * This mainly exists because of SPARK-5225, where we are forced to use a dummy [[InputMetrics]]
+   * because we want to ignore metrics from a second read method. In the future, we should revisit
+   * whether this is needed.
+   *
+   * A better alternative is [[TaskMetrics.registerInputMetrics]].
    */
-  @volatile @transient var bytesReadCallback: Option[() => Long] = None
+  private[executor] def this() {
+    this(InternalAccumulator.createInputAccums()
+      .map { a => (a.name.get, a) }.toMap[String, Accumulator[_]])
+  }
 
   /**
-   * Total bytes read.
+   * Total number of bytes read.
    */
-  private var _bytesRead: Long = _
-  def bytesRead: Long = _bytesRead
-  def incBytesRead(bytes: Long): Unit = _bytesRead += bytes
+  def bytesRead: Long = _bytesRead.localValue
 
   /**
-   * Total records read.
+   * Total number of records read.
    */
-  private var _recordsRead: Long = _
-  def recordsRead: Long = _recordsRead
-  def incRecordsRead(records: Long): Unit = _recordsRead += records
+  def recordsRead: Long = _recordsRead.localValue
 
   /**
-   * Invoke the bytesReadCallback and mutate bytesRead.
+   * The source from which this task reads its input.
    */
-  def updateBytesRead() {
-    bytesReadCallback.foreach { c =>
-      _bytesRead = c()
-    }
+  def readMethod: DataReadMethod.Value = DataReadMethod.withName(_readMethod.localValue)
+
+  @deprecated("incrementing input metrics is for internal use only", "2.0.0")
+  def incBytesRead(v: Long): Unit = _bytesRead.add(v)
+  @deprecated("incrementing input metrics is for internal use only", "2.0.0")
+  def incRecordsRead(v: Long): Unit = _recordsRead.add(v)
+  private[spark] def setBytesRead(v: Long): Unit = _bytesRead.setValue(v)
+  private[spark] def setReadMethod(v: DataReadMethod.Value): Unit =
+    _readMethod.setValue(v.toString)
+
+}
+
+/**
+ * Deprecated methods to preserve case class matching behavior before Spark 2.0.
+ */
+object InputMetrics {
+
+  @deprecated("matching on InputMetrics will not be supported in the future", "2.0.0")
+  def apply(readMethod: DataReadMethod.Value): InputMetrics = {
+    val im = new InputMetrics
+    im.setReadMethod(readMethod)
+    im
   }
 
-  /**
-   * Register a function that can be called to get up-to-date information on how many bytes the task
-   * has read from an input source.
-   */
-  def setBytesReadCallback(f: Option[() => Long]) {
-    bytesReadCallback = f
+  @deprecated("matching on InputMetrics will not be supported in the future", "2.0.0")
+  def unapply(input: InputMetrics): Option[DataReadMethod.Value] = {
+    Some(input.readMethod)
   }
 }
diff --git a/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala b/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala
index ad132d004cde000ab81fbf221db8e9c557714015..0b37d559c746277478f0d8a07f2d65cb6c6639fb 100644
--- a/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala
@@ -17,12 +17,14 @@
 
 package org.apache.spark.executor
 
+import org.apache.spark.{Accumulator, InternalAccumulator}
 import org.apache.spark.annotation.DeveloperApi
 
 
 /**
  * :: DeveloperApi ::
  * Method by which output data was written.
+ * Operations are not thread-safe.
  */
 @DeveloperApi
 object DataWriteMethod extends Enumeration with Serializable {
@@ -33,21 +35,70 @@ object DataWriteMethod extends Enumeration with Serializable {
 
 /**
  * :: DeveloperApi ::
- * Metrics about writing output data.
+ * A collection of accumulators that represents metrics about writing data to external systems.
  */
 @DeveloperApi
-case class OutputMetrics(writeMethod: DataWriteMethod.Value) {
+class OutputMetrics private (
+    _bytesWritten: Accumulator[Long],
+    _recordsWritten: Accumulator[Long],
+    _writeMethod: Accumulator[String])
+  extends Serializable {
+
+  private[executor] def this(accumMap: Map[String, Accumulator[_]]) {
+    this(
+      TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.output.BYTES_WRITTEN),
+      TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.output.RECORDS_WRITTEN),
+      TaskMetrics.getAccum[String](accumMap, InternalAccumulator.output.WRITE_METHOD))
+  }
+
+  /**
+   * Create a new [[OutputMetrics]] that is not associated with any particular task.
+   *
+   * This is only used for preserving matching behavior on [[OutputMetrics]], which used to be
+   * a case class before Spark 2.0. Once we remove support for matching on [[OutputMetrics]]
+   * we can remove this constructor as well.
+   */
+  private[executor] def this() {
+    this(InternalAccumulator.createOutputAccums()
+      .map { a => (a.name.get, a) }.toMap[String, Accumulator[_]])
+  }
+
+  /**
+   * Total number of bytes written.
+   */
+  def bytesWritten: Long = _bytesWritten.localValue
+
   /**
-   * Total bytes written
+   * Total number of records written.
    */
-  private var _bytesWritten: Long = _
-  def bytesWritten: Long = _bytesWritten
-  private[spark] def setBytesWritten(value : Long): Unit = _bytesWritten = value
+  def recordsWritten: Long = _recordsWritten.localValue
 
   /**
-   * Total records written
+   * The source to which this task writes its output.
    */
-  private var _recordsWritten: Long = 0L
-  def recordsWritten: Long = _recordsWritten
-  private[spark] def setRecordsWritten(value: Long): Unit = _recordsWritten = value
+  def writeMethod: DataWriteMethod.Value = DataWriteMethod.withName(_writeMethod.localValue)
+
+  private[spark] def setBytesWritten(v: Long): Unit = _bytesWritten.setValue(v)
+  private[spark] def setRecordsWritten(v: Long): Unit = _recordsWritten.setValue(v)
+  private[spark] def setWriteMethod(v: DataWriteMethod.Value): Unit =
+    _writeMethod.setValue(v.toString)
+
+}
+
+/**
+ * Deprecated methods to preserve case class matching behavior before Spark 2.0.
+ */
+object OutputMetrics {
+
+  @deprecated("matching on OutputMetrics will not be supported in the future", "2.0.0")
+  def apply(writeMethod: DataWriteMethod.Value): OutputMetrics = {
+    val om = new OutputMetrics
+    om.setWriteMethod(writeMethod)
+    om
+  }
+
+  @deprecated("matching on OutputMetrics will not be supported in the future", "2.0.0")
+  def unapply(output: OutputMetrics): Option[DataWriteMethod.Value] = {
+    Some(output.writeMethod)
+  }
 }
diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala
index e985b35ace62354294251f564dd0d3ef18ecf071..50bb645d974a358198854b8112c475fdb85b5f1d 100644
--- a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala
@@ -17,71 +17,103 @@
 
 package org.apache.spark.executor
 
+import org.apache.spark.{Accumulator, InternalAccumulator}
 import org.apache.spark.annotation.DeveloperApi
 
 
 /**
  * :: DeveloperApi ::
- * Metrics pertaining to shuffle data read in a given task.
+ * A collection of accumulators that represent metrics about reading shuffle data.
+ * Operations are not thread-safe.
  */
 @DeveloperApi
-class ShuffleReadMetrics extends Serializable {
+class ShuffleReadMetrics private (
+    _remoteBlocksFetched: Accumulator[Int],
+    _localBlocksFetched: Accumulator[Int],
+    _remoteBytesRead: Accumulator[Long],
+    _localBytesRead: Accumulator[Long],
+    _fetchWaitTime: Accumulator[Long],
+    _recordsRead: Accumulator[Long])
+  extends Serializable {
+
+  private[executor] def this(accumMap: Map[String, Accumulator[_]]) {
+    this(
+      TaskMetrics.getAccum[Int](accumMap, InternalAccumulator.shuffleRead.REMOTE_BLOCKS_FETCHED),
+      TaskMetrics.getAccum[Int](accumMap, InternalAccumulator.shuffleRead.LOCAL_BLOCKS_FETCHED),
+      TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.shuffleRead.REMOTE_BYTES_READ),
+      TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.shuffleRead.LOCAL_BYTES_READ),
+      TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.shuffleRead.FETCH_WAIT_TIME),
+      TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.shuffleRead.RECORDS_READ))
+  }
+
   /**
-   * Number of remote blocks fetched in this shuffle by this task
+   * Create a new [[ShuffleReadMetrics]] that is not associated with any particular task.
+   *
+   * This mainly exists for legacy reasons, because we use dummy [[ShuffleReadMetrics]] in
+   * many places only to merge their values together later. In the future, we should revisit
+   * whether this is needed.
+   *
+   * A better alternative is [[TaskMetrics.registerTempShuffleReadMetrics]] followed by
+   * [[TaskMetrics.mergeShuffleReadMetrics]].
    */
-  private var _remoteBlocksFetched: Int = _
-  def remoteBlocksFetched: Int = _remoteBlocksFetched
-  private[spark] def incRemoteBlocksFetched(value: Int) = _remoteBlocksFetched += value
-  private[spark] def decRemoteBlocksFetched(value: Int) = _remoteBlocksFetched -= value
+  private[spark] def this() {
+    this(InternalAccumulator.createShuffleReadAccums().map { a => (a.name.get, a) }.toMap)
+  }
 
   /**
-   * Number of local blocks fetched in this shuffle by this task
+   * Number of remote blocks fetched in this shuffle by this task.
    */
-  private var _localBlocksFetched: Int = _
-  def localBlocksFetched: Int = _localBlocksFetched
-  private[spark] def incLocalBlocksFetched(value: Int) = _localBlocksFetched += value
-  private[spark] def decLocalBlocksFetched(value: Int) = _localBlocksFetched -= value
+  def remoteBlocksFetched: Int = _remoteBlocksFetched.localValue
 
   /**
-   * Time the task spent waiting for remote shuffle blocks. This only includes the time
-   * blocking on shuffle input data. For instance if block B is being fetched while the task is
-   * still not finished processing block A, it is not considered to be blocking on block B.
+   * Number of local blocks fetched in this shuffle by this task.
    */
-  private var _fetchWaitTime: Long = _
-  def fetchWaitTime: Long = _fetchWaitTime
-  private[spark] def incFetchWaitTime(value: Long) = _fetchWaitTime += value
-  private[spark] def decFetchWaitTime(value: Long) = _fetchWaitTime -= value
+  def localBlocksFetched: Int = _localBlocksFetched.localValue
 
   /**
-   * Total number of remote bytes read from the shuffle by this task
+   * Total number of remote bytes read from the shuffle by this task.
    */
-  private var _remoteBytesRead: Long = _
-  def remoteBytesRead: Long = _remoteBytesRead
-  private[spark] def incRemoteBytesRead(value: Long) = _remoteBytesRead += value
-  private[spark] def decRemoteBytesRead(value: Long) = _remoteBytesRead -= value
+  def remoteBytesRead: Long = _remoteBytesRead.localValue
 
   /**
    * Shuffle data that was read from the local disk (as opposed to from a remote executor).
    */
-  private var _localBytesRead: Long = _
-  def localBytesRead: Long = _localBytesRead
-  private[spark] def incLocalBytesRead(value: Long) = _localBytesRead += value
+  def localBytesRead: Long = _localBytesRead.localValue
 
   /**
-   * Total bytes fetched in the shuffle by this task (both remote and local).
+   * Time the task spent waiting for remote shuffle blocks. This only includes the time
+   * blocking on shuffle input data. For instance if block B is being fetched while the task is
+   * still not finished processing block A, it is not considered to be blocking on block B.
+   */
+  def fetchWaitTime: Long = _fetchWaitTime.localValue
+
+  /**
+   * Total number of records read from the shuffle by this task.
    */
-  def totalBytesRead: Long = _remoteBytesRead + _localBytesRead
+  def recordsRead: Long = _recordsRead.localValue
 
   /**
-   * Number of blocks fetched in this shuffle by this task (remote or local)
+   * Total bytes fetched in the shuffle by this task (both remote and local).
    */
-  def totalBlocksFetched: Int = _remoteBlocksFetched + _localBlocksFetched
+  def totalBytesRead: Long = remoteBytesRead + localBytesRead
 
   /**
-   * Total number of records read from the shuffle by this task
+   * Number of blocks fetched in this shuffle by this task (remote or local).
    */
-  private var _recordsRead: Long = _
-  def recordsRead: Long = _recordsRead
-  private[spark] def incRecordsRead(value: Long) = _recordsRead += value
-  private[spark] def decRecordsRead(value: Long) = _recordsRead -= value
+  def totalBlocksFetched: Int = remoteBlocksFetched + localBlocksFetched
+
+  private[spark] def incRemoteBlocksFetched(v: Int): Unit = _remoteBlocksFetched.add(v)
+  private[spark] def incLocalBlocksFetched(v: Int): Unit = _localBlocksFetched.add(v)
+  private[spark] def incRemoteBytesRead(v: Long): Unit = _remoteBytesRead.add(v)
+  private[spark] def incLocalBytesRead(v: Long): Unit = _localBytesRead.add(v)
+  private[spark] def incFetchWaitTime(v: Long): Unit = _fetchWaitTime.add(v)
+  private[spark] def incRecordsRead(v: Long): Unit = _recordsRead.add(v)
+
+  private[spark] def setRemoteBlocksFetched(v: Int): Unit = _remoteBlocksFetched.setValue(v)
+  private[spark] def setLocalBlocksFetched(v: Int): Unit = _localBlocksFetched.setValue(v)
+  private[spark] def setRemoteBytesRead(v: Long): Unit = _remoteBytesRead.setValue(v)
+  private[spark] def setLocalBytesRead(v: Long): Unit = _localBytesRead.setValue(v)
+  private[spark] def setFetchWaitTime(v: Long): Unit = _fetchWaitTime.setValue(v)
+  private[spark] def setRecordsRead(v: Long): Unit = _recordsRead.setValue(v)
+
 }
diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala
index 24795f860087fcacc08626860c8aab963c4c9b4b..c7aaabb561bba7b14186c5777cbc952a16fe4f7f 100644
--- a/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala
@@ -17,40 +17,66 @@
 
 package org.apache.spark.executor
 
+import org.apache.spark.{Accumulator, InternalAccumulator}
 import org.apache.spark.annotation.DeveloperApi
 
 
 /**
  * :: DeveloperApi ::
- * Metrics pertaining to shuffle data written in a given task.
+ * A collection of accumulators that represent metrics about writing shuffle data.
+ * Operations are not thread-safe.
  */
 @DeveloperApi
-class ShuffleWriteMetrics extends Serializable {
+class ShuffleWriteMetrics private (
+    _bytesWritten: Accumulator[Long],
+    _recordsWritten: Accumulator[Long],
+    _writeTime: Accumulator[Long])
+  extends Serializable {
+
+  private[executor] def this(accumMap: Map[String, Accumulator[_]]) {
+    this(
+      TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.shuffleWrite.BYTES_WRITTEN),
+      TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.shuffleWrite.RECORDS_WRITTEN),
+      TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.shuffleWrite.WRITE_TIME))
+  }
 
   /**
-   * Number of bytes written for the shuffle by this task
+   * Create a new [[ShuffleWriteMetrics]] that is not associated with any particular task.
+   *
+   * This mainly exists for legacy reasons, because we use dummy [[ShuffleWriteMetrics]] in
+   * many places only to merge their values together later. In the future, we should revisit
+   * whether this is needed.
+   *
+   * A better alternative is [[TaskMetrics.registerShuffleWriteMetrics]].
    */
-  @volatile private var _bytesWritten: Long = _
-  def bytesWritten: Long = _bytesWritten
-  private[spark] def incBytesWritten(value: Long) = _bytesWritten += value
-  private[spark] def decBytesWritten(value: Long) = _bytesWritten -= value
+  private[spark] def this() {
+    this(InternalAccumulator.createShuffleWriteAccums().map { a => (a.name.get, a) }.toMap)
+  }
 
   /**
-   * Time the task spent blocking on writes to disk or buffer cache, in nanoseconds
+   * Number of bytes written for the shuffle by this task.
    */
-  @volatile private var _writeTime: Long = _
-  def writeTime: Long = _writeTime
-  private[spark] def incWriteTime(value: Long) = _writeTime += value
-  private[spark] def decWriteTime(value: Long) = _writeTime -= value
+  def bytesWritten: Long = _bytesWritten.localValue
 
   /**
-   * Total number of records written to the shuffle by this task
+   * Total number of records written to the shuffle by this task.
    */
-  @volatile private var _recordsWritten: Long = _
-  def recordsWritten: Long = _recordsWritten
-  private[spark] def incRecordsWritten(value: Long) = _recordsWritten += value
-  private[spark] def decRecordsWritten(value: Long) = _recordsWritten -= value
-  private[spark] def setRecordsWritten(value: Long) = _recordsWritten = value
+  def recordsWritten: Long = _recordsWritten.localValue
+
+  /**
+   * Time the task spent blocking on writes to disk or buffer cache, in nanoseconds.
+   */
+  def writeTime: Long = _writeTime.localValue
+
+  private[spark] def incBytesWritten(v: Long): Unit = _bytesWritten.add(v)
+  private[spark] def incRecordsWritten(v: Long): Unit = _recordsWritten.add(v)
+  private[spark] def incWriteTime(v: Long): Unit = _writeTime.add(v)
+  private[spark] def decBytesWritten(v: Long): Unit = {
+    _bytesWritten.setValue(bytesWritten - v)
+  }
+  private[spark] def decRecordsWritten(v: Long): Unit = {
+    _recordsWritten.setValue(recordsWritten - v)
+  }
 
   // Legacy methods for backward compatibility.
   // TODO: remove these once we make this class private.
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index 32ef5a9b5606f277d78db795ec0101d2f4f61b97..8d10bf588ef1fc51f46ef2088f9ba0948a59e6bd 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -17,90 +17,161 @@
 
 package org.apache.spark.executor
 
-import java.io.{IOException, ObjectInputStream}
-import java.util.concurrent.ConcurrentHashMap
-
+import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 
+import org.apache.spark._
 import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.executor.DataReadMethod.DataReadMethod
+import org.apache.spark.scheduler.AccumulableInfo
 import org.apache.spark.storage.{BlockId, BlockStatus}
-import org.apache.spark.util.Utils
 
 
 /**
  * :: DeveloperApi ::
  * Metrics tracked during the execution of a task.
  *
- * This class is used to house metrics both for in-progress and completed tasks. In executors,
- * both the task thread and the heartbeat thread write to the TaskMetrics. The heartbeat thread
- * reads it to send in-progress metrics, and the task thread reads it to send metrics along with
- * the completed task.
+ * This class is wrapper around a collection of internal accumulators that represent metrics
+ * associated with a task. The local values of these accumulators are sent from the executor
+ * to the driver when the task completes. These values are then merged into the corresponding
+ * accumulator previously registered on the driver.
+ *
+ * The accumulator updates are also sent to the driver periodically (on executor heartbeat)
+ * and when the task failed with an exception. The [[TaskMetrics]] object itself should never
+ * be sent to the driver.
  *
- * So, when adding new fields, take into consideration that the whole object can be serialized for
- * shipping off at any time to consumers of the SparkListener interface.
+ * @param initialAccums the initial set of accumulators that this [[TaskMetrics]] depends on.
+ *                      Each accumulator in this initial set must be uniquely named and marked
+ *                      as internal. Additional accumulators registered later need not satisfy
+ *                      these requirements.
  */
 @DeveloperApi
-class TaskMetrics extends Serializable {
+class TaskMetrics(initialAccums: Seq[Accumulator[_]]) extends Serializable {
+
+  import InternalAccumulator._
+
+  // Needed for Java tests
+  def this() {
+    this(InternalAccumulator.create())
+  }
+
+  /**
+   * All accumulators registered with this task.
+   */
+  private val accums = new ArrayBuffer[Accumulable[_, _]]
+  accums ++= initialAccums
+
+  /**
+   * A map for quickly accessing the initial set of accumulators by name.
+   */
+  private val initialAccumsMap: Map[String, Accumulator[_]] = {
+    val map = new mutable.HashMap[String, Accumulator[_]]
+    initialAccums.foreach { a =>
+      val name = a.name.getOrElse {
+        throw new IllegalArgumentException(
+          "initial accumulators passed to TaskMetrics must be named")
+      }
+      require(a.isInternal,
+        s"initial accumulator '$name' passed to TaskMetrics must be marked as internal")
+      require(!map.contains(name),
+        s"detected duplicate accumulator name '$name' when constructing TaskMetrics")
+      map(name) = a
+    }
+    map.toMap
+  }
+
+  // Each metric is internally represented as an accumulator
+  private val _executorDeserializeTime = getAccum(EXECUTOR_DESERIALIZE_TIME)
+  private val _executorRunTime = getAccum(EXECUTOR_RUN_TIME)
+  private val _resultSize = getAccum(RESULT_SIZE)
+  private val _jvmGCTime = getAccum(JVM_GC_TIME)
+  private val _resultSerializationTime = getAccum(RESULT_SERIALIZATION_TIME)
+  private val _memoryBytesSpilled = getAccum(MEMORY_BYTES_SPILLED)
+  private val _diskBytesSpilled = getAccum(DISK_BYTES_SPILLED)
+  private val _peakExecutionMemory = getAccum(PEAK_EXECUTION_MEMORY)
+  private val _updatedBlockStatuses =
+    TaskMetrics.getAccum[Seq[(BlockId, BlockStatus)]](initialAccumsMap, UPDATED_BLOCK_STATUSES)
+
   /**
-   * Host's name the task runs on
+   * Time taken on the executor to deserialize this task.
    */
-  private var _hostname: String = _
-  def hostname: String = _hostname
-  private[spark] def setHostname(value: String) = _hostname = value
+  def executorDeserializeTime: Long = _executorDeserializeTime.localValue
 
   /**
-   * Time taken on the executor to deserialize this task
+   * Time the executor spends actually running the task (including fetching shuffle data).
    */
-  private var _executorDeserializeTime: Long = _
-  def executorDeserializeTime: Long = _executorDeserializeTime
-  private[spark] def setExecutorDeserializeTime(value: Long) = _executorDeserializeTime = value
+  def executorRunTime: Long = _executorRunTime.localValue
 
+  /**
+   * The number of bytes this task transmitted back to the driver as the TaskResult.
+   */
+  def resultSize: Long = _resultSize.localValue
 
   /**
-   * Time the executor spends actually running the task (including fetching shuffle data)
+   * Amount of time the JVM spent in garbage collection while executing this task.
    */
-  private var _executorRunTime: Long = _
-  def executorRunTime: Long = _executorRunTime
-  private[spark] def setExecutorRunTime(value: Long) = _executorRunTime = value
+  def jvmGCTime: Long = _jvmGCTime.localValue
 
   /**
-   * The number of bytes this task transmitted back to the driver as the TaskResult
+   * Amount of time spent serializing the task result.
    */
-  private var _resultSize: Long = _
-  def resultSize: Long = _resultSize
-  private[spark] def setResultSize(value: Long) = _resultSize = value
+  def resultSerializationTime: Long = _resultSerializationTime.localValue
 
+  /**
+   * The number of in-memory bytes spilled by this task.
+   */
+  def memoryBytesSpilled: Long = _memoryBytesSpilled.localValue
 
   /**
-   * Amount of time the JVM spent in garbage collection while executing this task
+   * The number of on-disk bytes spilled by this task.
    */
-  private var _jvmGCTime: Long = _
-  def jvmGCTime: Long = _jvmGCTime
-  private[spark] def setJvmGCTime(value: Long) = _jvmGCTime = value
+  def diskBytesSpilled: Long = _diskBytesSpilled.localValue
 
   /**
-   * Amount of time spent serializing the task result
+   * Peak memory used by internal data structures created during shuffles, aggregations and
+   * joins. The value of this accumulator should be approximately the sum of the peak sizes
+   * across all such data structures created in this task. For SQL jobs, this only tracks all
+   * unsafe operators and ExternalSort.
    */
-  private var _resultSerializationTime: Long = _
-  def resultSerializationTime: Long = _resultSerializationTime
-  private[spark] def setResultSerializationTime(value: Long) = _resultSerializationTime = value
+  def peakExecutionMemory: Long = _peakExecutionMemory.localValue
 
   /**
-   * The number of in-memory bytes spilled by this task
+   * Storage statuses of any blocks that have been updated as a result of this task.
    */
-  private var _memoryBytesSpilled: Long = _
-  def memoryBytesSpilled: Long = _memoryBytesSpilled
-  private[spark] def incMemoryBytesSpilled(value: Long): Unit = _memoryBytesSpilled += value
-  private[spark] def decMemoryBytesSpilled(value: Long): Unit = _memoryBytesSpilled -= value
+  def updatedBlockStatuses: Seq[(BlockId, BlockStatus)] = _updatedBlockStatuses.localValue
+
+  @deprecated("use updatedBlockStatuses instead", "2.0.0")
+  def updatedBlocks: Option[Seq[(BlockId, BlockStatus)]] = {
+    if (updatedBlockStatuses.nonEmpty) Some(updatedBlockStatuses) else None
+  }
+
+  // Setters and increment-ers
+  private[spark] def setExecutorDeserializeTime(v: Long): Unit =
+    _executorDeserializeTime.setValue(v)
+  private[spark] def setExecutorRunTime(v: Long): Unit = _executorRunTime.setValue(v)
+  private[spark] def setResultSize(v: Long): Unit = _resultSize.setValue(v)
+  private[spark] def setJvmGCTime(v: Long): Unit = _jvmGCTime.setValue(v)
+  private[spark] def setResultSerializationTime(v: Long): Unit =
+    _resultSerializationTime.setValue(v)
+  private[spark] def incMemoryBytesSpilled(v: Long): Unit = _memoryBytesSpilled.add(v)
+  private[spark] def incDiskBytesSpilled(v: Long): Unit = _diskBytesSpilled.add(v)
+  private[spark] def incPeakExecutionMemory(v: Long): Unit = _peakExecutionMemory.add(v)
+  private[spark] def incUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit =
+    _updatedBlockStatuses.add(v)
+  private[spark] def setUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit =
+    _updatedBlockStatuses.setValue(v)
 
   /**
-   * The number of on-disk bytes spilled by this task
+   * Get a Long accumulator from the given map by name, assuming it exists.
+   * Note: this only searches the initial set of accumulators passed into the constructor.
    */
-  private var _diskBytesSpilled: Long = _
-  def diskBytesSpilled: Long = _diskBytesSpilled
-  private[spark] def incDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled += value
-  private[spark] def decDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled -= value
+  private[spark] def getAccum(name: String): Accumulator[Long] = {
+    TaskMetrics.getAccum[Long](initialAccumsMap, name)
+  }
+
+
+  /* ========================== *
+   |        INPUT METRICS       |
+   * ========================== */
 
   private var _inputMetrics: Option[InputMetrics] = None
 
@@ -116,7 +187,8 @@ class TaskMetrics extends Serializable {
   private[spark] def registerInputMetrics(readMethod: DataReadMethod.Value): InputMetrics = {
     synchronized {
       val metrics = _inputMetrics.getOrElse {
-        val metrics = new InputMetrics(readMethod)
+        val metrics = new InputMetrics(initialAccumsMap)
+        metrics.setReadMethod(readMethod)
         _inputMetrics = Some(metrics)
         metrics
       }
@@ -128,18 +200,17 @@ class TaskMetrics extends Serializable {
       if (metrics.readMethod == readMethod) {
         metrics
       } else {
-        new InputMetrics(readMethod)
+        val m = new InputMetrics
+        m.setReadMethod(readMethod)
+        m
       }
     }
   }
 
-  /**
-   * This should only be used when recreating TaskMetrics, not when updating input metrics in
-   * executors
-   */
-  private[spark] def setInputMetrics(inputMetrics: Option[InputMetrics]) {
-    _inputMetrics = inputMetrics
-  }
+
+  /* ============================ *
+   |        OUTPUT METRICS        |
+   * ============================ */
 
   private var _outputMetrics: Option[OutputMetrics] = None
 
@@ -149,23 +220,24 @@ class TaskMetrics extends Serializable {
    */
   def outputMetrics: Option[OutputMetrics] = _outputMetrics
 
-  @deprecated("setting OutputMetrics is for internal use only", "2.0.0")
-  def outputMetrics_=(om: Option[OutputMetrics]): Unit = {
-    _outputMetrics = om
-  }
-
   /**
    * Get or create a new [[OutputMetrics]] associated with this task.
    */
   private[spark] def registerOutputMetrics(
       writeMethod: DataWriteMethod.Value): OutputMetrics = synchronized {
     _outputMetrics.getOrElse {
-      val metrics = new OutputMetrics(writeMethod)
+      val metrics = new OutputMetrics(initialAccumsMap)
+      metrics.setWriteMethod(writeMethod)
       _outputMetrics = Some(metrics)
       metrics
     }
   }
 
+
+  /* ================================== *
+   |        SHUFFLE READ METRICS        |
+   * ================================== */
+
   private var _shuffleReadMetrics: Option[ShuffleReadMetrics] = None
 
   /**
@@ -174,21 +246,13 @@ class TaskMetrics extends Serializable {
    */
   def shuffleReadMetrics: Option[ShuffleReadMetrics] = _shuffleReadMetrics
 
-  /**
-   * This should only be used when recreating TaskMetrics, not when updating read metrics in
-   * executors.
-   */
-  private[spark] def setShuffleReadMetrics(shuffleReadMetrics: Option[ShuffleReadMetrics]) {
-    _shuffleReadMetrics = shuffleReadMetrics
-  }
-
   /**
    * Temporary list of [[ShuffleReadMetrics]], one per shuffle dependency.
    *
    * A task may have multiple shuffle readers for multiple dependencies. To avoid synchronization
    * issues from readers in different threads, in-progress tasks use a [[ShuffleReadMetrics]] for
    * each dependency and merge these metrics before reporting them to the driver.
-  */
+   */
   @transient private lazy val tempShuffleReadMetrics = new ArrayBuffer[ShuffleReadMetrics]
 
   /**
@@ -210,19 +274,21 @@ class TaskMetrics extends Serializable {
    */
   private[spark] def mergeShuffleReadMetrics(): Unit = synchronized {
     if (tempShuffleReadMetrics.nonEmpty) {
-      val merged = new ShuffleReadMetrics
-      for (depMetrics <- tempShuffleReadMetrics) {
-        merged.incFetchWaitTime(depMetrics.fetchWaitTime)
-        merged.incLocalBlocksFetched(depMetrics.localBlocksFetched)
-        merged.incRemoteBlocksFetched(depMetrics.remoteBlocksFetched)
-        merged.incRemoteBytesRead(depMetrics.remoteBytesRead)
-        merged.incLocalBytesRead(depMetrics.localBytesRead)
-        merged.incRecordsRead(depMetrics.recordsRead)
-      }
-      _shuffleReadMetrics = Some(merged)
+      val metrics = new ShuffleReadMetrics(initialAccumsMap)
+      metrics.setRemoteBlocksFetched(tempShuffleReadMetrics.map(_.remoteBlocksFetched).sum)
+      metrics.setLocalBlocksFetched(tempShuffleReadMetrics.map(_.localBlocksFetched).sum)
+      metrics.setFetchWaitTime(tempShuffleReadMetrics.map(_.fetchWaitTime).sum)
+      metrics.setRemoteBytesRead(tempShuffleReadMetrics.map(_.remoteBytesRead).sum)
+      metrics.setLocalBytesRead(tempShuffleReadMetrics.map(_.localBytesRead).sum)
+      metrics.setRecordsRead(tempShuffleReadMetrics.map(_.recordsRead).sum)
+      _shuffleReadMetrics = Some(metrics)
     }
   }
 
+  /* =================================== *
+   |        SHUFFLE WRITE METRICS        |
+   * =================================== */
+
   private var _shuffleWriteMetrics: Option[ShuffleWriteMetrics] = None
 
   /**
@@ -230,86 +296,120 @@ class TaskMetrics extends Serializable {
    */
   def shuffleWriteMetrics: Option[ShuffleWriteMetrics] = _shuffleWriteMetrics
 
-  @deprecated("setting ShuffleWriteMetrics is for internal use only", "2.0.0")
-  def shuffleWriteMetrics_=(swm: Option[ShuffleWriteMetrics]): Unit = {
-    _shuffleWriteMetrics = swm
-  }
-
   /**
    * Get or create a new [[ShuffleWriteMetrics]] associated with this task.
    */
   private[spark] def registerShuffleWriteMetrics(): ShuffleWriteMetrics = synchronized {
     _shuffleWriteMetrics.getOrElse {
-      val metrics = new ShuffleWriteMetrics
+      val metrics = new ShuffleWriteMetrics(initialAccumsMap)
       _shuffleWriteMetrics = Some(metrics)
       metrics
     }
   }
 
-  private var _updatedBlockStatuses: Seq[(BlockId, BlockStatus)] =
-    Seq.empty[(BlockId, BlockStatus)]
 
-  /**
-   * Storage statuses of any blocks that have been updated as a result of this task.
-   */
-  def updatedBlockStatuses: Seq[(BlockId, BlockStatus)] = _updatedBlockStatuses
+  /* ========================== *
+   |        OTHER THINGS        |
+   * ========================== */
 
-  @deprecated("setting updated blocks is for internal use only", "2.0.0")
-  def updatedBlocks_=(ub: Option[Seq[(BlockId, BlockStatus)]]): Unit = {
-    _updatedBlockStatuses = ub.getOrElse(Seq.empty[(BlockId, BlockStatus)])
+  private[spark] def registerAccumulator(a: Accumulable[_, _]): Unit = {
+    accums += a
   }
 
-  private[spark] def incUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit = {
-    _updatedBlockStatuses ++= v
+  /**
+   * Return the latest updates of accumulators in this task.
+   *
+   * The [[AccumulableInfo.update]] field is always defined and the [[AccumulableInfo.value]]
+   * field is always empty, since this represents the partial updates recorded in this task,
+   * not the aggregated value across multiple tasks.
+   */
+  def accumulatorUpdates(): Seq[AccumulableInfo] = accums.map { a =>
+    new AccumulableInfo(a.id, a.name, Some(a.localValue), None, a.isInternal, a.countFailedValues)
   }
 
-  private[spark] def setUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit = {
-    _updatedBlockStatuses = v
+  // If we are reconstructing this TaskMetrics on the driver, some metrics may already be set.
+  // If so, initialize all relevant metrics classes so listeners can access them downstream.
+  {
+    var (hasShuffleRead, hasShuffleWrite, hasInput, hasOutput) = (false, false, false, false)
+    initialAccums
+      .filter { a => a.localValue != a.zero }
+      .foreach { a =>
+        a.name.get match {
+          case sr if sr.startsWith(SHUFFLE_READ_METRICS_PREFIX) => hasShuffleRead = true
+          case sw if sw.startsWith(SHUFFLE_WRITE_METRICS_PREFIX) => hasShuffleWrite = true
+          case in if in.startsWith(INPUT_METRICS_PREFIX) => hasInput = true
+          case out if out.startsWith(OUTPUT_METRICS_PREFIX) => hasOutput = true
+          case _ =>
+        }
+      }
+    if (hasShuffleRead) { _shuffleReadMetrics = Some(new ShuffleReadMetrics(initialAccumsMap)) }
+    if (hasShuffleWrite) { _shuffleWriteMetrics = Some(new ShuffleWriteMetrics(initialAccumsMap)) }
+    if (hasInput) { _inputMetrics = Some(new InputMetrics(initialAccumsMap)) }
+    if (hasOutput) { _outputMetrics = Some(new OutputMetrics(initialAccumsMap)) }
   }
 
-  @deprecated("use updatedBlockStatuses instead", "2.0.0")
-  def updatedBlocks: Option[Seq[(BlockId, BlockStatus)]] = {
-    if (_updatedBlockStatuses.nonEmpty) Some(_updatedBlockStatuses) else None
-  }
+}
 
-  private[spark] def updateInputMetrics(): Unit = synchronized {
-    inputMetrics.foreach(_.updateBytesRead())
-  }
+private[spark] object TaskMetrics extends Logging {
 
-  @throws(classOf[IOException])
-  private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
-    in.defaultReadObject()
-    // Get the hostname from cached data, since hostname is the order of number of nodes in
-    // cluster, so using cached hostname will decrease the object number and alleviate the GC
-    // overhead.
-    _hostname = TaskMetrics.getCachedHostName(_hostname)
-  }
-
-  private var _accumulatorUpdates: Map[Long, Any] = Map.empty
-  @transient private var _accumulatorsUpdater: () => Map[Long, Any] = null
+  def empty: TaskMetrics = new TaskMetrics
 
-  private[spark] def updateAccumulators(): Unit = synchronized {
-    _accumulatorUpdates = _accumulatorsUpdater()
+  /**
+   * Get an accumulator from the given map by name, assuming it exists.
+   */
+  def getAccum[T](accumMap: Map[String, Accumulator[_]], name: String): Accumulator[T] = {
+    require(accumMap.contains(name), s"metric '$name' is missing")
+    val accum = accumMap(name)
+    try {
+      // Note: we can't do pattern matching here because types are erased by compile time
+      accum.asInstanceOf[Accumulator[T]]
+    } catch {
+      case e: ClassCastException =>
+        throw new SparkException(s"accumulator $name was of unexpected type", e)
+    }
   }
 
   /**
-   * Return the latest updates of accumulators in this task.
+   * Construct a [[TaskMetrics]] object from a list of accumulator updates, called on driver only.
+   *
+   * Executors only send accumulator updates back to the driver, not [[TaskMetrics]]. However, we
+   * need the latter to post task end events to listeners, so we need to reconstruct the metrics
+   * on the driver.
+   *
+   * This assumes the provided updates contain the initial set of accumulators representing
+   * internal task level metrics.
    */
-  def accumulatorUpdates(): Map[Long, Any] = _accumulatorUpdates
-
-  private[spark] def setAccumulatorsUpdater(accumulatorsUpdater: () => Map[Long, Any]): Unit = {
-    _accumulatorsUpdater = accumulatorsUpdater
+  def fromAccumulatorUpdates(accumUpdates: Seq[AccumulableInfo]): TaskMetrics = {
+    // Initial accumulators are passed into the TaskMetrics constructor first because these
+    // are required to be uniquely named. The rest of the accumulators from this task are
+    // registered later because they need not satisfy this requirement.
+    val (initialAccumInfos, otherAccumInfos) = accumUpdates
+      .filter { info => info.update.isDefined }
+      .partition { info => info.name.exists(_.startsWith(InternalAccumulator.METRICS_PREFIX)) }
+    val initialAccums = initialAccumInfos.map { info =>
+      val accum = InternalAccumulator.create(info.name.get)
+      accum.setValueAny(info.update.get)
+      accum
+    }
+    // We don't know the types of the rest of the accumulators, so we try to find the same ones
+    // that were previously registered here on the driver and make copies of them. It is important
+    // that we copy the accumulators here since they are used across many tasks and we want to
+    // maintain a snapshot of their local task values when we post them to listeners downstream.
+    val otherAccums = otherAccumInfos.flatMap { info =>
+      val id = info.id
+      val acc = Accumulators.get(id).map { a =>
+        val newAcc = a.copy()
+        newAcc.setValueAny(info.update.get)
+        newAcc
+      }
+      if (acc.isEmpty) {
+        logWarning(s"encountered unregistered accumulator $id when reconstructing task metrics.")
+      }
+      acc
+    }
+    val metrics = new TaskMetrics(initialAccums)
+    otherAccums.foreach(metrics.registerAccumulator)
+    metrics
   }
-}
-
 
-private[spark] object TaskMetrics {
-  private val hostNameCache = new ConcurrentHashMap[String, String]()
-
-  def empty: TaskMetrics = new TaskMetrics
-
-  def getCachedHostName(host: String): String = {
-    val canonicalHost = hostNameCache.putIfAbsent(host, host)
-    if (canonicalHost != null) canonicalHost else host
-  }
 }
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index 3587e7eb1afaf1d2a96f49528d91f506b99d794d..d9b0824b38ecce0b6686d8c1d2591742e7d9d436 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -153,8 +153,7 @@ class CoGroupedRDD[K: ClassTag](
     }
     context.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled)
     context.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled)
-    context.internalMetricsToAccumulators(
-      InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes)
+    context.taskMetrics().incPeakExecutionMemory(map.peakMemoryUsedBytes)
     new InterruptibleIterator(context,
       map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]])
   }
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index a79ab86d49227ff0946cb62dbbd18086200ec9a1..3204e6adceca2c0ba6d806b92cc02a612acb229f 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -212,6 +212,8 @@ class HadoopRDD[K, V](
       logInfo("Input split: " + split.inputSplit)
       val jobConf = getJobConf()
 
+      // TODO: there is a lot of duplicate code between this and NewHadoopRDD and SqlNewHadoopRDD
+
       val inputMetrics = context.taskMetrics().registerInputMetrics(DataReadMethod.Hadoop)
 
       // Sets the thread local variable for the file's name
@@ -222,14 +224,17 @@ class HadoopRDD[K, V](
 
       // Find a function that will return the FileSystem bytes read by this thread. Do this before
       // creating RecordReader, because RecordReader's constructor might read some bytes
-      val bytesReadCallback = inputMetrics.bytesReadCallback.orElse {
-        split.inputSplit.value match {
-          case _: FileSplit | _: CombineFileSplit =>
-            SparkHadoopUtil.get.getFSBytesReadOnThreadCallback()
-          case _ => None
+      val getBytesReadCallback: Option[() => Long] = split.inputSplit.value match {
+        case _: FileSplit | _: CombineFileSplit =>
+          SparkHadoopUtil.get.getFSBytesReadOnThreadCallback()
+        case _ => None
+      }
+
+      def updateBytesRead(): Unit = {
+        getBytesReadCallback.foreach { getBytesRead =>
+          inputMetrics.setBytesRead(getBytesRead())
         }
       }
-      inputMetrics.setBytesReadCallback(bytesReadCallback)
 
       var reader: RecordReader[K, V] = null
       val inputFormat = getInputFormat(jobConf)
@@ -252,6 +257,9 @@ class HadoopRDD[K, V](
         if (!finished) {
           inputMetrics.incRecordsRead(1)
         }
+        if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) {
+          updateBytesRead()
+        }
         (key, value)
       }
 
@@ -272,8 +280,8 @@ class HadoopRDD[K, V](
           } finally {
             reader = null
           }
-          if (bytesReadCallback.isDefined) {
-            inputMetrics.updateBytesRead()
+          if (getBytesReadCallback.isDefined) {
+            updateBytesRead()
           } else if (split.inputSplit.value.isInstanceOf[FileSplit] ||
                      split.inputSplit.value.isInstanceOf[CombineFileSplit]) {
             // If we can't get the bytes read from the FS stats, fall back to the split size,
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index 5cc9c81cc6749cc069d4b423999f3fdca2a5ed4c..4d2816e335fe3846e8eb5b30bd37a2a6d0177c26 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -133,14 +133,17 @@ class NewHadoopRDD[K, V](
 
       // Find a function that will return the FileSystem bytes read by this thread. Do this before
       // creating RecordReader, because RecordReader's constructor might read some bytes
-      val bytesReadCallback = inputMetrics.bytesReadCallback.orElse {
-        split.serializableHadoopSplit.value match {
-          case _: FileSplit | _: CombineFileSplit =>
-            SparkHadoopUtil.get.getFSBytesReadOnThreadCallback()
-          case _ => None
+      val getBytesReadCallback: Option[() => Long] = split.serializableHadoopSplit.value match {
+        case _: FileSplit | _: CombineFileSplit =>
+          SparkHadoopUtil.get.getFSBytesReadOnThreadCallback()
+        case _ => None
+      }
+
+      def updateBytesRead(): Unit = {
+        getBytesReadCallback.foreach { getBytesRead =>
+          inputMetrics.setBytesRead(getBytesRead())
         }
       }
-      inputMetrics.setBytesReadCallback(bytesReadCallback)
 
       val format = inputFormatClass.newInstance
       format match {
@@ -182,6 +185,9 @@ class NewHadoopRDD[K, V](
         if (!finished) {
           inputMetrics.incRecordsRead(1)
         }
+        if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) {
+          updateBytesRead()
+        }
         (reader.getCurrentKey, reader.getCurrentValue)
       }
 
@@ -201,8 +207,8 @@ class NewHadoopRDD[K, V](
           } finally {
             reader = null
           }
-          if (bytesReadCallback.isDefined) {
-            inputMetrics.updateBytesRead()
+          if (getBytesReadCallback.isDefined) {
+            updateBytesRead()
           } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] ||
                      split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) {
             // If we can't get the bytes read from the FS stats, fall back to the split size,
diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala
index 146cfb9ba80373d65f77bf328db41ac9b1b5ff8b..9d45fff9213c69b8682f7c1e358ee9b37734e161 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala
@@ -19,47 +19,58 @@ package org.apache.spark.scheduler
 
 import org.apache.spark.annotation.DeveloperApi
 
+
 /**
  * :: DeveloperApi ::
  * Information about an [[org.apache.spark.Accumulable]] modified during a task or stage.
+ *
+ * Note: once this is JSON serialized the types of `update` and `value` will be lost and be
+ * cast to strings. This is because the user can define an accumulator of any type and it will
+ * be difficult to preserve the type in consumers of the event log. This does not apply to
+ * internal accumulators that represent task level metrics.
+ *
+ * @param id accumulator ID
+ * @param name accumulator name
+ * @param update partial value from a task, may be None if used on driver to describe a stage
+ * @param value total accumulated value so far, maybe None if used on executors to describe a task
+ * @param internal whether this accumulator was internal
+ * @param countFailedValues whether to count this accumulator's partial value if the task failed
  */
 @DeveloperApi
-class AccumulableInfo private[spark] (
-    val id: Long,
-    val name: String,
-    val update: Option[String], // represents a partial update within a task
-    val value: String,
-    val internal: Boolean) {
-
-  override def equals(other: Any): Boolean = other match {
-    case acc: AccumulableInfo =>
-      this.id == acc.id && this.name == acc.name &&
-        this.update == acc.update && this.value == acc.value &&
-        this.internal == acc.internal
-    case _ => false
-  }
+case class AccumulableInfo private[spark] (
+    id: Long,
+    name: Option[String],
+    update: Option[Any], // represents a partial update within a task
+    value: Option[Any],
+    private[spark] val internal: Boolean,
+    private[spark] val countFailedValues: Boolean)
 
-  override def hashCode(): Int = {
-    val state = Seq(id, name, update, value, internal)
-    state.map(_.hashCode).reduceLeft(31 * _ + _)
-  }
-}
 
+/**
+ * A collection of deprecated constructors. This will be removed soon.
+ */
 object AccumulableInfo {
+
+  @deprecated("do not create AccumulableInfo", "2.0.0")
   def apply(
       id: Long,
       name: String,
       update: Option[String],
       value: String,
       internal: Boolean): AccumulableInfo = {
-    new AccumulableInfo(id, name, update, value, internal)
+    new AccumulableInfo(
+      id, Option(name), update, Option(value), internal, countFailedValues = false)
   }
 
+  @deprecated("do not create AccumulableInfo", "2.0.0")
   def apply(id: Long, name: String, update: Option[String], value: String): AccumulableInfo = {
-    new AccumulableInfo(id, name, update, value, internal = false)
+    new AccumulableInfo(
+      id, Option(name), update, Option(value), internal = false, countFailedValues = false)
   }
 
+  @deprecated("do not create AccumulableInfo", "2.0.0")
   def apply(id: Long, name: String, value: String): AccumulableInfo = {
-    new AccumulableInfo(id, name, None, value, internal = false)
+    new AccumulableInfo(
+      id, Option(name), None, Option(value), internal = false, countFailedValues = false)
   }
 }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 6b01a10fc136e260f1dd9a500cd51ada004e7855..897479b50010daa586e5c7c13bd0202237d96e21 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -208,11 +208,10 @@ class DAGScheduler(
       task: Task[_],
       reason: TaskEndReason,
       result: Any,
-      accumUpdates: Map[Long, Any],
-      taskInfo: TaskInfo,
-      taskMetrics: TaskMetrics): Unit = {
+      accumUpdates: Seq[AccumulableInfo],
+      taskInfo: TaskInfo): Unit = {
     eventProcessLoop.post(
-      CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics))
+      CompletionEvent(task, reason, result, accumUpdates, taskInfo))
   }
 
   /**
@@ -222,9 +221,10 @@ class DAGScheduler(
    */
   def executorHeartbeatReceived(
       execId: String,
-      taskMetrics: Array[(Long, Int, Int, TaskMetrics)], // (taskId, stageId, stateAttempt, metrics)
+      // (taskId, stageId, stageAttemptId, accumUpdates)
+      accumUpdates: Array[(Long, Int, Int, Seq[AccumulableInfo])],
       blockManagerId: BlockManagerId): Boolean = {
-    listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, taskMetrics))
+    listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, accumUpdates))
     blockManagerMaster.driverEndpoint.askWithRetry[Boolean](
       BlockManagerHeartbeat(blockManagerId), new RpcTimeout(600 seconds, "BlockManagerHeartbeat"))
   }
@@ -1074,39 +1074,43 @@ class DAGScheduler(
     }
   }
 
-  /** Merge updates from a task to our local accumulator values */
+  /**
+   * Merge local values from a task into the corresponding accumulators previously registered
+   * here on the driver.
+   *
+   * Although accumulators themselves are not thread-safe, this method is called only from one
+   * thread, the one that runs the scheduling loop. This means we only handle one task
+   * completion event at a time so we don't need to worry about locking the accumulators.
+   * This still doesn't stop the caller from updating the accumulator outside the scheduler,
+   * but that's not our problem since there's nothing we can do about that.
+   */
   private def updateAccumulators(event: CompletionEvent): Unit = {
     val task = event.task
     val stage = stageIdToStage(task.stageId)
-    if (event.accumUpdates != null) {
-      try {
-        Accumulators.add(event.accumUpdates)
-
-        event.accumUpdates.foreach { case (id, partialValue) =>
-          // In this instance, although the reference in Accumulators.originals is a WeakRef,
-          // it's guaranteed to exist since the event.accumUpdates Map exists
-
-          val acc = Accumulators.originals(id).get match {
-            case Some(accum) => accum.asInstanceOf[Accumulable[Any, Any]]
-            case None => throw new NullPointerException("Non-existent reference to Accumulator")
-          }
-
-          // To avoid UI cruft, ignore cases where value wasn't updated
-          if (acc.name.isDefined && partialValue != acc.zero) {
-            val name = acc.name.get
-            val value = s"${acc.value}"
-            stage.latestInfo.accumulables(id) =
-              new AccumulableInfo(id, name, None, value, acc.isInternal)
-            event.taskInfo.accumulables +=
-              new AccumulableInfo(id, name, Some(s"$partialValue"), value, acc.isInternal)
-          }
+    try {
+      event.accumUpdates.foreach { ainfo =>
+        assert(ainfo.update.isDefined, "accumulator from task should have a partial value")
+        val id = ainfo.id
+        val partialValue = ainfo.update.get
+        // Find the corresponding accumulator on the driver and update it
+        val acc: Accumulable[Any, Any] = Accumulators.get(id) match {
+          case Some(accum) => accum.asInstanceOf[Accumulable[Any, Any]]
+          case None =>
+            throw new SparkException(s"attempted to access non-existent accumulator $id")
+        }
+        acc ++= partialValue
+        // To avoid UI cruft, ignore cases where value wasn't updated
+        if (acc.name.isDefined && partialValue != acc.zero) {
+          val name = acc.name
+          stage.latestInfo.accumulables(id) = new AccumulableInfo(
+            id, name, None, Some(acc.value), acc.isInternal, acc.countFailedValues)
+          event.taskInfo.accumulables += new AccumulableInfo(
+            id, name, Some(partialValue), Some(acc.value), acc.isInternal, acc.countFailedValues)
         }
-      } catch {
-        // If we see an exception during accumulator update, just log the
-        // error and move on.
-        case e: Exception =>
-          logError(s"Failed to update accumulators for $task", e)
       }
+    } catch {
+      case NonFatal(e) =>
+        logError(s"Failed to update accumulators for task ${task.partitionId}", e)
     }
   }
 
@@ -1116,6 +1120,7 @@ class DAGScheduler(
    */
   private[scheduler] def handleTaskCompletion(event: CompletionEvent) {
     val task = event.task
+    val taskId = event.taskInfo.id
     val stageId = task.stageId
     val taskType = Utils.getFormattedClassName(task)
 
@@ -1125,12 +1130,26 @@ class DAGScheduler(
       event.taskInfo.attemptNumber, // this is a task attempt number
       event.reason)
 
-    // The success case is dealt with separately below, since we need to compute accumulator
-    // updates before posting.
+    // Reconstruct task metrics. Note: this may be null if the task has failed.
+    val taskMetrics: TaskMetrics =
+      if (event.accumUpdates.nonEmpty) {
+        try {
+          TaskMetrics.fromAccumulatorUpdates(event.accumUpdates)
+        } catch {
+          case NonFatal(e) =>
+            logError(s"Error when attempting to reconstruct metrics for task $taskId", e)
+            null
+        }
+      } else {
+        null
+      }
+
+    // The success case is dealt with separately below.
+    // TODO: Why post it only for failed tasks in cancelled stages? Clarify semantics here.
     if (event.reason != Success) {
       val attemptId = task.stageAttemptId
-      listenerBus.post(SparkListenerTaskEnd(stageId, attemptId, taskType, event.reason,
-        event.taskInfo, event.taskMetrics))
+      listenerBus.post(SparkListenerTaskEnd(
+        stageId, attemptId, taskType, event.reason, event.taskInfo, taskMetrics))
     }
 
     if (!stageIdToStage.contains(task.stageId)) {
@@ -1142,7 +1161,7 @@ class DAGScheduler(
     event.reason match {
       case Success =>
         listenerBus.post(SparkListenerTaskEnd(stageId, stage.latestInfo.attemptId, taskType,
-          event.reason, event.taskInfo, event.taskMetrics))
+          event.reason, event.taskInfo, taskMetrics))
         stage.pendingPartitions -= task.partitionId
         task match {
           case rt: ResultTask[_, _] =>
@@ -1291,7 +1310,8 @@ class DAGScheduler(
         // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits
 
       case exceptionFailure: ExceptionFailure =>
-        // Do nothing here, left up to the TaskScheduler to decide how to handle user failures
+        // Tasks failed with exceptions might still have accumulator updates.
+        updateAccumulators(event)
 
       case TaskResultLost =>
         // Do nothing here; the TaskScheduler handles these failures and resubmits the task.
@@ -1637,7 +1657,7 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler
     case GettingResultEvent(taskInfo) =>
       dagScheduler.handleGetTaskResult(taskInfo)
 
-    case completion @ CompletionEvent(task, reason, _, _, taskInfo, taskMetrics) =>
+    case completion: CompletionEvent =>
       dagScheduler.handleTaskCompletion(completion)
 
     case TaskSetFailed(taskSet, reason, exception) =>
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
index dda3b6cc7f960e7428baed2718b11e3178e82629..d5cd2da7a10d199b24a67d500e0edf3a57fdf121 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
@@ -73,9 +73,8 @@ private[scheduler] case class CompletionEvent(
     task: Task[_],
     reason: TaskEndReason,
     result: Any,
-    accumUpdates: Map[Long, Any],
-    taskInfo: TaskInfo,
-    taskMetrics: TaskMetrics)
+    accumUpdates: Seq[AccumulableInfo],
+    taskInfo: TaskInfo)
   extends DAGSchedulerEvent
 
 private[scheduler] case class ExecutorAdded(execId: String, host: String) extends DAGSchedulerEvent
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
index 6590cf6ffd24fbe25d135ea457852bd207d80bcf..885f70e89fbf5a50547f5f1e85408326e18d28ff 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -30,6 +30,7 @@ import org.apache.spark.rdd.RDD
  * See [[Task]] for more information.
  *
  * @param stageId id of the stage this task belongs to
+ * @param stageAttemptId attempt id of the stage this task belongs to
  * @param taskBinary broadcasted version of the serialized RDD and the function to apply on each
  *                   partition of the given RDD. Once deserialized, the type should be
  *                   (RDD[T], (TaskContext, Iterator[T]) => U).
@@ -37,6 +38,9 @@ import org.apache.spark.rdd.RDD
  * @param locs preferred task execution locations for locality scheduling
  * @param outputId index of the task in this job (a job can launch tasks on only a subset of the
  *                 input RDD's partitions).
+ * @param _initialAccums initial set of accumulators to be used in this task for tracking
+ *                       internal metrics. Other accumulators will be registered later when
+ *                       they are deserialized on the executors.
  */
 private[spark] class ResultTask[T, U](
     stageId: Int,
@@ -45,8 +49,8 @@ private[spark] class ResultTask[T, U](
     partition: Partition,
     locs: Seq[TaskLocation],
     val outputId: Int,
-    internalAccumulators: Seq[Accumulator[Long]])
-  extends Task[U](stageId, stageAttemptId, partition.index, internalAccumulators)
+    _initialAccums: Seq[Accumulator[_]] = InternalAccumulator.create())
+  extends Task[U](stageId, stageAttemptId, partition.index, _initialAccums)
   with Serializable {
 
   @transient private[this] val preferredLocs: Seq[TaskLocation] = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index ea97ef0e746d8658a6a22e1d5f4d4ff64797d308..89207dd175ae958d3c47c0dc5494c97f7f2ba197 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -33,10 +33,14 @@ import org.apache.spark.shuffle.ShuffleWriter
  * See [[org.apache.spark.scheduler.Task]] for more information.
  *
  * @param stageId id of the stage this task belongs to
+ * @param stageAttemptId attempt id of the stage this task belongs to
  * @param taskBinary broadcast version of the RDD and the ShuffleDependency. Once deserialized,
  *                   the type should be (RDD[_], ShuffleDependency[_, _, _]).
  * @param partition partition of the RDD this task is associated with
  * @param locs preferred task execution locations for locality scheduling
+ * @param _initialAccums initial set of accumulators to be used in this task for tracking
+ *                       internal metrics. Other accumulators will be registered later when
+ *                       they are deserialized on the executors.
  */
 private[spark] class ShuffleMapTask(
     stageId: Int,
@@ -44,8 +48,8 @@ private[spark] class ShuffleMapTask(
     taskBinary: Broadcast[Array[Byte]],
     partition: Partition,
     @transient private var locs: Seq[TaskLocation],
-    internalAccumulators: Seq[Accumulator[Long]])
-  extends Task[MapStatus](stageId, stageAttemptId, partition.index, internalAccumulators)
+    _initialAccums: Seq[Accumulator[_]])
+  extends Task[MapStatus](stageId, stageAttemptId, partition.index, _initialAccums)
   with Logging {
 
   /** A constructor used only in test suites. This does not require passing in an RDD. */
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
index 6c6883d703beaa38ebd5d30dcfd062d6723accd4..ed3adbd81c28e08dbd1ef622d0760509474acfe2 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.scheduler
 
 import java.util.Properties
+import javax.annotation.Nullable
 
 import scala.collection.Map
 import scala.collection.mutable
@@ -60,7 +61,7 @@ case class SparkListenerTaskEnd(
     taskType: String,
     reason: TaskEndReason,
     taskInfo: TaskInfo,
-    taskMetrics: TaskMetrics)
+    @Nullable taskMetrics: TaskMetrics)
   extends SparkListenerEvent
 
 @DeveloperApi
@@ -111,12 +112,12 @@ case class SparkListenerBlockUpdated(blockUpdatedInfo: BlockUpdatedInfo) extends
 /**
  * Periodic updates from executors.
  * @param execId executor id
- * @param taskMetrics sequence of (task id, stage id, stage attempt, metrics)
+ * @param accumUpdates sequence of (taskId, stageId, stageAttemptId, accumUpdates)
  */
 @DeveloperApi
 case class SparkListenerExecutorMetricsUpdate(
     execId: String,
-    taskMetrics: Seq[(Long, Int, Int, TaskMetrics)])
+    accumUpdates: Seq[(Long, Int, Int, Seq[AccumulableInfo])])
   extends SparkListenerEvent
 
 @DeveloperApi
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index 7ea24a217bd39d77ae6b3a9d99aaaceadbd9465b..c1c8b47128f228fd3daf298e225d7324c8d54933 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -74,10 +74,10 @@ private[scheduler] abstract class Stage(
   val name: String = callSite.shortForm
   val details: String = callSite.longForm
 
-  private var _internalAccumulators: Seq[Accumulator[Long]] = Seq.empty
+  private var _internalAccumulators: Seq[Accumulator[_]] = Seq.empty
 
   /** Internal accumulators shared across all tasks in this stage. */
-  def internalAccumulators: Seq[Accumulator[Long]] = _internalAccumulators
+  def internalAccumulators: Seq[Accumulator[_]] = _internalAccumulators
 
   /**
    * Re-initialize the internal accumulators associated with this stage.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index fca57928eca1bff17019a8a8763278a963836115..a49f3716e2702e9ca09f205dce24f28fd878f5bd 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.scheduler
 
-import java.io.{ByteArrayOutputStream, DataInputStream, DataOutputStream}
+import java.io.{DataInputStream, DataOutputStream}
 import java.nio.ByteBuffer
 
 import scala.collection.mutable.HashMap
@@ -41,32 +41,29 @@ import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Uti
  * and divides the task output to multiple buckets (based on the task's partitioner).
  *
  * @param stageId id of the stage this task belongs to
+ * @param stageAttemptId attempt id of the stage this task belongs to
  * @param partitionId index of the number in the RDD
+ * @param initialAccumulators initial set of accumulators to be used in this task for tracking
+ *                            internal metrics. Other accumulators will be registered later when
+ *                            they are deserialized on the executors.
  */
 private[spark] abstract class Task[T](
     val stageId: Int,
     val stageAttemptId: Int,
     val partitionId: Int,
-    internalAccumulators: Seq[Accumulator[Long]]) extends Serializable {
+    val initialAccumulators: Seq[Accumulator[_]]) extends Serializable {
 
   /**
-   * The key of the Map is the accumulator id and the value of the Map is the latest accumulator
-   * local value.
-   */
-  type AccumulatorUpdates = Map[Long, Any]
-
-  /**
-   * Called by [[Executor]] to run this task.
+   * Called by [[org.apache.spark.executor.Executor]] to run this task.
    *
    * @param taskAttemptId an identifier for this task attempt that is unique within a SparkContext.
    * @param attemptNumber how many times this task has been attempted (0 for the first attempt)
    * @return the result of the task along with updates of Accumulators.
    */
   final def run(
-    taskAttemptId: Long,
-    attemptNumber: Int,
-    metricsSystem: MetricsSystem)
-  : (T, AccumulatorUpdates) = {
+      taskAttemptId: Long,
+      attemptNumber: Int,
+      metricsSystem: MetricsSystem): T = {
     context = new TaskContextImpl(
       stageId,
       partitionId,
@@ -74,16 +71,14 @@ private[spark] abstract class Task[T](
       attemptNumber,
       taskMemoryManager,
       metricsSystem,
-      internalAccumulators)
+      initialAccumulators)
     TaskContext.setTaskContext(context)
-    context.taskMetrics.setHostname(Utils.localHostName())
-    context.taskMetrics.setAccumulatorsUpdater(context.collectInternalAccumulators)
     taskThread = Thread.currentThread()
     if (_killed) {
       kill(interruptThread = false)
     }
     try {
-      (runTask(context), context.collectAccumulators())
+      runTask(context)
     } finally {
       context.markTaskCompleted()
       try {
@@ -140,6 +135,18 @@ private[spark] abstract class Task[T](
    */
   def executorDeserializeTime: Long = _executorDeserializeTime
 
+  /**
+   * Collect the latest values of accumulators used in this task. If the task failed,
+   * filter out the accumulators whose values should not be included on failures.
+   */
+  def collectAccumulatorUpdates(taskFailed: Boolean = false): Seq[AccumulableInfo] = {
+    if (context != null) {
+      context.taskMetrics.accumulatorUpdates().filter { a => !taskFailed || a.countFailedValues }
+    } else {
+      Seq.empty[AccumulableInfo]
+    }
+  }
+
   /**
    * Kills a task by setting the interrupted flag to true. This relies on the upper level Spark
    * code and user code to properly handle the flag. This function should be idempotent so it can
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
index b82c7f3fa54f862d1fe599810755ca7783134322..03135e63d75518bbc059a0d83fa39175ae637a62 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
@@ -20,11 +20,9 @@ package org.apache.spark.scheduler
 import java.io._
 import java.nio.ByteBuffer
 
-import scala.collection.Map
-import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.SparkEnv
-import org.apache.spark.executor.TaskMetrics
 import org.apache.spark.storage.BlockId
 import org.apache.spark.util.Utils
 
@@ -36,31 +34,24 @@ private[spark] case class IndirectTaskResult[T](blockId: BlockId, size: Int)
   extends TaskResult[T] with Serializable
 
 /** A TaskResult that contains the task's return value and accumulator updates. */
-private[spark]
-class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long, Any],
-    var metrics: TaskMetrics)
+private[spark] class DirectTaskResult[T](
+    var valueBytes: ByteBuffer,
+    var accumUpdates: Seq[AccumulableInfo])
   extends TaskResult[T] with Externalizable {
 
   private var valueObjectDeserialized = false
   private var valueObject: T = _
 
-  def this() = this(null.asInstanceOf[ByteBuffer], null, null)
+  def this() = this(null.asInstanceOf[ByteBuffer], null)
 
   override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
-
-    out.writeInt(valueBytes.remaining);
+    out.writeInt(valueBytes.remaining)
     Utils.writeByteBuffer(valueBytes, out)
-
     out.writeInt(accumUpdates.size)
-    for ((key, value) <- accumUpdates) {
-      out.writeLong(key)
-      out.writeObject(value)
-    }
-    out.writeObject(metrics)
+    accumUpdates.foreach(out.writeObject)
   }
 
   override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
-
     val blen = in.readInt()
     val byteVal = new Array[Byte](blen)
     in.readFully(byteVal)
@@ -70,13 +61,12 @@ class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long
     if (numUpdates == 0) {
       accumUpdates = null
     } else {
-      val _accumUpdates = mutable.Map[Long, Any]()
+      val _accumUpdates = new ArrayBuffer[AccumulableInfo]
       for (i <- 0 until numUpdates) {
-        _accumUpdates(in.readLong()) = in.readObject()
+        _accumUpdates += in.readObject.asInstanceOf[AccumulableInfo]
       }
       accumUpdates = _accumUpdates
     }
-    metrics = in.readObject().asInstanceOf[TaskMetrics]
     valueObjectDeserialized = false
   }
 
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
index f4965994d8277f191550f2d1abf16f58e3d8fd1a..c94c4f55e9ced6a5ae16c371730fb03d91c9e138 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.scheduler
 
 import java.nio.ByteBuffer
-import java.util.concurrent.RejectedExecutionException
+import java.util.concurrent.{ExecutorService, RejectedExecutionException}
 
 import scala.language.existentials
 import scala.util.control.NonFatal
@@ -35,9 +35,12 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
   extends Logging {
 
   private val THREADS = sparkEnv.conf.getInt("spark.resultGetter.threads", 4)
-  private val getTaskResultExecutor = ThreadUtils.newDaemonFixedThreadPool(
-    THREADS, "task-result-getter")
 
+  // Exposed for testing.
+  protected val getTaskResultExecutor: ExecutorService =
+    ThreadUtils.newDaemonFixedThreadPool(THREADS, "task-result-getter")
+
+  // Exposed for testing.
   protected val serializer = new ThreadLocal[SerializerInstance] {
     override def initialValue(): SerializerInstance = {
       sparkEnv.closureSerializer.newInstance()
@@ -45,7 +48,9 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
   }
 
   def enqueueSuccessfulTask(
-    taskSetManager: TaskSetManager, tid: Long, serializedData: ByteBuffer) {
+      taskSetManager: TaskSetManager,
+      tid: Long,
+      serializedData: ByteBuffer): Unit = {
     getTaskResultExecutor.execute(new Runnable {
       override def run(): Unit = Utils.logUncaughtExceptions {
         try {
@@ -82,7 +87,19 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
               (deserializedResult, size)
           }
 
-          result.metrics.setResultSize(size)
+          // Set the task result size in the accumulator updates received from the executors.
+          // We need to do this here on the driver because if we did this on the executors then
+          // we would have to serialize the result again after updating the size.
+          result.accumUpdates = result.accumUpdates.map { a =>
+            if (a.name == Some(InternalAccumulator.RESULT_SIZE)) {
+              assert(a.update == Some(0L),
+                "task result size should not have been set on the executors")
+              a.copy(update = Some(size.toLong))
+            } else {
+              a
+            }
+          }
+
           scheduler.handleSuccessfulTask(taskSetManager, tid, result)
         } catch {
           case cnf: ClassNotFoundException =>
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
index 7c0b007db708eb62de0731b03d269bff58220ec7..fccd6e0699341de0265b3fda49eb85e70ce605b7 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
@@ -65,8 +65,10 @@ private[spark] trait TaskScheduler {
    * alive. Return true if the driver knows about the given block manager. Otherwise, return false,
    * indicating that the block manager should re-register.
    */
-  def executorHeartbeatReceived(execId: String, taskMetrics: Array[(Long, TaskMetrics)],
-    blockManagerId: BlockManagerId): Boolean
+  def executorHeartbeatReceived(
+      execId: String,
+      accumUpdates: Array[(Long, Seq[AccumulableInfo])],
+      blockManagerId: BlockManagerId): Boolean
 
   /**
    * Get an application ID associated with the job.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index 6e3ef0e54f0fd644980d377656e2d8db099f56fc..29341dfe3043c72bb1323a235265a3d6c8323506 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -30,7 +30,6 @@ import scala.util.Random
 
 import org.apache.spark._
 import org.apache.spark.TaskState.TaskState
-import org.apache.spark.executor.TaskMetrics
 import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
 import org.apache.spark.scheduler.TaskLocality.TaskLocality
 import org.apache.spark.storage.BlockManagerId
@@ -380,17 +379,17 @@ private[spark] class TaskSchedulerImpl(
    */
   override def executorHeartbeatReceived(
       execId: String,
-      taskMetrics: Array[(Long, TaskMetrics)], // taskId -> TaskMetrics
+      accumUpdates: Array[(Long, Seq[AccumulableInfo])],
       blockManagerId: BlockManagerId): Boolean = {
-
-    val metricsWithStageIds: Array[(Long, Int, Int, TaskMetrics)] = synchronized {
-      taskMetrics.flatMap { case (id, metrics) =>
+    // (taskId, stageId, stageAttemptId, accumUpdates)
+    val accumUpdatesWithTaskIds: Array[(Long, Int, Int, Seq[AccumulableInfo])] = synchronized {
+      accumUpdates.flatMap { case (id, updates) =>
         taskIdToTaskSetManager.get(id).map { taskSetMgr =>
-          (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics)
+          (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, updates)
         }
       }
     }
-    dagScheduler.executorHeartbeatReceived(execId, metricsWithStageIds, blockManagerId)
+    dagScheduler.executorHeartbeatReceived(execId, accumUpdatesWithTaskIds, blockManagerId)
   }
 
   def handleTaskGettingResult(taskSetManager: TaskSetManager, tid: Long): Unit = synchronized {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index aa39b59d8cce427edbfb87f8ac6170c0e6b8cf80..cf97877476d54e630dbf146bff825325d79148a2 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -621,8 +621,7 @@ private[spark] class TaskSetManager(
     // "result.value()" in "TaskResultGetter.enqueueSuccessfulTask" before reaching here.
     // Note: "result.value()" only deserializes the value when it's called at the first time, so
     // here "result.value()" just returns the value and won't block other threads.
-    sched.dagScheduler.taskEnded(
-      tasks(index), Success, result.value(), result.accumUpdates, info, result.metrics)
+    sched.dagScheduler.taskEnded(tasks(index), Success, result.value(), result.accumUpdates, info)
     if (!successful(index)) {
       tasksSuccessful += 1
       logInfo("Finished task %s in stage %s (TID %d) in %d ms on %s (%d/%d)".format(
@@ -653,8 +652,7 @@ private[spark] class TaskSetManager(
     info.markFailed()
     val index = info.index
     copiesRunning(index) -= 1
-    var taskMetrics : TaskMetrics = null
-
+    var accumUpdates: Seq[AccumulableInfo] = Seq.empty[AccumulableInfo]
     val failureReason = s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid, ${info.host}): " +
       reason.asInstanceOf[TaskFailedReason].toErrorString
     val failureException: Option[Throwable] = reason match {
@@ -669,7 +667,8 @@ private[spark] class TaskSetManager(
         None
 
       case ef: ExceptionFailure =>
-        taskMetrics = ef.metrics.orNull
+        // ExceptionFailure's might have accumulator updates
+        accumUpdates = ef.accumUpdates
         if (ef.className == classOf[NotSerializableException].getName) {
           // If the task result wasn't serializable, there's no point in trying to re-execute it.
           logError("Task %s in stage %s (TID %d) had a not serializable result: %s; not retrying"
@@ -721,7 +720,7 @@ private[spark] class TaskSetManager(
     // always add to failed executors
     failedExecutors.getOrElseUpdate(index, new HashMap[String, Long]()).
       put(info.executorId, clock.getTimeMillis())
-    sched.dagScheduler.taskEnded(tasks(index), reason, null, null, info, taskMetrics)
+    sched.dagScheduler.taskEnded(tasks(index), reason, null, accumUpdates, info)
     addPendingTask(index)
     if (!isZombie && state != TaskState.KILLED
         && reason.isInstanceOf[TaskFailedReason]
@@ -793,7 +792,8 @@ private[spark] class TaskSetManager(
           addPendingTask(index)
           // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
           // stage finishes when a total of tasks.size tasks finish.
-          sched.dagScheduler.taskEnded(tasks(index), Resubmitted, null, null, info, null)
+          sched.dagScheduler.taskEnded(
+            tasks(index), Resubmitted, null, Seq.empty[AccumulableInfo], info)
         }
       }
     }
diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
index a57e5b0bfb8656b9b0867c9e423ed70ddc7e003a..acbe16001f5ba4ed6cca658e2775b67300e409d9 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
@@ -103,8 +103,7 @@ private[spark] class BlockStoreShuffleReader[K, C](
         sorter.insertAll(aggregatedIter)
         context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
         context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
-        context.internalMetricsToAccumulators(
-          InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes)
+        context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
         CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
       case None =>
         aggregatedIter
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala
index 078718ba1126036cbb14c613d8949a1879ba2099..9c92a501503cdaa0817e0b905f4ad4f9e895f4ea 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala
@@ -237,7 +237,8 @@ private[v1] object AllStagesResource {
   }
 
   def convertAccumulableInfo(acc: InternalAccumulableInfo): AccumulableInfo = {
-    new AccumulableInfo(acc.id, acc.name, acc.update, acc.value)
+    new AccumulableInfo(
+      acc.id, acc.name.orNull, acc.update.map(_.toString), acc.value.map(_.toString).orNull)
   }
 
   def convertUiTaskMetrics(internal: InternalTaskMetrics): TaskMetrics = {
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
index 4a9f8b30525fe4b6da7c9f0d5dc11ef0f6713455..b2aa8bfbe70094fe7932414dac1c563e5085a97b 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
@@ -325,12 +325,13 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
   override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = synchronized {
     val taskInfo = taskStart.taskInfo
     if (taskInfo != null) {
+      val metrics = new TaskMetrics
       val stageData = stageIdToData.getOrElseUpdate((taskStart.stageId, taskStart.stageAttemptId), {
         logWarning("Task start for unknown stage " + taskStart.stageId)
         new StageUIData
       })
       stageData.numActiveTasks += 1
-      stageData.taskData.put(taskInfo.taskId, new TaskUIData(taskInfo))
+      stageData.taskData.put(taskInfo.taskId, new TaskUIData(taskInfo, Some(metrics)))
     }
     for (
       activeJobsDependentOnStage <- stageIdToActiveJobIds.get(taskStart.stageId);
@@ -387,9 +388,9 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
             (Some(e.toErrorString), None)
         }
 
-      if (!metrics.isEmpty) {
+      metrics.foreach { m =>
         val oldMetrics = stageData.taskData.get(info.taskId).flatMap(_.taskMetrics)
-        updateAggregateMetrics(stageData, info.executorId, metrics.get, oldMetrics)
+        updateAggregateMetrics(stageData, info.executorId, m, oldMetrics)
       }
 
       val taskData = stageData.taskData.getOrElseUpdate(info.taskId, new TaskUIData(info))
@@ -489,19 +490,18 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
   }
 
   override def onExecutorMetricsUpdate(executorMetricsUpdate: SparkListenerExecutorMetricsUpdate) {
-    for ((taskId, sid, sAttempt, taskMetrics) <- executorMetricsUpdate.taskMetrics) {
+    for ((taskId, sid, sAttempt, accumUpdates) <- executorMetricsUpdate.accumUpdates) {
       val stageData = stageIdToData.getOrElseUpdate((sid, sAttempt), {
         logWarning("Metrics update for task in unknown stage " + sid)
         new StageUIData
       })
       val taskData = stageData.taskData.get(taskId)
-      taskData.map { t =>
+      val metrics = TaskMetrics.fromAccumulatorUpdates(accumUpdates)
+      taskData.foreach { t =>
         if (!t.taskInfo.finished) {
-          updateAggregateMetrics(stageData, executorMetricsUpdate.execId, taskMetrics,
-            t.taskMetrics)
-
+          updateAggregateMetrics(stageData, executorMetricsUpdate.execId, metrics, t.taskMetrics)
           // Overwrite task metrics
-          t.taskMetrics = Some(taskMetrics)
+          t.taskMetrics = Some(metrics)
         }
       }
     }
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index 914f6183cc2a4893a41f78e80a5aa0e75eb4f6b4..29c5ff0b5cf0bad5921c9ddf1a31a1257dd4df34 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -271,8 +271,12 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
         }
 
       val accumulableHeaders: Seq[String] = Seq("Accumulable", "Value")
-      def accumulableRow(acc: AccumulableInfo): Elem =
-        <tr><td>{acc.name}</td><td>{acc.value}</td></tr>
+      def accumulableRow(acc: AccumulableInfo): Seq[Node] = {
+        (acc.name, acc.value) match {
+          case (Some(name), Some(value)) => <tr><td>{name}</td><td>{value}</td></tr>
+          case _ => Seq.empty[Node]
+        }
+      }
       val accumulableTable = UIUtils.listingTable(
         accumulableHeaders,
         accumulableRow,
@@ -404,13 +408,9 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
             </td> +:
             getFormattedTimeQuantiles(gettingResultTimes)
 
-          val peakExecutionMemory = validTasks.map { case TaskUIData(info, _, _) =>
-            info.accumulables
-              .find { acc => acc.name == InternalAccumulator.PEAK_EXECUTION_MEMORY }
-              .map { acc => acc.update.getOrElse("0").toLong }
-              .getOrElse(0L)
-              .toDouble
-          }
+            val peakExecutionMemory = validTasks.map { case TaskUIData(_, metrics, _) =>
+              metrics.get.peakExecutionMemory.toDouble
+            }
           val peakExecutionMemoryQuantiles = {
             <td>
               <span data-toggle="tooltip"
@@ -891,15 +891,15 @@ private[ui] class TaskDataSource(
     val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L)
     val gettingResultTime = getGettingResultTime(info, currentTime)
 
-    val (taskInternalAccumulables, taskExternalAccumulables) =
-      info.accumulables.partition(_.internal)
-    val externalAccumulableReadable = taskExternalAccumulables.map { acc =>
-      StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update.get}")
-    }
-    val peakExecutionMemoryUsed = taskInternalAccumulables
-      .find { acc => acc.name == InternalAccumulator.PEAK_EXECUTION_MEMORY }
-      .map { acc => acc.update.getOrElse("0").toLong }
-      .getOrElse(0L)
+    val externalAccumulableReadable = info.accumulables
+      .filterNot(_.internal)
+      .flatMap { a =>
+        (a.name, a.update) match {
+          case (Some(name), Some(update)) => Some(StringEscapeUtils.escapeHtml4(s"$name: $update"))
+          case _ => None
+        }
+      }
+    val peakExecutionMemoryUsed = metrics.map(_.peakExecutionMemory).getOrElse(0L)
 
     val maybeInput = metrics.flatMap(_.inputMetrics)
     val inputSortable = maybeInput.map(_.bytesRead).getOrElse(0L)
diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
index efa22b99936af6c935ad5fe2d37f4329913fe91a..dc8070cf8aad3b1c06df271cf3d2a3ac5b29207b 100644
--- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
@@ -233,14 +233,14 @@ private[spark] object JsonProtocol {
 
   def executorMetricsUpdateToJson(metricsUpdate: SparkListenerExecutorMetricsUpdate): JValue = {
     val execId = metricsUpdate.execId
-    val taskMetrics = metricsUpdate.taskMetrics
+    val accumUpdates = metricsUpdate.accumUpdates
     ("Event" -> Utils.getFormattedClassName(metricsUpdate)) ~
     ("Executor ID" -> execId) ~
-    ("Metrics Updated" -> taskMetrics.map { case (taskId, stageId, stageAttemptId, metrics) =>
+      ("Metrics Updated" -> accumUpdates.map { case (taskId, stageId, stageAttemptId, updates) =>
       ("Task ID" -> taskId) ~
       ("Stage ID" -> stageId) ~
       ("Stage Attempt ID" -> stageAttemptId) ~
-      ("Task Metrics" -> taskMetricsToJson(metrics))
+      ("Accumulator Updates" -> JArray(updates.map(accumulableInfoToJson).toList))
     })
   }
 
@@ -265,7 +265,7 @@ private[spark] object JsonProtocol {
     ("Completion Time" -> completionTime) ~
     ("Failure Reason" -> failureReason) ~
     ("Accumulables" -> JArray(
-        stageInfo.accumulables.values.map(accumulableInfoToJson).toList))
+      stageInfo.accumulables.values.map(accumulableInfoToJson).toList))
   }
 
   def taskInfoToJson(taskInfo: TaskInfo): JValue = {
@@ -284,11 +284,44 @@ private[spark] object JsonProtocol {
   }
 
   def accumulableInfoToJson(accumulableInfo: AccumulableInfo): JValue = {
+    val name = accumulableInfo.name
     ("ID" -> accumulableInfo.id) ~
-    ("Name" -> accumulableInfo.name) ~
-    ("Update" -> accumulableInfo.update.map(new JString(_)).getOrElse(JNothing)) ~
-    ("Value" -> accumulableInfo.value) ~
-    ("Internal" -> accumulableInfo.internal)
+    ("Name" -> name) ~
+    ("Update" -> accumulableInfo.update.map { v => accumValueToJson(name, v) }) ~
+    ("Value" -> accumulableInfo.value.map { v => accumValueToJson(name, v) }) ~
+    ("Internal" -> accumulableInfo.internal) ~
+    ("Count Failed Values" -> accumulableInfo.countFailedValues)
+  }
+
+  /**
+   * Serialize the value of an accumulator to JSON.
+   *
+   * For accumulators representing internal task metrics, this looks up the relevant
+   * [[AccumulatorParam]] to serialize the value accordingly. For all other accumulators,
+   * this will simply serialize the value as a string.
+   *
+   * The behavior here must match that of [[accumValueFromJson]]. Exposed for testing.
+   */
+  private[util] def accumValueToJson(name: Option[String], value: Any): JValue = {
+    import AccumulatorParam._
+    if (name.exists(_.startsWith(InternalAccumulator.METRICS_PREFIX))) {
+      (value, InternalAccumulator.getParam(name.get)) match {
+        case (v: Int, IntAccumulatorParam) => JInt(v)
+        case (v: Long, LongAccumulatorParam) => JInt(v)
+        case (v: String, StringAccumulatorParam) => JString(v)
+        case (v, UpdatedBlockStatusesAccumulatorParam) =>
+          JArray(v.asInstanceOf[Seq[(BlockId, BlockStatus)]].toList.map { case (id, status) =>
+            ("Block ID" -> id.toString) ~
+            ("Status" -> blockStatusToJson(status))
+          })
+        case (v, p) =>
+          throw new IllegalArgumentException(s"unexpected combination of accumulator value " +
+            s"type (${v.getClass.getName}) and param (${p.getClass.getName}) in '${name.get}'")
+      }
+    } else {
+      // For all external accumulators, just use strings
+      JString(value.toString)
+    }
   }
 
   def taskMetricsToJson(taskMetrics: TaskMetrics): JValue = {
@@ -303,9 +336,9 @@ private[spark] object JsonProtocol {
       }.getOrElse(JNothing)
     val shuffleWriteMetrics: JValue =
       taskMetrics.shuffleWriteMetrics.map { wm =>
-        ("Shuffle Bytes Written" -> wm.shuffleBytesWritten) ~
-        ("Shuffle Write Time" -> wm.shuffleWriteTime) ~
-        ("Shuffle Records Written" -> wm.shuffleRecordsWritten)
+        ("Shuffle Bytes Written" -> wm.bytesWritten) ~
+        ("Shuffle Write Time" -> wm.writeTime) ~
+        ("Shuffle Records Written" -> wm.recordsWritten)
       }.getOrElse(JNothing)
     val inputMetrics: JValue =
       taskMetrics.inputMetrics.map { im =>
@@ -324,7 +357,6 @@ private[spark] object JsonProtocol {
         ("Block ID" -> id.toString) ~
         ("Status" -> blockStatusToJson(status))
       })
-    ("Host Name" -> taskMetrics.hostname) ~
     ("Executor Deserialize Time" -> taskMetrics.executorDeserializeTime) ~
     ("Executor Run Time" -> taskMetrics.executorRunTime) ~
     ("Result Size" -> taskMetrics.resultSize) ~
@@ -352,12 +384,12 @@ private[spark] object JsonProtocol {
         ("Message" -> fetchFailed.message)
       case exceptionFailure: ExceptionFailure =>
         val stackTrace = stackTraceToJson(exceptionFailure.stackTrace)
-        val metrics = exceptionFailure.metrics.map(taskMetricsToJson).getOrElse(JNothing)
+        val accumUpdates = JArray(exceptionFailure.accumUpdates.map(accumulableInfoToJson).toList)
         ("Class Name" -> exceptionFailure.className) ~
         ("Description" -> exceptionFailure.description) ~
         ("Stack Trace" -> stackTrace) ~
         ("Full Stack Trace" -> exceptionFailure.fullStackTrace) ~
-        ("Metrics" -> metrics)
+        ("Accumulator Updates" -> accumUpdates)
       case taskCommitDenied: TaskCommitDenied =>
         ("Job ID" -> taskCommitDenied.jobID) ~
         ("Partition ID" -> taskCommitDenied.partitionID) ~
@@ -619,14 +651,15 @@ private[spark] object JsonProtocol {
 
   def executorMetricsUpdateFromJson(json: JValue): SparkListenerExecutorMetricsUpdate = {
     val execInfo = (json \ "Executor ID").extract[String]
-    val taskMetrics = (json \ "Metrics Updated").extract[List[JValue]].map { json =>
+    val accumUpdates = (json \ "Metrics Updated").extract[List[JValue]].map { json =>
       val taskId = (json \ "Task ID").extract[Long]
       val stageId = (json \ "Stage ID").extract[Int]
       val stageAttemptId = (json \ "Stage Attempt ID").extract[Int]
-      val metrics = taskMetricsFromJson(json \ "Task Metrics")
-      (taskId, stageId, stageAttemptId, metrics)
+      val updates =
+        (json \ "Accumulator Updates").extract[List[JValue]].map(accumulableInfoFromJson)
+      (taskId, stageId, stageAttemptId, updates)
     }
-    SparkListenerExecutorMetricsUpdate(execInfo, taskMetrics)
+    SparkListenerExecutorMetricsUpdate(execInfo, accumUpdates)
   }
 
   /** --------------------------------------------------------------------- *
@@ -647,7 +680,7 @@ private[spark] object JsonProtocol {
     val completionTime = Utils.jsonOption(json \ "Completion Time").map(_.extract[Long])
     val failureReason = Utils.jsonOption(json \ "Failure Reason").map(_.extract[String])
     val accumulatedValues = (json \ "Accumulables").extractOpt[List[JValue]] match {
-      case Some(values) => values.map(accumulableInfoFromJson(_))
+      case Some(values) => values.map(accumulableInfoFromJson)
       case None => Seq[AccumulableInfo]()
     }
 
@@ -675,7 +708,7 @@ private[spark] object JsonProtocol {
     val finishTime = (json \ "Finish Time").extract[Long]
     val failed = (json \ "Failed").extract[Boolean]
     val accumulables = (json \ "Accumulables").extractOpt[Seq[JValue]] match {
-      case Some(values) => values.map(accumulableInfoFromJson(_))
+      case Some(values) => values.map(accumulableInfoFromJson)
       case None => Seq[AccumulableInfo]()
     }
 
@@ -690,11 +723,43 @@ private[spark] object JsonProtocol {
 
   def accumulableInfoFromJson(json: JValue): AccumulableInfo = {
     val id = (json \ "ID").extract[Long]
-    val name = (json \ "Name").extract[String]
-    val update = Utils.jsonOption(json \ "Update").map(_.extract[String])
-    val value = (json \ "Value").extract[String]
+    val name = (json \ "Name").extractOpt[String]
+    val update = Utils.jsonOption(json \ "Update").map { v => accumValueFromJson(name, v) }
+    val value = Utils.jsonOption(json \ "Value").map { v => accumValueFromJson(name, v) }
     val internal = (json \ "Internal").extractOpt[Boolean].getOrElse(false)
-    AccumulableInfo(id, name, update, value, internal)
+    val countFailedValues = (json \ "Count Failed Values").extractOpt[Boolean].getOrElse(false)
+    new AccumulableInfo(id, name, update, value, internal, countFailedValues)
+  }
+
+  /**
+   * Deserialize the value of an accumulator from JSON.
+   *
+   * For accumulators representing internal task metrics, this looks up the relevant
+   * [[AccumulatorParam]] to deserialize the value accordingly. For all other
+   * accumulators, this will simply deserialize the value as a string.
+   *
+   * The behavior here must match that of [[accumValueToJson]]. Exposed for testing.
+   */
+  private[util] def accumValueFromJson(name: Option[String], value: JValue): Any = {
+    import AccumulatorParam._
+    if (name.exists(_.startsWith(InternalAccumulator.METRICS_PREFIX))) {
+      (value, InternalAccumulator.getParam(name.get)) match {
+        case (JInt(v), IntAccumulatorParam) => v.toInt
+        case (JInt(v), LongAccumulatorParam) => v.toLong
+        case (JString(v), StringAccumulatorParam) => v
+        case (JArray(v), UpdatedBlockStatusesAccumulatorParam) =>
+          v.map { blockJson =>
+            val id = BlockId((blockJson \ "Block ID").extract[String])
+            val status = blockStatusFromJson(blockJson \ "Status")
+            (id, status)
+          }
+        case (v, p) =>
+          throw new IllegalArgumentException(s"unexpected combination of accumulator " +
+            s"value in JSON ($v) and accumulator param (${p.getClass.getName}) in '${name.get}'")
+       }
+     } else {
+       value.extract[String]
+     }
   }
 
   def taskMetricsFromJson(json: JValue): TaskMetrics = {
@@ -702,7 +767,6 @@ private[spark] object JsonProtocol {
       return TaskMetrics.empty
     }
     val metrics = new TaskMetrics
-    metrics.setHostname((json \ "Host Name").extract[String])
     metrics.setExecutorDeserializeTime((json \ "Executor Deserialize Time").extract[Long])
     metrics.setExecutorRunTime((json \ "Executor Run Time").extract[Long])
     metrics.setResultSize((json \ "Result Size").extract[Long])
@@ -787,10 +851,12 @@ private[spark] object JsonProtocol {
         val className = (json \ "Class Name").extract[String]
         val description = (json \ "Description").extract[String]
         val stackTrace = stackTraceFromJson(json \ "Stack Trace")
-        val fullStackTrace = Utils.jsonOption(json \ "Full Stack Trace").
-          map(_.extract[String]).orNull
-        val metrics = Utils.jsonOption(json \ "Metrics").map(taskMetricsFromJson)
-        ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics, None)
+        val fullStackTrace = (json \ "Full Stack Trace").extractOpt[String].orNull
+        // Fallback on getting accumulator updates from TaskMetrics, which was logged in Spark 1.x
+        val accumUpdates = Utils.jsonOption(json \ "Accumulator Updates")
+          .map(_.extract[List[JValue]].map(accumulableInfoFromJson))
+          .getOrElse(taskMetricsFromJson(json \ "Metrics").accumulatorUpdates())
+        ExceptionFailure(className, description, stackTrace, fullStackTrace, None, accumUpdates)
       case `taskResultLost` => TaskResultLost
       case `taskKilled` => TaskKilled
       case `taskCommitDenied` =>
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index df9e0502e73617382cac157a80467dbde836c5f2..5afd6d6e22c62beab78f2887009fbf89bd42a21b 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -682,8 +682,7 @@ private[spark] class ExternalSorter[K, V, C](
 
     context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled)
     context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
-    context.internalMetricsToAccumulators(
-      InternalAccumulator.PEAK_EXECUTION_MEMORY).add(peakMemoryUsedBytes)
+    context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes)
 
     lengths
   }
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index 625fdd57eb5d40036142df37a4b181731fe22d82..876c3a2283649e808a3f7be4c9c96f9392b37b27 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -191,8 +191,6 @@ public class UnsafeShuffleWriterSuite {
       });
 
     when(taskContext.taskMetrics()).thenReturn(taskMetrics);
-    when(taskContext.internalMetricsToAccumulators()).thenReturn(null);
-
     when(shuffleDep.serializer()).thenReturn(Option.<Serializer>apply(serializer));
     when(shuffleDep.partitioner()).thenReturn(hashPartitioner);
   }
diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
index 5b84acf40be4e567cd1061172a9fa5de5d7a24d6..11c97d7d9a447e4d16d7802232de9a659ba0e3b8 100644
--- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
@@ -17,18 +17,22 @@
 
 package org.apache.spark
 
+import javax.annotation.concurrent.GuardedBy
+
 import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 import scala.ref.WeakReference
+import scala.util.control.NonFatal
 
 import org.scalatest.Matchers
 import org.scalatest.exceptions.TestFailedException
 
 import org.apache.spark.scheduler._
+import org.apache.spark.serializer.JavaSerializer
 
 
 class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContext {
-  import InternalAccumulator._
+  import AccumulatorParam._
 
   implicit def setAccum[A]: AccumulableParam[mutable.Set[A], A] =
     new AccumulableParam[mutable.Set[A], A] {
@@ -59,7 +63,7 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex
     longAcc.value should be (210L + maxInt * 20)
   }
 
-  test ("value not assignable from tasks") {
+  test("value not assignable from tasks") {
     sc = new SparkContext("local", "test")
     val acc : Accumulator[Int] = sc.accumulator(0)
 
@@ -84,7 +88,7 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex
     }
   }
 
-  test ("value not readable in tasks") {
+  test("value not readable in tasks") {
     val maxI = 1000
     for (nThreads <- List(1, 10)) { // test single & multi-threaded
       sc = new SparkContext("local[" + nThreads + "]", "test")
@@ -159,193 +163,157 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex
     assert(!Accumulators.originals.get(accId).isDefined)
   }
 
-  test("internal accumulators in TaskContext") {
+  test("get accum") {
     sc = new SparkContext("local", "test")
-    val accums = InternalAccumulator.create(sc)
-    val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null, accums)
-    val internalMetricsToAccums = taskContext.internalMetricsToAccumulators
-    val collectedInternalAccums = taskContext.collectInternalAccumulators()
-    val collectedAccums = taskContext.collectAccumulators()
-    assert(internalMetricsToAccums.size > 0)
-    assert(internalMetricsToAccums.values.forall(_.isInternal))
-    assert(internalMetricsToAccums.contains(TEST_ACCUMULATOR))
-    val testAccum = internalMetricsToAccums(TEST_ACCUMULATOR)
-    assert(collectedInternalAccums.size === internalMetricsToAccums.size)
-    assert(collectedInternalAccums.size === collectedAccums.size)
-    assert(collectedInternalAccums.contains(testAccum.id))
-    assert(collectedAccums.contains(testAccum.id))
-  }
+    // Don't register with SparkContext for cleanup
+    var acc = new Accumulable[Int, Int](0, IntAccumulatorParam, None, true, true)
+    val accId = acc.id
+    val ref = WeakReference(acc)
+    assert(ref.get.isDefined)
+    Accumulators.register(ref.get.get)
 
-  test("internal accumulators in a stage") {
-    val listener = new SaveInfoListener
-    val numPartitions = 10
-    sc = new SparkContext("local", "test")
-    sc.addSparkListener(listener)
-    // Have each task add 1 to the internal accumulator
-    val rdd = sc.parallelize(1 to 100, numPartitions).mapPartitions { iter =>
-      TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1
-      iter
-    }
-    // Register asserts in job completion callback to avoid flakiness
-    listener.registerJobCompletionCallback { _ =>
-      val stageInfos = listener.getCompletedStageInfos
-      val taskInfos = listener.getCompletedTaskInfos
-      assert(stageInfos.size === 1)
-      assert(taskInfos.size === numPartitions)
-      // The accumulator values should be merged in the stage
-      val stageAccum = findAccumulableInfo(stageInfos.head.accumulables.values, TEST_ACCUMULATOR)
-      assert(stageAccum.value.toLong === numPartitions)
-      // The accumulator should be updated locally on each task
-      val taskAccumValues = taskInfos.map { taskInfo =>
-        val taskAccum = findAccumulableInfo(taskInfo.accumulables, TEST_ACCUMULATOR)
-        assert(taskAccum.update.isDefined)
-        assert(taskAccum.update.get.toLong === 1)
-        taskAccum.value.toLong
-      }
-      // Each task should keep track of the partial value on the way, i.e. 1, 2, ... numPartitions
-      assert(taskAccumValues.sorted === (1L to numPartitions).toSeq)
+    // Remove the explicit reference to it and allow weak reference to get garbage collected
+    acc = null
+    System.gc()
+    assert(ref.get.isEmpty)
+
+    // Getting a garbage collected accum should throw error
+    intercept[IllegalAccessError] {
+      Accumulators.get(accId)
     }
-    rdd.count()
+
+    // Getting a normal accumulator. Note: this has to be separate because referencing an
+    // accumulator above in an `assert` would keep it from being garbage collected.
+    val acc2 = new Accumulable[Long, Long](0L, LongAccumulatorParam, None, true, true)
+    Accumulators.register(acc2)
+    assert(Accumulators.get(acc2.id) === Some(acc2))
+
+    // Getting an accumulator that does not exist should return None
+    assert(Accumulators.get(100000).isEmpty)
   }
 
-  test("internal accumulators in multiple stages") {
-    val listener = new SaveInfoListener
-    val numPartitions = 10
-    sc = new SparkContext("local", "test")
-    sc.addSparkListener(listener)
-    // Each stage creates its own set of internal accumulators so the
-    // values for the same metric should not be mixed up across stages
-    val rdd = sc.parallelize(1 to 100, numPartitions)
-      .map { i => (i, i) }
-      .mapPartitions { iter =>
-        TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1
-        iter
-      }
-      .reduceByKey { case (x, y) => x + y }
-      .mapPartitions { iter =>
-        TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 10
-        iter
-      }
-      .repartition(numPartitions * 2)
-      .mapPartitions { iter =>
-        TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 100
-        iter
-      }
-    // Register asserts in job completion callback to avoid flakiness
-    listener.registerJobCompletionCallback { _ =>
-      // We ran 3 stages, and the accumulator values should be distinct
-      val stageInfos = listener.getCompletedStageInfos
-      assert(stageInfos.size === 3)
-      val (firstStageAccum, secondStageAccum, thirdStageAccum) =
-        (findAccumulableInfo(stageInfos(0).accumulables.values, TEST_ACCUMULATOR),
-        findAccumulableInfo(stageInfos(1).accumulables.values, TEST_ACCUMULATOR),
-        findAccumulableInfo(stageInfos(2).accumulables.values, TEST_ACCUMULATOR))
-      assert(firstStageAccum.value.toLong === numPartitions)
-      assert(secondStageAccum.value.toLong === numPartitions * 10)
-      assert(thirdStageAccum.value.toLong === numPartitions * 2 * 100)
-    }
-    rdd.count()
+  test("only external accums are automatically registered") {
+    val accEx = new Accumulator(0, IntAccumulatorParam, Some("external"), internal = false)
+    val accIn = new Accumulator(0, IntAccumulatorParam, Some("internal"), internal = true)
+    assert(!accEx.isInternal)
+    assert(accIn.isInternal)
+    assert(Accumulators.get(accEx.id).isDefined)
+    assert(Accumulators.get(accIn.id).isEmpty)
   }
 
-  test("internal accumulators in fully resubmitted stages") {
-    testInternalAccumulatorsWithFailedTasks((i: Int) => true) // fail all tasks
+  test("copy") {
+    val acc1 = new Accumulable[Long, Long](456L, LongAccumulatorParam, Some("x"), true, false)
+    val acc2 = acc1.copy()
+    assert(acc1.id === acc2.id)
+    assert(acc1.value === acc2.value)
+    assert(acc1.name === acc2.name)
+    assert(acc1.isInternal === acc2.isInternal)
+    assert(acc1.countFailedValues === acc2.countFailedValues)
+    assert(acc1 !== acc2)
+    // Modifying one does not affect the other
+    acc1.add(44L)
+    assert(acc1.value === 500L)
+    assert(acc2.value === 456L)
+    acc2.add(144L)
+    assert(acc1.value === 500L)
+    assert(acc2.value === 600L)
   }
 
-  test("internal accumulators in partially resubmitted stages") {
-    testInternalAccumulatorsWithFailedTasks((i: Int) => i % 2 == 0) // fail a subset
+  test("register multiple accums with same ID") {
+    // Make sure these are internal accums so we don't automatically register them already
+    val acc1 = new Accumulable[Int, Int](0, IntAccumulatorParam, None, true, true)
+    val acc2 = acc1.copy()
+    assert(acc1 !== acc2)
+    assert(acc1.id === acc2.id)
+    assert(Accumulators.originals.isEmpty)
+    assert(Accumulators.get(acc1.id).isEmpty)
+    Accumulators.register(acc1)
+    Accumulators.register(acc2)
+    // The second one does not override the first one
+    assert(Accumulators.originals.size === 1)
+    assert(Accumulators.get(acc1.id) === Some(acc1))
   }
 
-  /**
-   * Return the accumulable info that matches the specified name.
-   */
-  private def findAccumulableInfo(
-      accums: Iterable[AccumulableInfo],
-      name: String): AccumulableInfo = {
-    accums.find { a => a.name == name }.getOrElse {
-      throw new TestFailedException(s"internal accumulator '$name' not found", 0)
-    }
+  test("string accumulator param") {
+    val acc = new Accumulator("", StringAccumulatorParam, Some("darkness"))
+    assert(acc.value === "")
+    acc.setValue("feeds")
+    assert(acc.value === "feeds")
+    acc.add("your")
+    assert(acc.value === "your") // value is overwritten, not concatenated
+    acc += "soul"
+    assert(acc.value === "soul")
+    acc ++= "with"
+    assert(acc.value === "with")
+    acc.merge("kindness")
+    assert(acc.value === "kindness")
   }
 
-  /**
-   * Test whether internal accumulators are merged properly if some tasks fail.
-   */
-  private def testInternalAccumulatorsWithFailedTasks(failCondition: (Int => Boolean)): Unit = {
-    val listener = new SaveInfoListener
-    val numPartitions = 10
-    val numFailedPartitions = (0 until numPartitions).count(failCondition)
-    // This says use 1 core and retry tasks up to 2 times
-    sc = new SparkContext("local[1, 2]", "test")
-    sc.addSparkListener(listener)
-    val rdd = sc.parallelize(1 to 100, numPartitions).mapPartitionsWithIndex { case (i, iter) =>
-      val taskContext = TaskContext.get()
-      taskContext.internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1
-      // Fail the first attempts of a subset of the tasks
-      if (failCondition(i) && taskContext.attemptNumber() == 0) {
-        throw new Exception("Failing a task intentionally.")
-      }
-      iter
-    }
-    // Register asserts in job completion callback to avoid flakiness
-    listener.registerJobCompletionCallback { _ =>
-      val stageInfos = listener.getCompletedStageInfos
-      val taskInfos = listener.getCompletedTaskInfos
-      assert(stageInfos.size === 1)
-      assert(taskInfos.size === numPartitions + numFailedPartitions)
-      val stageAccum = findAccumulableInfo(stageInfos.head.accumulables.values, TEST_ACCUMULATOR)
-      // We should not double count values in the merged accumulator
-      assert(stageAccum.value.toLong === numPartitions)
-      val taskAccumValues = taskInfos.flatMap { taskInfo =>
-        if (!taskInfo.failed) {
-          // If a task succeeded, its update value should always be 1
-          val taskAccum = findAccumulableInfo(taskInfo.accumulables, TEST_ACCUMULATOR)
-          assert(taskAccum.update.isDefined)
-          assert(taskAccum.update.get.toLong === 1)
-          Some(taskAccum.value.toLong)
-        } else {
-          // If a task failed, we should not get its accumulator values
-          assert(taskInfo.accumulables.isEmpty)
-          None
-        }
-      }
-      assert(taskAccumValues.sorted === (1L to numPartitions).toSeq)
-    }
-    rdd.count()
+  test("list accumulator param") {
+    val acc = new Accumulator(Seq.empty[Int], new ListAccumulatorParam[Int], Some("numbers"))
+    assert(acc.value === Seq.empty[Int])
+    acc.add(Seq(1, 2))
+    assert(acc.value === Seq(1, 2))
+    acc += Seq(3, 4)
+    assert(acc.value === Seq(1, 2, 3, 4))
+    acc ++= Seq(5, 6)
+    assert(acc.value === Seq(1, 2, 3, 4, 5, 6))
+    acc.merge(Seq(7, 8))
+    assert(acc.value === Seq(1, 2, 3, 4, 5, 6, 7, 8))
+    acc.setValue(Seq(9, 10))
+    assert(acc.value === Seq(9, 10))
+  }
+
+  test("value is reset on the executors") {
+    val acc1 = new Accumulator(0, IntAccumulatorParam, Some("thing"), internal = false)
+    val acc2 = new Accumulator(0L, LongAccumulatorParam, Some("thing2"), internal = false)
+    val externalAccums = Seq(acc1, acc2)
+    val internalAccums = InternalAccumulator.create()
+    // Set some values; these should not be observed later on the "executors"
+    acc1.setValue(10)
+    acc2.setValue(20L)
+    internalAccums
+      .find(_.name == Some(InternalAccumulator.TEST_ACCUM))
+      .get.asInstanceOf[Accumulator[Long]]
+      .setValue(30L)
+    // Simulate the task being serialized and sent to the executors.
+    val dummyTask = new DummyTask(internalAccums, externalAccums)
+    val serInstance = new JavaSerializer(new SparkConf).newInstance()
+    val taskSer = Task.serializeWithDependencies(
+      dummyTask, mutable.HashMap(), mutable.HashMap(), serInstance)
+    // Now we're on the executors.
+    // Deserialize the task and assert that its accumulators are zero'ed out.
+    val (_, _, taskBytes) = Task.deserializeWithDependencies(taskSer)
+    val taskDeser = serInstance.deserialize[DummyTask](
+      taskBytes, Thread.currentThread.getContextClassLoader)
+    // Assert that executors see only zeros
+    taskDeser.externalAccums.foreach { a => assert(a.localValue == a.zero) }
+    taskDeser.internalAccums.foreach { a => assert(a.localValue == a.zero) }
   }
 
 }
 
 private[spark] object AccumulatorSuite {
 
+  import InternalAccumulator._
+
   /**
-   * Run one or more Spark jobs and verify that the peak execution memory accumulator
-   * is updated afterwards.
+   * Run one or more Spark jobs and verify that in at least one job the peak execution memory
+   * accumulator is updated afterwards.
    */
   def verifyPeakExecutionMemorySet(
       sc: SparkContext,
       testName: String)(testBody: => Unit): Unit = {
     val listener = new SaveInfoListener
     sc.addSparkListener(listener)
-    // Register asserts in job completion callback to avoid flakiness
-    listener.registerJobCompletionCallback { jobId =>
-      if (jobId == 0) {
-        // The first job is a dummy one to verify that the accumulator does not already exist
-        val accums = listener.getCompletedStageInfos.flatMap(_.accumulables.values)
-        assert(!accums.exists(_.name == InternalAccumulator.PEAK_EXECUTION_MEMORY))
-      } else {
-        // In the subsequent jobs, verify that peak execution memory is updated
-        val accum = listener.getCompletedStageInfos
-          .flatMap(_.accumulables.values)
-          .find(_.name == InternalAccumulator.PEAK_EXECUTION_MEMORY)
-          .getOrElse {
-          throw new TestFailedException(
-            s"peak execution memory accumulator not set in '$testName'", 0)
-        }
-        assert(accum.value.toLong > 0)
-      }
-    }
-    // Run the jobs
-    sc.parallelize(1 to 10).count()
     testBody
+    val accums = listener.getCompletedStageInfos.flatMap(_.accumulables.values)
+    val isSet = accums.exists { a =>
+      a.name == Some(PEAK_EXECUTION_MEMORY) && a.value.exists(_.asInstanceOf[Long] > 0L)
+    }
+    if (!isSet) {
+      throw new TestFailedException(s"peak execution memory accumulator not set in '$testName'", 0)
+    }
   }
 }
 
@@ -357,6 +325,10 @@ private class SaveInfoListener extends SparkListener {
   private val completedTaskInfos: ArrayBuffer[TaskInfo] = new ArrayBuffer[TaskInfo]
   private var jobCompletionCallback: (Int => Unit) = null // parameter is job ID
 
+  // Accesses must be synchronized to ensure failures in `jobCompletionCallback` are propagated
+  @GuardedBy("this")
+  private var exception: Throwable = null
+
   def getCompletedStageInfos: Seq[StageInfo] = completedStageInfos.toArray.toSeq
   def getCompletedTaskInfos: Seq[TaskInfo] = completedTaskInfos.toArray.toSeq
 
@@ -365,9 +337,20 @@ private class SaveInfoListener extends SparkListener {
     jobCompletionCallback = callback
   }
 
-  override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
+  /** Throw a stored exception, if any. */
+  def maybeThrowException(): Unit = synchronized {
+    if (exception != null) { throw exception }
+  }
+
+  override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = synchronized {
     if (jobCompletionCallback != null) {
-      jobCompletionCallback(jobEnd.jobId)
+      try {
+        jobCompletionCallback(jobEnd.jobId)
+      } catch {
+        // Store any exception thrown here so we can throw them later in the main thread.
+        // Otherwise, if `jobCompletionCallback` threw something it wouldn't fail the test.
+        case NonFatal(e) => exception = e
+      }
     }
   }
 
@@ -379,3 +362,14 @@ private class SaveInfoListener extends SparkListener {
     completedTaskInfos += taskEnd.taskInfo
   }
 }
+
+
+/**
+ * A dummy [[Task]] that contains internal and external [[Accumulator]]s.
+ */
+private[spark] class DummyTask(
+    val internalAccums: Seq[Accumulator[_]],
+    val externalAccums: Seq[Accumulator[_]])
+  extends Task[Int](0, 0, 0, internalAccums) {
+  override def runTask(c: TaskContext): Int = 1
+}
diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
index 4e678fbac6a397972d24d18f5f265e4bb58fa772..80a1de6065b43a9549546310de700749eafe30e3 100644
--- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
@@ -801,7 +801,7 @@ class ExecutorAllocationManagerSuite
     assert(maxNumExecutorsNeeded(manager) === 1)
 
     // If the task is failed, we expect it to be resubmitted later.
-    val taskEndReason = ExceptionFailure(null, null, null, null, null, None)
+    val taskEndReason = ExceptionFailure(null, null, null, null, None)
     sc.listenerBus.postToAll(SparkListenerTaskEnd(0, 0, null, taskEndReason, taskInfo, null))
     assert(maxNumExecutorsNeeded(manager) === 1)
   }
diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
index c7f629a14ba24f34ea332b3c78eaab2abd9a0cb6..3777d77f8f5b33071d12f127cf4b7bbe6765becb 100644
--- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
+++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
@@ -215,14 +215,16 @@ class HeartbeatReceiverSuite
     val metrics = new TaskMetrics
     val blockManagerId = BlockManagerId(executorId, "localhost", 12345)
     val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse](
-      Heartbeat(executorId, Array(1L -> metrics), blockManagerId))
+      Heartbeat(executorId, Array(1L -> metrics.accumulatorUpdates()), blockManagerId))
     if (executorShouldReregister) {
       assert(response.reregisterBlockManager)
     } else {
       assert(!response.reregisterBlockManager)
       // Additionally verify that the scheduler callback is called with the correct parameters
       verify(scheduler).executorHeartbeatReceived(
-        Matchers.eq(executorId), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId))
+        Matchers.eq(executorId),
+        Matchers.eq(Array(1L -> metrics.accumulatorUpdates())),
+        Matchers.eq(blockManagerId))
     }
   }
 
diff --git a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..630b46f828df718fc3ae0172e5cb0cdb7390e0f5
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala
@@ -0,0 +1,331 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.scheduler.AccumulableInfo
+import org.apache.spark.storage.{BlockId, BlockStatus}
+
+
+class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
+  import InternalAccumulator._
+  import AccumulatorParam._
+
+  test("get param") {
+    assert(getParam(EXECUTOR_DESERIALIZE_TIME) === LongAccumulatorParam)
+    assert(getParam(EXECUTOR_RUN_TIME) === LongAccumulatorParam)
+    assert(getParam(RESULT_SIZE) === LongAccumulatorParam)
+    assert(getParam(JVM_GC_TIME) === LongAccumulatorParam)
+    assert(getParam(RESULT_SERIALIZATION_TIME) === LongAccumulatorParam)
+    assert(getParam(MEMORY_BYTES_SPILLED) === LongAccumulatorParam)
+    assert(getParam(DISK_BYTES_SPILLED) === LongAccumulatorParam)
+    assert(getParam(PEAK_EXECUTION_MEMORY) === LongAccumulatorParam)
+    assert(getParam(UPDATED_BLOCK_STATUSES) === UpdatedBlockStatusesAccumulatorParam)
+    assert(getParam(TEST_ACCUM) === LongAccumulatorParam)
+    // shuffle read
+    assert(getParam(shuffleRead.REMOTE_BLOCKS_FETCHED) === IntAccumulatorParam)
+    assert(getParam(shuffleRead.LOCAL_BLOCKS_FETCHED) === IntAccumulatorParam)
+    assert(getParam(shuffleRead.REMOTE_BYTES_READ) === LongAccumulatorParam)
+    assert(getParam(shuffleRead.LOCAL_BYTES_READ) === LongAccumulatorParam)
+    assert(getParam(shuffleRead.FETCH_WAIT_TIME) === LongAccumulatorParam)
+    assert(getParam(shuffleRead.RECORDS_READ) === LongAccumulatorParam)
+    // shuffle write
+    assert(getParam(shuffleWrite.BYTES_WRITTEN) === LongAccumulatorParam)
+    assert(getParam(shuffleWrite.RECORDS_WRITTEN) === LongAccumulatorParam)
+    assert(getParam(shuffleWrite.WRITE_TIME) === LongAccumulatorParam)
+    // input
+    assert(getParam(input.READ_METHOD) === StringAccumulatorParam)
+    assert(getParam(input.RECORDS_READ) === LongAccumulatorParam)
+    assert(getParam(input.BYTES_READ) === LongAccumulatorParam)
+    // output
+    assert(getParam(output.WRITE_METHOD) === StringAccumulatorParam)
+    assert(getParam(output.RECORDS_WRITTEN) === LongAccumulatorParam)
+    assert(getParam(output.BYTES_WRITTEN) === LongAccumulatorParam)
+    // default to Long
+    assert(getParam(METRICS_PREFIX + "anything") === LongAccumulatorParam)
+    intercept[IllegalArgumentException] {
+      getParam("something that does not start with the right prefix")
+    }
+  }
+
+  test("create by name") {
+    val executorRunTime = create(EXECUTOR_RUN_TIME)
+    val updatedBlockStatuses = create(UPDATED_BLOCK_STATUSES)
+    val shuffleRemoteBlocksRead = create(shuffleRead.REMOTE_BLOCKS_FETCHED)
+    val inputReadMethod = create(input.READ_METHOD)
+    assert(executorRunTime.name === Some(EXECUTOR_RUN_TIME))
+    assert(updatedBlockStatuses.name === Some(UPDATED_BLOCK_STATUSES))
+    assert(shuffleRemoteBlocksRead.name === Some(shuffleRead.REMOTE_BLOCKS_FETCHED))
+    assert(inputReadMethod.name === Some(input.READ_METHOD))
+    assert(executorRunTime.value.isInstanceOf[Long])
+    assert(updatedBlockStatuses.value.isInstanceOf[Seq[_]])
+    // We cannot assert the type of the value directly since the type parameter is erased.
+    // Instead, try casting a `Seq` of expected type and see if it fails in run time.
+    updatedBlockStatuses.setValueAny(Seq.empty[(BlockId, BlockStatus)])
+    assert(shuffleRemoteBlocksRead.value.isInstanceOf[Int])
+    assert(inputReadMethod.value.isInstanceOf[String])
+    // default to Long
+    val anything = create(METRICS_PREFIX + "anything")
+    assert(anything.value.isInstanceOf[Long])
+  }
+
+  test("create") {
+    val accums = create()
+    val shuffleReadAccums = createShuffleReadAccums()
+    val shuffleWriteAccums = createShuffleWriteAccums()
+    val inputAccums = createInputAccums()
+    val outputAccums = createOutputAccums()
+    // assert they're all internal
+    assert(accums.forall(_.isInternal))
+    assert(shuffleReadAccums.forall(_.isInternal))
+    assert(shuffleWriteAccums.forall(_.isInternal))
+    assert(inputAccums.forall(_.isInternal))
+    assert(outputAccums.forall(_.isInternal))
+    // assert they all count on failures
+    assert(accums.forall(_.countFailedValues))
+    assert(shuffleReadAccums.forall(_.countFailedValues))
+    assert(shuffleWriteAccums.forall(_.countFailedValues))
+    assert(inputAccums.forall(_.countFailedValues))
+    assert(outputAccums.forall(_.countFailedValues))
+    // assert they all have names
+    assert(accums.forall(_.name.isDefined))
+    assert(shuffleReadAccums.forall(_.name.isDefined))
+    assert(shuffleWriteAccums.forall(_.name.isDefined))
+    assert(inputAccums.forall(_.name.isDefined))
+    assert(outputAccums.forall(_.name.isDefined))
+    // assert `accums` is a strict superset of the others
+    val accumNames = accums.map(_.name.get).toSet
+    val shuffleReadAccumNames = shuffleReadAccums.map(_.name.get).toSet
+    val shuffleWriteAccumNames = shuffleWriteAccums.map(_.name.get).toSet
+    val inputAccumNames = inputAccums.map(_.name.get).toSet
+    val outputAccumNames = outputAccums.map(_.name.get).toSet
+    assert(shuffleReadAccumNames.subsetOf(accumNames))
+    assert(shuffleWriteAccumNames.subsetOf(accumNames))
+    assert(inputAccumNames.subsetOf(accumNames))
+    assert(outputAccumNames.subsetOf(accumNames))
+  }
+
+  test("naming") {
+    val accums = create()
+    val shuffleReadAccums = createShuffleReadAccums()
+    val shuffleWriteAccums = createShuffleWriteAccums()
+    val inputAccums = createInputAccums()
+    val outputAccums = createOutputAccums()
+    // assert that prefixes are properly namespaced
+    assert(SHUFFLE_READ_METRICS_PREFIX.startsWith(METRICS_PREFIX))
+    assert(SHUFFLE_WRITE_METRICS_PREFIX.startsWith(METRICS_PREFIX))
+    assert(INPUT_METRICS_PREFIX.startsWith(METRICS_PREFIX))
+    assert(OUTPUT_METRICS_PREFIX.startsWith(METRICS_PREFIX))
+    assert(accums.forall(_.name.get.startsWith(METRICS_PREFIX)))
+    // assert they all start with the expected prefixes
+    assert(shuffleReadAccums.forall(_.name.get.startsWith(SHUFFLE_READ_METRICS_PREFIX)))
+    assert(shuffleWriteAccums.forall(_.name.get.startsWith(SHUFFLE_WRITE_METRICS_PREFIX)))
+    assert(inputAccums.forall(_.name.get.startsWith(INPUT_METRICS_PREFIX)))
+    assert(outputAccums.forall(_.name.get.startsWith(OUTPUT_METRICS_PREFIX)))
+  }
+
+  test("internal accumulators in TaskContext") {
+    val taskContext = TaskContext.empty()
+    val accumUpdates = taskContext.taskMetrics.accumulatorUpdates()
+    assert(accumUpdates.size > 0)
+    assert(accumUpdates.forall(_.internal))
+    val testAccum = taskContext.taskMetrics.getAccum(TEST_ACCUM)
+    assert(accumUpdates.exists(_.id == testAccum.id))
+  }
+
+  test("internal accumulators in a stage") {
+    val listener = new SaveInfoListener
+    val numPartitions = 10
+    sc = new SparkContext("local", "test")
+    sc.addSparkListener(listener)
+    // Have each task add 1 to the internal accumulator
+    val rdd = sc.parallelize(1 to 100, numPartitions).mapPartitions { iter =>
+      TaskContext.get().taskMetrics().getAccum(TEST_ACCUM) += 1
+      iter
+    }
+    // Register asserts in job completion callback to avoid flakiness
+    listener.registerJobCompletionCallback { _ =>
+      val stageInfos = listener.getCompletedStageInfos
+      val taskInfos = listener.getCompletedTaskInfos
+      assert(stageInfos.size === 1)
+      assert(taskInfos.size === numPartitions)
+      // The accumulator values should be merged in the stage
+      val stageAccum = findTestAccum(stageInfos.head.accumulables.values)
+      assert(stageAccum.value.get.asInstanceOf[Long] === numPartitions)
+      // The accumulator should be updated locally on each task
+      val taskAccumValues = taskInfos.map { taskInfo =>
+        val taskAccum = findTestAccum(taskInfo.accumulables)
+        assert(taskAccum.update.isDefined)
+        assert(taskAccum.update.get.asInstanceOf[Long] === 1L)
+        taskAccum.value.get.asInstanceOf[Long]
+      }
+      // Each task should keep track of the partial value on the way, i.e. 1, 2, ... numPartitions
+      assert(taskAccumValues.sorted === (1L to numPartitions).toSeq)
+    }
+    rdd.count()
+  }
+
+  test("internal accumulators in multiple stages") {
+    val listener = new SaveInfoListener
+    val numPartitions = 10
+    sc = new SparkContext("local", "test")
+    sc.addSparkListener(listener)
+    // Each stage creates its own set of internal accumulators so the
+    // values for the same metric should not be mixed up across stages
+    val rdd = sc.parallelize(1 to 100, numPartitions)
+      .map { i => (i, i) }
+      .mapPartitions { iter =>
+        TaskContext.get().taskMetrics().getAccum(TEST_ACCUM) += 1
+        iter
+      }
+      .reduceByKey { case (x, y) => x + y }
+      .mapPartitions { iter =>
+        TaskContext.get().taskMetrics().getAccum(TEST_ACCUM) += 10
+        iter
+      }
+      .repartition(numPartitions * 2)
+      .mapPartitions { iter =>
+        TaskContext.get().taskMetrics().getAccum(TEST_ACCUM) += 100
+        iter
+      }
+    // Register asserts in job completion callback to avoid flakiness
+    listener.registerJobCompletionCallback { _ =>
+    // We ran 3 stages, and the accumulator values should be distinct
+      val stageInfos = listener.getCompletedStageInfos
+      assert(stageInfos.size === 3)
+      val (firstStageAccum, secondStageAccum, thirdStageAccum) =
+        (findTestAccum(stageInfos(0).accumulables.values),
+        findTestAccum(stageInfos(1).accumulables.values),
+        findTestAccum(stageInfos(2).accumulables.values))
+      assert(firstStageAccum.value.get.asInstanceOf[Long] === numPartitions)
+      assert(secondStageAccum.value.get.asInstanceOf[Long] === numPartitions * 10)
+      assert(thirdStageAccum.value.get.asInstanceOf[Long] === numPartitions * 2 * 100)
+    }
+    rdd.count()
+  }
+
+  // TODO: these two tests are incorrect; they don't actually trigger stage retries.
+  ignore("internal accumulators in fully resubmitted stages") {
+    testInternalAccumulatorsWithFailedTasks((i: Int) => true) // fail all tasks
+  }
+
+  ignore("internal accumulators in partially resubmitted stages") {
+    testInternalAccumulatorsWithFailedTasks((i: Int) => i % 2 == 0) // fail a subset
+  }
+
+  test("internal accumulators are registered for cleanups") {
+    sc = new SparkContext("local", "test") {
+      private val myCleaner = new SaveAccumContextCleaner(this)
+      override def cleaner: Option[ContextCleaner] = Some(myCleaner)
+    }
+    assert(Accumulators.originals.isEmpty)
+    sc.parallelize(1 to 100).map { i => (i, i) }.reduceByKey { _ + _ }.count()
+    val internalAccums = InternalAccumulator.create()
+    // We ran 2 stages, so we should have 2 sets of internal accumulators, 1 for each stage
+    assert(Accumulators.originals.size === internalAccums.size * 2)
+    val accumsRegistered = sc.cleaner match {
+      case Some(cleaner: SaveAccumContextCleaner) => cleaner.accumsRegisteredForCleanup
+      case _ => Seq.empty[Long]
+    }
+    // Make sure the same set of accumulators is registered for cleanup
+    assert(accumsRegistered.size === internalAccums.size * 2)
+    assert(accumsRegistered.toSet === Accumulators.originals.keys.toSet)
+  }
+
+  /**
+   * Return the accumulable info that matches the specified name.
+   */
+  private def findTestAccum(accums: Iterable[AccumulableInfo]): AccumulableInfo = {
+    accums.find { a => a.name == Some(TEST_ACCUM) }.getOrElse {
+      fail(s"unable to find internal accumulator called $TEST_ACCUM")
+    }
+  }
+
+  /**
+   * Test whether internal accumulators are merged properly if some tasks fail.
+   * TODO: make this actually retry the stage.
+   */
+  private def testInternalAccumulatorsWithFailedTasks(failCondition: (Int => Boolean)): Unit = {
+    val listener = new SaveInfoListener
+    val numPartitions = 10
+    val numFailedPartitions = (0 until numPartitions).count(failCondition)
+    // This says use 1 core and retry tasks up to 2 times
+    sc = new SparkContext("local[1, 2]", "test")
+    sc.addSparkListener(listener)
+    val rdd = sc.parallelize(1 to 100, numPartitions).mapPartitionsWithIndex { case (i, iter) =>
+      val taskContext = TaskContext.get()
+      taskContext.taskMetrics().getAccum(TEST_ACCUM) += 1
+      // Fail the first attempts of a subset of the tasks
+      if (failCondition(i) && taskContext.attemptNumber() == 0) {
+        throw new Exception("Failing a task intentionally.")
+      }
+      iter
+    }
+    // Register asserts in job completion callback to avoid flakiness
+    listener.registerJobCompletionCallback { _ =>
+      val stageInfos = listener.getCompletedStageInfos
+      val taskInfos = listener.getCompletedTaskInfos
+      assert(stageInfos.size === 1)
+      assert(taskInfos.size === numPartitions + numFailedPartitions)
+      val stageAccum = findTestAccum(stageInfos.head.accumulables.values)
+      // If all partitions failed, then we would resubmit the whole stage again and create a
+      // fresh set of internal accumulators. Otherwise, these internal accumulators do count
+      // failed values, so we must include the failed values.
+      val expectedAccumValue =
+        if (numPartitions == numFailedPartitions) {
+          numPartitions
+        } else {
+          numPartitions + numFailedPartitions
+        }
+      assert(stageAccum.value.get.asInstanceOf[Long] === expectedAccumValue)
+      val taskAccumValues = taskInfos.flatMap { taskInfo =>
+        if (!taskInfo.failed) {
+          // If a task succeeded, its update value should always be 1
+          val taskAccum = findTestAccum(taskInfo.accumulables)
+          assert(taskAccum.update.isDefined)
+          assert(taskAccum.update.get.asInstanceOf[Long] === 1L)
+          assert(taskAccum.value.isDefined)
+          Some(taskAccum.value.get.asInstanceOf[Long])
+        } else {
+          // If a task failed, we should not get its accumulator values
+          assert(taskInfo.accumulables.isEmpty)
+          None
+        }
+      }
+      assert(taskAccumValues.sorted === (1L to numPartitions).toSeq)
+    }
+    rdd.count()
+    listener.maybeThrowException()
+  }
+
+  /**
+   * A special [[ContextCleaner]] that saves the IDs of the accumulators registered for cleanup.
+   */
+  private class SaveAccumContextCleaner(sc: SparkContext) extends ContextCleaner(sc) {
+    private val accumsRegistered = new ArrayBuffer[Long]
+
+    override def registerAccumulatorForCleanup(a: Accumulable[_, _]): Unit = {
+      accumsRegistered += a.id
+      super.registerAccumulatorForCleanup(a)
+    }
+
+    def accumsRegisteredForCleanup: Seq[Long] = accumsRegistered.toArray
+  }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala
index 9be9db01c7de9016f8e1267dced0d6e47f0f75d9..d3359c7406e45ca07d2d70890543416cd8ba8af0 100644
--- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala
@@ -42,6 +42,8 @@ private[spark] abstract class SparkFunSuite extends FunSuite with Logging {
       test()
     } finally {
       logInfo(s"\n\n===== FINISHED $shortSuiteName: '$testName' =====\n")
+      // Avoid leaking map entries in tests that use accumulators without SparkContext
+      Accumulators.clear()
     }
   }
 
diff --git a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala
index e5ec2aa1be355b44f9d563966d97eec43a0c0977..15be0b194ed8ebf29b5b5614267d0555eea445ca 100644
--- a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala
@@ -17,12 +17,542 @@
 
 package org.apache.spark.executor
 
-import org.apache.spark.SparkFunSuite
+import org.scalatest.Assertions
+
+import org.apache.spark._
+import org.apache.spark.scheduler.AccumulableInfo
+import org.apache.spark.storage.{BlockId, BlockStatus, StorageLevel, TestBlockId}
+
 
 class TaskMetricsSuite extends SparkFunSuite {
-  test("[SPARK-5701] updateShuffleReadMetrics: ShuffleReadMetrics not added when no shuffle deps") {
-    val taskMetrics = new TaskMetrics()
-    taskMetrics.mergeShuffleReadMetrics()
-    assert(taskMetrics.shuffleReadMetrics.isEmpty)
+  import AccumulatorParam._
+  import InternalAccumulator._
+  import StorageLevel._
+  import TaskMetricsSuite._
+
+  test("create") {
+    val internalAccums = InternalAccumulator.create()
+    val tm1 = new TaskMetrics
+    val tm2 = new TaskMetrics(internalAccums)
+    assert(tm1.accumulatorUpdates().size === internalAccums.size)
+    assert(tm1.shuffleReadMetrics.isEmpty)
+    assert(tm1.shuffleWriteMetrics.isEmpty)
+    assert(tm1.inputMetrics.isEmpty)
+    assert(tm1.outputMetrics.isEmpty)
+    assert(tm2.accumulatorUpdates().size === internalAccums.size)
+    assert(tm2.shuffleReadMetrics.isEmpty)
+    assert(tm2.shuffleWriteMetrics.isEmpty)
+    assert(tm2.inputMetrics.isEmpty)
+    assert(tm2.outputMetrics.isEmpty)
+    // TaskMetrics constructor expects minimal set of initial accumulators
+    intercept[IllegalArgumentException] { new TaskMetrics(Seq.empty[Accumulator[_]]) }
+  }
+
+  test("create with unnamed accum") {
+    intercept[IllegalArgumentException] {
+      new TaskMetrics(
+        InternalAccumulator.create() ++ Seq(
+          new Accumulator(0, IntAccumulatorParam, None, internal = true)))
+    }
+  }
+
+  test("create with duplicate name accum") {
+    intercept[IllegalArgumentException] {
+      new TaskMetrics(
+        InternalAccumulator.create() ++ Seq(
+          new Accumulator(0, IntAccumulatorParam, Some(RESULT_SIZE), internal = true)))
+    }
+  }
+
+  test("create with external accum") {
+    intercept[IllegalArgumentException] {
+      new TaskMetrics(
+        InternalAccumulator.create() ++ Seq(
+          new Accumulator(0, IntAccumulatorParam, Some("x"))))
+    }
+  }
+
+  test("create shuffle read metrics") {
+    import shuffleRead._
+    val accums = InternalAccumulator.createShuffleReadAccums()
+      .map { a => (a.name.get, a) }.toMap[String, Accumulator[_]]
+    accums(REMOTE_BLOCKS_FETCHED).setValueAny(1)
+    accums(LOCAL_BLOCKS_FETCHED).setValueAny(2)
+    accums(REMOTE_BYTES_READ).setValueAny(3L)
+    accums(LOCAL_BYTES_READ).setValueAny(4L)
+    accums(FETCH_WAIT_TIME).setValueAny(5L)
+    accums(RECORDS_READ).setValueAny(6L)
+    val sr = new ShuffleReadMetrics(accums)
+    assert(sr.remoteBlocksFetched === 1)
+    assert(sr.localBlocksFetched === 2)
+    assert(sr.remoteBytesRead === 3L)
+    assert(sr.localBytesRead === 4L)
+    assert(sr.fetchWaitTime === 5L)
+    assert(sr.recordsRead === 6L)
+  }
+
+  test("create shuffle write metrics") {
+    import shuffleWrite._
+    val accums = InternalAccumulator.createShuffleWriteAccums()
+      .map { a => (a.name.get, a) }.toMap[String, Accumulator[_]]
+    accums(BYTES_WRITTEN).setValueAny(1L)
+    accums(RECORDS_WRITTEN).setValueAny(2L)
+    accums(WRITE_TIME).setValueAny(3L)
+    val sw = new ShuffleWriteMetrics(accums)
+    assert(sw.bytesWritten === 1L)
+    assert(sw.recordsWritten === 2L)
+    assert(sw.writeTime === 3L)
+  }
+
+  test("create input metrics") {
+    import input._
+    val accums = InternalAccumulator.createInputAccums()
+      .map { a => (a.name.get, a) }.toMap[String, Accumulator[_]]
+    accums(BYTES_READ).setValueAny(1L)
+    accums(RECORDS_READ).setValueAny(2L)
+    accums(READ_METHOD).setValueAny(DataReadMethod.Hadoop.toString)
+    val im = new InputMetrics(accums)
+    assert(im.bytesRead === 1L)
+    assert(im.recordsRead === 2L)
+    assert(im.readMethod === DataReadMethod.Hadoop)
+  }
+
+  test("create output metrics") {
+    import output._
+    val accums = InternalAccumulator.createOutputAccums()
+      .map { a => (a.name.get, a) }.toMap[String, Accumulator[_]]
+    accums(BYTES_WRITTEN).setValueAny(1L)
+    accums(RECORDS_WRITTEN).setValueAny(2L)
+    accums(WRITE_METHOD).setValueAny(DataWriteMethod.Hadoop.toString)
+    val om = new OutputMetrics(accums)
+    assert(om.bytesWritten === 1L)
+    assert(om.recordsWritten === 2L)
+    assert(om.writeMethod === DataWriteMethod.Hadoop)
+  }
+
+  test("mutating values") {
+    val accums = InternalAccumulator.create()
+    val tm = new TaskMetrics(accums)
+    // initial values
+    assertValueEquals(tm, _.executorDeserializeTime, accums, EXECUTOR_DESERIALIZE_TIME, 0L)
+    assertValueEquals(tm, _.executorRunTime, accums, EXECUTOR_RUN_TIME, 0L)
+    assertValueEquals(tm, _.resultSize, accums, RESULT_SIZE, 0L)
+    assertValueEquals(tm, _.jvmGCTime, accums, JVM_GC_TIME, 0L)
+    assertValueEquals(tm, _.resultSerializationTime, accums, RESULT_SERIALIZATION_TIME, 0L)
+    assertValueEquals(tm, _.memoryBytesSpilled, accums, MEMORY_BYTES_SPILLED, 0L)
+    assertValueEquals(tm, _.diskBytesSpilled, accums, DISK_BYTES_SPILLED, 0L)
+    assertValueEquals(tm, _.peakExecutionMemory, accums, PEAK_EXECUTION_MEMORY, 0L)
+    assertValueEquals(tm, _.updatedBlockStatuses, accums, UPDATED_BLOCK_STATUSES,
+      Seq.empty[(BlockId, BlockStatus)])
+    // set or increment values
+    tm.setExecutorDeserializeTime(100L)
+    tm.setExecutorDeserializeTime(1L) // overwrite
+    tm.setExecutorRunTime(200L)
+    tm.setExecutorRunTime(2L)
+    tm.setResultSize(300L)
+    tm.setResultSize(3L)
+    tm.setJvmGCTime(400L)
+    tm.setJvmGCTime(4L)
+    tm.setResultSerializationTime(500L)
+    tm.setResultSerializationTime(5L)
+    tm.incMemoryBytesSpilled(600L)
+    tm.incMemoryBytesSpilled(6L) // add
+    tm.incDiskBytesSpilled(700L)
+    tm.incDiskBytesSpilled(7L)
+    tm.incPeakExecutionMemory(800L)
+    tm.incPeakExecutionMemory(8L)
+    val block1 = (TestBlockId("a"), BlockStatus(MEMORY_ONLY, 1L, 2L))
+    val block2 = (TestBlockId("b"), BlockStatus(MEMORY_ONLY, 3L, 4L))
+    tm.incUpdatedBlockStatuses(Seq(block1))
+    tm.incUpdatedBlockStatuses(Seq(block2))
+    // assert new values exist
+    assertValueEquals(tm, _.executorDeserializeTime, accums, EXECUTOR_DESERIALIZE_TIME, 1L)
+    assertValueEquals(tm, _.executorRunTime, accums, EXECUTOR_RUN_TIME, 2L)
+    assertValueEquals(tm, _.resultSize, accums, RESULT_SIZE, 3L)
+    assertValueEquals(tm, _.jvmGCTime, accums, JVM_GC_TIME, 4L)
+    assertValueEquals(tm, _.resultSerializationTime, accums, RESULT_SERIALIZATION_TIME, 5L)
+    assertValueEquals(tm, _.memoryBytesSpilled, accums, MEMORY_BYTES_SPILLED, 606L)
+    assertValueEquals(tm, _.diskBytesSpilled, accums, DISK_BYTES_SPILLED, 707L)
+    assertValueEquals(tm, _.peakExecutionMemory, accums, PEAK_EXECUTION_MEMORY, 808L)
+    assertValueEquals(tm, _.updatedBlockStatuses, accums, UPDATED_BLOCK_STATUSES,
+      Seq(block1, block2))
+  }
+
+  test("mutating shuffle read metrics values") {
+    import shuffleRead._
+    val accums = InternalAccumulator.create()
+    val tm = new TaskMetrics(accums)
+    def assertValEquals[T](tmValue: ShuffleReadMetrics => T, name: String, value: T): Unit = {
+      assertValueEquals(tm, tm => tmValue(tm.shuffleReadMetrics.get), accums, name, value)
+    }
+    // create shuffle read metrics
+    assert(tm.shuffleReadMetrics.isEmpty)
+    tm.registerTempShuffleReadMetrics()
+    tm.mergeShuffleReadMetrics()
+    assert(tm.shuffleReadMetrics.isDefined)
+    val sr = tm.shuffleReadMetrics.get
+    // initial values
+    assertValEquals(_.remoteBlocksFetched, REMOTE_BLOCKS_FETCHED, 0)
+    assertValEquals(_.localBlocksFetched, LOCAL_BLOCKS_FETCHED, 0)
+    assertValEquals(_.remoteBytesRead, REMOTE_BYTES_READ, 0L)
+    assertValEquals(_.localBytesRead, LOCAL_BYTES_READ, 0L)
+    assertValEquals(_.fetchWaitTime, FETCH_WAIT_TIME, 0L)
+    assertValEquals(_.recordsRead, RECORDS_READ, 0L)
+    // set and increment values
+    sr.setRemoteBlocksFetched(100)
+    sr.setRemoteBlocksFetched(10)
+    sr.incRemoteBlocksFetched(1) // 10 + 1
+    sr.incRemoteBlocksFetched(1) // 10 + 1 + 1
+    sr.setLocalBlocksFetched(200)
+    sr.setLocalBlocksFetched(20)
+    sr.incLocalBlocksFetched(2)
+    sr.incLocalBlocksFetched(2)
+    sr.setRemoteBytesRead(300L)
+    sr.setRemoteBytesRead(30L)
+    sr.incRemoteBytesRead(3L)
+    sr.incRemoteBytesRead(3L)
+    sr.setLocalBytesRead(400L)
+    sr.setLocalBytesRead(40L)
+    sr.incLocalBytesRead(4L)
+    sr.incLocalBytesRead(4L)
+    sr.setFetchWaitTime(500L)
+    sr.setFetchWaitTime(50L)
+    sr.incFetchWaitTime(5L)
+    sr.incFetchWaitTime(5L)
+    sr.setRecordsRead(600L)
+    sr.setRecordsRead(60L)
+    sr.incRecordsRead(6L)
+    sr.incRecordsRead(6L)
+    // assert new values exist
+    assertValEquals(_.remoteBlocksFetched, REMOTE_BLOCKS_FETCHED, 12)
+    assertValEquals(_.localBlocksFetched, LOCAL_BLOCKS_FETCHED, 24)
+    assertValEquals(_.remoteBytesRead, REMOTE_BYTES_READ, 36L)
+    assertValEquals(_.localBytesRead, LOCAL_BYTES_READ, 48L)
+    assertValEquals(_.fetchWaitTime, FETCH_WAIT_TIME, 60L)
+    assertValEquals(_.recordsRead, RECORDS_READ, 72L)
+  }
+
+  test("mutating shuffle write metrics values") {
+    import shuffleWrite._
+    val accums = InternalAccumulator.create()
+    val tm = new TaskMetrics(accums)
+    def assertValEquals[T](tmValue: ShuffleWriteMetrics => T, name: String, value: T): Unit = {
+      assertValueEquals(tm, tm => tmValue(tm.shuffleWriteMetrics.get), accums, name, value)
+    }
+    // create shuffle write metrics
+    assert(tm.shuffleWriteMetrics.isEmpty)
+    tm.registerShuffleWriteMetrics()
+    assert(tm.shuffleWriteMetrics.isDefined)
+    val sw = tm.shuffleWriteMetrics.get
+    // initial values
+    assertValEquals(_.bytesWritten, BYTES_WRITTEN, 0L)
+    assertValEquals(_.recordsWritten, RECORDS_WRITTEN, 0L)
+    assertValEquals(_.writeTime, WRITE_TIME, 0L)
+    // increment and decrement values
+    sw.incBytesWritten(100L)
+    sw.incBytesWritten(10L) // 100 + 10
+    sw.decBytesWritten(1L) // 100 + 10 - 1
+    sw.decBytesWritten(1L) // 100 + 10 - 1 - 1
+    sw.incRecordsWritten(200L)
+    sw.incRecordsWritten(20L)
+    sw.decRecordsWritten(2L)
+    sw.decRecordsWritten(2L)
+    sw.incWriteTime(300L)
+    sw.incWriteTime(30L)
+    // assert new values exist
+    assertValEquals(_.bytesWritten, BYTES_WRITTEN, 108L)
+    assertValEquals(_.recordsWritten, RECORDS_WRITTEN, 216L)
+    assertValEquals(_.writeTime, WRITE_TIME, 330L)
+  }
+
+  test("mutating input metrics values") {
+    import input._
+    val accums = InternalAccumulator.create()
+    val tm = new TaskMetrics(accums)
+    def assertValEquals(tmValue: InputMetrics => Any, name: String, value: Any): Unit = {
+      assertValueEquals(tm, tm => tmValue(tm.inputMetrics.get), accums, name, value,
+        (x: Any, y: Any) => assert(x.toString === y.toString))
+    }
+    // create input metrics
+    assert(tm.inputMetrics.isEmpty)
+    tm.registerInputMetrics(DataReadMethod.Memory)
+    assert(tm.inputMetrics.isDefined)
+    val in = tm.inputMetrics.get
+    // initial values
+    assertValEquals(_.bytesRead, BYTES_READ, 0L)
+    assertValEquals(_.recordsRead, RECORDS_READ, 0L)
+    assertValEquals(_.readMethod, READ_METHOD, DataReadMethod.Memory)
+    // set and increment values
+    in.setBytesRead(1L)
+    in.setBytesRead(2L)
+    in.incRecordsRead(1L)
+    in.incRecordsRead(2L)
+    in.setReadMethod(DataReadMethod.Disk)
+    // assert new values exist
+    assertValEquals(_.bytesRead, BYTES_READ, 2L)
+    assertValEquals(_.recordsRead, RECORDS_READ, 3L)
+    assertValEquals(_.readMethod, READ_METHOD, DataReadMethod.Disk)
+  }
+
+  test("mutating output metrics values") {
+    import output._
+    val accums = InternalAccumulator.create()
+    val tm = new TaskMetrics(accums)
+    def assertValEquals(tmValue: OutputMetrics => Any, name: String, value: Any): Unit = {
+      assertValueEquals(tm, tm => tmValue(tm.outputMetrics.get), accums, name, value,
+        (x: Any, y: Any) => assert(x.toString === y.toString))
+    }
+    // create input metrics
+    assert(tm.outputMetrics.isEmpty)
+    tm.registerOutputMetrics(DataWriteMethod.Hadoop)
+    assert(tm.outputMetrics.isDefined)
+    val out = tm.outputMetrics.get
+    // initial values
+    assertValEquals(_.bytesWritten, BYTES_WRITTEN, 0L)
+    assertValEquals(_.recordsWritten, RECORDS_WRITTEN, 0L)
+    assertValEquals(_.writeMethod, WRITE_METHOD, DataWriteMethod.Hadoop)
+    // set values
+    out.setBytesWritten(1L)
+    out.setBytesWritten(2L)
+    out.setRecordsWritten(3L)
+    out.setRecordsWritten(4L)
+    out.setWriteMethod(DataWriteMethod.Hadoop)
+    // assert new values exist
+    assertValEquals(_.bytesWritten, BYTES_WRITTEN, 2L)
+    assertValEquals(_.recordsWritten, RECORDS_WRITTEN, 4L)
+    // Note: this doesn't actually test anything, but there's only one DataWriteMethod
+    // so we can't set it to anything else
+    assertValEquals(_.writeMethod, WRITE_METHOD, DataWriteMethod.Hadoop)
+  }
+
+  test("merging multiple shuffle read metrics") {
+    val tm = new TaskMetrics
+    assert(tm.shuffleReadMetrics.isEmpty)
+    val sr1 = tm.registerTempShuffleReadMetrics()
+    val sr2 = tm.registerTempShuffleReadMetrics()
+    val sr3 = tm.registerTempShuffleReadMetrics()
+    assert(tm.shuffleReadMetrics.isEmpty)
+    sr1.setRecordsRead(10L)
+    sr2.setRecordsRead(10L)
+    sr1.setFetchWaitTime(1L)
+    sr2.setFetchWaitTime(2L)
+    sr3.setFetchWaitTime(3L)
+    tm.mergeShuffleReadMetrics()
+    assert(tm.shuffleReadMetrics.isDefined)
+    val sr = tm.shuffleReadMetrics.get
+    assert(sr.remoteBlocksFetched === 0L)
+    assert(sr.recordsRead === 20L)
+    assert(sr.fetchWaitTime === 6L)
+
+    // SPARK-5701: calling merge without any shuffle deps does nothing
+    val tm2 = new TaskMetrics
+    tm2.mergeShuffleReadMetrics()
+    assert(tm2.shuffleReadMetrics.isEmpty)
+  }
+
+  test("register multiple shuffle write metrics") {
+    val tm = new TaskMetrics
+    val sw1 = tm.registerShuffleWriteMetrics()
+    val sw2 = tm.registerShuffleWriteMetrics()
+    assert(sw1 === sw2)
+    assert(tm.shuffleWriteMetrics === Some(sw1))
+  }
+
+  test("register multiple input metrics") {
+    val tm = new TaskMetrics
+    val im1 = tm.registerInputMetrics(DataReadMethod.Memory)
+    val im2 = tm.registerInputMetrics(DataReadMethod.Memory)
+    // input metrics with a different read method than the one already registered are ignored
+    val im3 = tm.registerInputMetrics(DataReadMethod.Hadoop)
+    assert(im1 === im2)
+    assert(im1 !== im3)
+    assert(tm.inputMetrics === Some(im1))
+    im2.setBytesRead(50L)
+    im3.setBytesRead(100L)
+    assert(tm.inputMetrics.get.bytesRead === 50L)
+  }
+
+  test("register multiple output metrics") {
+    val tm = new TaskMetrics
+    val om1 = tm.registerOutputMetrics(DataWriteMethod.Hadoop)
+    val om2 = tm.registerOutputMetrics(DataWriteMethod.Hadoop)
+    assert(om1 === om2)
+    assert(tm.outputMetrics === Some(om1))
+  }
+
+  test("additional accumulables") {
+    val internalAccums = InternalAccumulator.create()
+    val tm = new TaskMetrics(internalAccums)
+    assert(tm.accumulatorUpdates().size === internalAccums.size)
+    val acc1 = new Accumulator(0, IntAccumulatorParam, Some("a"))
+    val acc2 = new Accumulator(0, IntAccumulatorParam, Some("b"))
+    val acc3 = new Accumulator(0, IntAccumulatorParam, Some("c"))
+    val acc4 = new Accumulator(0, IntAccumulatorParam, Some("d"),
+      internal = true, countFailedValues = true)
+    tm.registerAccumulator(acc1)
+    tm.registerAccumulator(acc2)
+    tm.registerAccumulator(acc3)
+    tm.registerAccumulator(acc4)
+    acc1 += 1
+    acc2 += 2
+    val newUpdates = tm.accumulatorUpdates().map { a => (a.id, a) }.toMap
+    assert(newUpdates.contains(acc1.id))
+    assert(newUpdates.contains(acc2.id))
+    assert(newUpdates.contains(acc3.id))
+    assert(newUpdates.contains(acc4.id))
+    assert(newUpdates(acc1.id).name === Some("a"))
+    assert(newUpdates(acc2.id).name === Some("b"))
+    assert(newUpdates(acc3.id).name === Some("c"))
+    assert(newUpdates(acc4.id).name === Some("d"))
+    assert(newUpdates(acc1.id).update === Some(1))
+    assert(newUpdates(acc2.id).update === Some(2))
+    assert(newUpdates(acc3.id).update === Some(0))
+    assert(newUpdates(acc4.id).update === Some(0))
+    assert(!newUpdates(acc3.id).internal)
+    assert(!newUpdates(acc3.id).countFailedValues)
+    assert(newUpdates(acc4.id).internal)
+    assert(newUpdates(acc4.id).countFailedValues)
+    assert(newUpdates.values.map(_.update).forall(_.isDefined))
+    assert(newUpdates.values.map(_.value).forall(_.isEmpty))
+    assert(newUpdates.size === internalAccums.size + 4)
+  }
+
+  test("existing values in shuffle read accums") {
+    // set shuffle read accum before passing it into TaskMetrics
+    val accums = InternalAccumulator.create()
+    val srAccum = accums.find(_.name === Some(shuffleRead.FETCH_WAIT_TIME))
+    assert(srAccum.isDefined)
+    srAccum.get.asInstanceOf[Accumulator[Long]] += 10L
+    val tm = new TaskMetrics(accums)
+    assert(tm.shuffleReadMetrics.isDefined)
+    assert(tm.shuffleWriteMetrics.isEmpty)
+    assert(tm.inputMetrics.isEmpty)
+    assert(tm.outputMetrics.isEmpty)
+  }
+
+  test("existing values in shuffle write accums") {
+    // set shuffle write accum before passing it into TaskMetrics
+    val accums = InternalAccumulator.create()
+    val swAccum = accums.find(_.name === Some(shuffleWrite.RECORDS_WRITTEN))
+    assert(swAccum.isDefined)
+    swAccum.get.asInstanceOf[Accumulator[Long]] += 10L
+    val tm = new TaskMetrics(accums)
+    assert(tm.shuffleReadMetrics.isEmpty)
+    assert(tm.shuffleWriteMetrics.isDefined)
+    assert(tm.inputMetrics.isEmpty)
+    assert(tm.outputMetrics.isEmpty)
+  }
+
+  test("existing values in input accums") {
+    // set input accum before passing it into TaskMetrics
+    val accums = InternalAccumulator.create()
+    val inAccum = accums.find(_.name === Some(input.RECORDS_READ))
+    assert(inAccum.isDefined)
+    inAccum.get.asInstanceOf[Accumulator[Long]] += 10L
+    val tm = new TaskMetrics(accums)
+    assert(tm.shuffleReadMetrics.isEmpty)
+    assert(tm.shuffleWriteMetrics.isEmpty)
+    assert(tm.inputMetrics.isDefined)
+    assert(tm.outputMetrics.isEmpty)
   }
+
+  test("existing values in output accums") {
+    // set output accum before passing it into TaskMetrics
+    val accums = InternalAccumulator.create()
+    val outAccum = accums.find(_.name === Some(output.RECORDS_WRITTEN))
+    assert(outAccum.isDefined)
+    outAccum.get.asInstanceOf[Accumulator[Long]] += 10L
+    val tm4 = new TaskMetrics(accums)
+    assert(tm4.shuffleReadMetrics.isEmpty)
+    assert(tm4.shuffleWriteMetrics.isEmpty)
+    assert(tm4.inputMetrics.isEmpty)
+    assert(tm4.outputMetrics.isDefined)
+  }
+
+  test("from accumulator updates") {
+    val accumUpdates1 = InternalAccumulator.create().map { a =>
+      AccumulableInfo(a.id, a.name, Some(3L), None, a.isInternal, a.countFailedValues)
+    }
+    val metrics1 = TaskMetrics.fromAccumulatorUpdates(accumUpdates1)
+    assertUpdatesEquals(metrics1.accumulatorUpdates(), accumUpdates1)
+    // Test this with additional accumulators. Only the ones registered with `Accumulators`
+    // will show up in the reconstructed TaskMetrics. In practice, all accumulators created
+    // on the driver, internal or not, should be registered with `Accumulators` at some point.
+    // Here we show that reconstruction will succeed even if there are unregistered accumulators.
+    val param = IntAccumulatorParam
+    val registeredAccums = Seq(
+      new Accumulator(0, param, Some("a"), internal = true, countFailedValues = true),
+      new Accumulator(0, param, Some("b"), internal = true, countFailedValues = false),
+      new Accumulator(0, param, Some("c"), internal = false, countFailedValues = true),
+      new Accumulator(0, param, Some("d"), internal = false, countFailedValues = false))
+    val unregisteredAccums = Seq(
+      new Accumulator(0, param, Some("e"), internal = true, countFailedValues = true),
+      new Accumulator(0, param, Some("f"), internal = true, countFailedValues = false))
+    registeredAccums.foreach(Accumulators.register)
+    registeredAccums.foreach { a => assert(Accumulators.originals.contains(a.id)) }
+    unregisteredAccums.foreach { a => assert(!Accumulators.originals.contains(a.id)) }
+    // set some values in these accums
+    registeredAccums.zipWithIndex.foreach { case (a, i) => a.setValue(i) }
+    unregisteredAccums.zipWithIndex.foreach { case (a, i) => a.setValue(i) }
+    val registeredAccumInfos = registeredAccums.map(makeInfo)
+    val unregisteredAccumInfos = unregisteredAccums.map(makeInfo)
+    val accumUpdates2 = accumUpdates1 ++ registeredAccumInfos ++ unregisteredAccumInfos
+    val metrics2 = TaskMetrics.fromAccumulatorUpdates(accumUpdates2)
+    // accumulators that were not registered with `Accumulators` will not show up
+    assertUpdatesEquals(metrics2.accumulatorUpdates(), accumUpdates1 ++ registeredAccumInfos)
+  }
+}
+
+
+private[spark] object TaskMetricsSuite extends Assertions {
+
+  /**
+   * Assert that the following three things are equal to `value`:
+   *   (1) TaskMetrics value
+   *   (2) TaskMetrics accumulator update value
+   *   (3) Original accumulator value
+   */
+  def assertValueEquals(
+      tm: TaskMetrics,
+      tmValue: TaskMetrics => Any,
+      accums: Seq[Accumulator[_]],
+      metricName: String,
+      value: Any,
+      assertEquals: (Any, Any) => Unit = (x: Any, y: Any) => assert(x === y)): Unit = {
+    assertEquals(tmValue(tm), value)
+    val accum = accums.find(_.name == Some(metricName))
+    assert(accum.isDefined)
+    assertEquals(accum.get.value, value)
+    val accumUpdate = tm.accumulatorUpdates().find(_.name == Some(metricName))
+    assert(accumUpdate.isDefined)
+    assert(accumUpdate.get.value === None)
+    assertEquals(accumUpdate.get.update, Some(value))
+  }
+
+  /**
+   * Assert that two lists of accumulator updates are equal.
+   * Note: this does NOT check accumulator ID equality.
+   */
+  def assertUpdatesEquals(
+      updates1: Seq[AccumulableInfo],
+      updates2: Seq[AccumulableInfo]): Unit = {
+    assert(updates1.size === updates2.size)
+    updates1.zip(updates2).foreach { case (info1, info2) =>
+      // do not assert ID equals here
+      assert(info1.name === info2.name)
+      assert(info1.update === info2.update)
+      assert(info1.value === info2.value)
+      assert(info1.internal === info2.internal)
+      assert(info1.countFailedValues === info2.countFailedValues)
+    }
+  }
+
+  /**
+   * Make an [[AccumulableInfo]] out of an [[Accumulable]] with the intent to use the
+   * info as an accumulator update.
+   */
+  def makeInfo(a: Accumulable[_, _]): AccumulableInfo = {
+    new AccumulableInfo(a.id, a.name, Some(a.value), None, a.isInternal, a.countFailedValues)
+  }
+
 }
diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala
index 0e60cc8e778736e5a5e5be0ac9b6a54ebddfb856..2b5e4b80e96ab96e6af4d3d9c46911d4ab39509a 100644
--- a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala
+++ b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala
@@ -31,7 +31,6 @@ object MemoryTestingUtils {
       taskAttemptId = 0,
       attemptNumber = 0,
       taskMemoryManager = taskMemoryManager,
-      metricsSystem = env.metricsSystem,
-      internalAccumulators = Seq.empty)
+      metricsSystem = env.metricsSystem)
   }
 }
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 370a284d2950f5d3223b03e73ba96eb97ad73788..d9c71ec2eae7b5031660505195f9bcc7263024ce 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -23,7 +23,6 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
 import scala.language.reflectiveCalls
 import scala.util.control.NonFatal
 
-import org.scalatest.BeforeAndAfter
 import org.scalatest.concurrent.Timeouts
 import org.scalatest.time.SpanSugar._
 
@@ -96,8 +95,7 @@ class MyRDD(
 
 class DAGSchedulerSuiteDummyException extends Exception
 
-class DAGSchedulerSuite
-  extends SparkFunSuite with BeforeAndAfter with LocalSparkContext with Timeouts {
+class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeouts {
 
   val conf = new SparkConf
   /** Set of TaskSets the DAGScheduler has requested executed. */
@@ -111,8 +109,10 @@ class DAGSchedulerSuite
     override def schedulingMode: SchedulingMode = SchedulingMode.NONE
     override def start() = {}
     override def stop() = {}
-    override def executorHeartbeatReceived(execId: String, taskMetrics: Array[(Long, TaskMetrics)],
-      blockManagerId: BlockManagerId): Boolean = true
+    override def executorHeartbeatReceived(
+        execId: String,
+        accumUpdates: Array[(Long, Seq[AccumulableInfo])],
+        blockManagerId: BlockManagerId): Boolean = true
     override def submitTasks(taskSet: TaskSet) = {
       // normally done by TaskSetManager
       taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch)
@@ -189,7 +189,8 @@ class DAGSchedulerSuite
     override def jobFailed(exception: Exception): Unit = { failure = exception }
   }
 
-  before {
+  override def beforeEach(): Unit = {
+    super.beforeEach()
     sc = new SparkContext("local", "DAGSchedulerSuite")
     sparkListener.submittedStageInfos.clear()
     sparkListener.successfulStages.clear()
@@ -202,17 +203,21 @@ class DAGSchedulerSuite
     results.clear()
     mapOutputTracker = new MapOutputTrackerMaster(conf)
     scheduler = new DAGScheduler(
-        sc,
-        taskScheduler,
-        sc.listenerBus,
-        mapOutputTracker,
-        blockManagerMaster,
-        sc.env)
+      sc,
+      taskScheduler,
+      sc.listenerBus,
+      mapOutputTracker,
+      blockManagerMaster,
+      sc.env)
     dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(scheduler)
   }
 
-  after {
-    scheduler.stop()
+  override def afterEach(): Unit = {
+    try {
+      scheduler.stop()
+    } finally {
+      super.afterEach()
+    }
   }
 
   override def afterAll() {
@@ -242,26 +247,31 @@ class DAGSchedulerSuite
    * directly through CompletionEvents.
    */
   private val jobComputeFunc = (context: TaskContext, it: Iterator[(_)]) =>
-     it.next.asInstanceOf[Tuple2[_, _]]._1
+    it.next.asInstanceOf[Tuple2[_, _]]._1
 
   /** Send the given CompletionEvent messages for the tasks in the TaskSet. */
   private def complete(taskSet: TaskSet, results: Seq[(TaskEndReason, Any)]) {
     assert(taskSet.tasks.size >= results.size)
     for ((result, i) <- results.zipWithIndex) {
       if (i < taskSet.tasks.size) {
-        runEvent(CompletionEvent(
-          taskSet.tasks(i), result._1, result._2, null, createFakeTaskInfo(), null))
+        runEvent(makeCompletionEvent(taskSet.tasks(i), result._1, result._2))
       }
     }
   }
 
-  private def completeWithAccumulator(accumId: Long, taskSet: TaskSet,
-                                      results: Seq[(TaskEndReason, Any)]) {
+  private def completeWithAccumulator(
+      accumId: Long,
+      taskSet: TaskSet,
+      results: Seq[(TaskEndReason, Any)]) {
     assert(taskSet.tasks.size >= results.size)
     for ((result, i) <- results.zipWithIndex) {
       if (i < taskSet.tasks.size) {
-        runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2,
-          Map[Long, Any]((accumId, 1)), createFakeTaskInfo(), null))
+        runEvent(makeCompletionEvent(
+          taskSet.tasks(i),
+          result._1,
+          result._2,
+          Seq(new AccumulableInfo(
+            accumId, Some(""), Some(1), None, internal = false, countFailedValues = false))))
       }
     }
   }
@@ -338,9 +348,12 @@ class DAGSchedulerSuite
   }
 
   test("equals and hashCode AccumulableInfo") {
-    val accInfo1 = new AccumulableInfo(1, " Accumulable " + 1, Some("delta" + 1), "val" + 1, true)
-    val accInfo2 = new AccumulableInfo(1, " Accumulable " + 1, Some("delta" + 1), "val" + 1, false)
-    val accInfo3 = new AccumulableInfo(1, " Accumulable " + 1, Some("delta" + 1), "val" + 1, false)
+    val accInfo1 = new AccumulableInfo(
+      1, Some("a1"), Some("delta1"), Some("val1"), internal = true, countFailedValues = false)
+    val accInfo2 = new AccumulableInfo(
+      1, Some("a1"), Some("delta1"), Some("val1"), internal = false, countFailedValues = false)
+    val accInfo3 = new AccumulableInfo(
+      1, Some("a1"), Some("delta1"), Some("val1"), internal = false, countFailedValues = false)
     assert(accInfo1 !== accInfo2)
     assert(accInfo2 === accInfo3)
     assert(accInfo2.hashCode() === accInfo3.hashCode())
@@ -464,7 +477,7 @@ class DAGSchedulerSuite
       override def defaultParallelism(): Int = 2
       override def executorHeartbeatReceived(
           execId: String,
-          taskMetrics: Array[(Long, TaskMetrics)],
+          accumUpdates: Array[(Long, Seq[AccumulableInfo])],
           blockManagerId: BlockManagerId): Boolean = true
       override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {}
       override def applicationAttemptId(): Option[String] = None
@@ -499,8 +512,8 @@ class DAGSchedulerSuite
     val reduceRdd = new MyRDD(sc, 1, List(shuffleDep), tracker = mapOutputTracker)
     submit(reduceRdd, Array(0))
     complete(taskSets(0), Seq(
-        (Success, makeMapStatus("hostA", 1)),
-        (Success, makeMapStatus("hostB", 1))))
+      (Success, makeMapStatus("hostA", 1)),
+      (Success, makeMapStatus("hostB", 1))))
     assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet ===
       HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
     complete(taskSets(1), Seq((Success, 42)))
@@ -515,12 +528,12 @@ class DAGSchedulerSuite
     val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker)
     submit(reduceRdd, Array(0, 1))
     complete(taskSets(0), Seq(
-        (Success, makeMapStatus("hostA", reduceRdd.partitions.length)),
-        (Success, makeMapStatus("hostB", reduceRdd.partitions.length))))
+      (Success, makeMapStatus("hostA", reduceRdd.partitions.length)),
+      (Success, makeMapStatus("hostB", reduceRdd.partitions.length))))
     // the 2nd ResultTask failed
     complete(taskSets(1), Seq(
-        (Success, 42),
-        (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), null)))
+      (Success, 42),
+      (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), null)))
     // this will get called
     // blockManagerMaster.removeExecutor("exec-hostA")
     // ask the scheduler to try it again
@@ -829,23 +842,17 @@ class DAGSchedulerSuite
       HashSet("hostA", "hostB"))
 
     // The first result task fails, with a fetch failure for the output from the first mapper.
-    runEvent(CompletionEvent(
+    runEvent(makeCompletionEvent(
       taskSets(1).tasks(0),
       FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"),
-      null,
-      Map[Long, Any](),
-      createFakeTaskInfo(),
       null))
     sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
     assert(sparkListener.failedStages.contains(1))
 
     // The second ResultTask fails, with a fetch failure for the output from the second mapper.
-    runEvent(CompletionEvent(
+    runEvent(makeCompletionEvent(
       taskSets(1).tasks(0),
       FetchFailed(makeBlockManagerId("hostA"), shuffleId, 1, 1, "ignored"),
-      null,
-      Map[Long, Any](),
-      createFakeTaskInfo(),
       null))
     // The SparkListener should not receive redundant failure events.
     sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
@@ -882,12 +889,9 @@ class DAGSchedulerSuite
       HashSet("hostA", "hostB"))
 
     // The first result task fails, with a fetch failure for the output from the first mapper.
-    runEvent(CompletionEvent(
+    runEvent(makeCompletionEvent(
       taskSets(1).tasks(0),
       FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"),
-      null,
-      Map[Long, Any](),
-      createFakeTaskInfo(),
       null))
     sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
     assert(sparkListener.failedStages.contains(1))
@@ -900,12 +904,9 @@ class DAGSchedulerSuite
     assert(countSubmittedMapStageAttempts() === 2)
 
     // The second ResultTask fails, with a fetch failure for the output from the second mapper.
-    runEvent(CompletionEvent(
+    runEvent(makeCompletionEvent(
       taskSets(1).tasks(1),
       FetchFailed(makeBlockManagerId("hostB"), shuffleId, 1, 1, "ignored"),
-      null,
-      Map[Long, Any](),
-      createFakeTaskInfo(),
       null))
 
     // Another ResubmitFailedStages event should not result in another attempt for the map
@@ -920,11 +921,11 @@ class DAGSchedulerSuite
   }
 
   /**
-    * This tests the case where a late FetchFailed comes in after the map stage has finished getting
-    * retried and a new reduce stage starts running.
-    */
+   * This tests the case where a late FetchFailed comes in after the map stage has finished getting
+   * retried and a new reduce stage starts running.
+   */
   test("extremely late fetch failures don't cause multiple concurrent attempts for " +
-      "the same stage") {
+    "the same stage") {
     val shuffleMapRdd = new MyRDD(sc, 2, Nil)
     val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2))
     val shuffleId = shuffleDep.shuffleId
@@ -952,12 +953,9 @@ class DAGSchedulerSuite
     assert(countSubmittedReduceStageAttempts() === 1)
 
     // The first result task fails, with a fetch failure for the output from the first mapper.
-    runEvent(CompletionEvent(
+    runEvent(makeCompletionEvent(
       taskSets(1).tasks(0),
       FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"),
-      null,
-      Map[Long, Any](),
-      createFakeTaskInfo(),
       null))
 
     // Trigger resubmission of the failed map stage and finish the re-started map task.
@@ -971,12 +969,9 @@ class DAGSchedulerSuite
     assert(countSubmittedReduceStageAttempts() === 2)
 
     // A late FetchFailed arrives from the second task in the original reduce stage.
-    runEvent(CompletionEvent(
+    runEvent(makeCompletionEvent(
       taskSets(1).tasks(1),
       FetchFailed(makeBlockManagerId("hostB"), shuffleId, 1, 1, "ignored"),
-      null,
-      Map[Long, Any](),
-      createFakeTaskInfo(),
       null))
 
     // Running ResubmitFailedStages shouldn't result in any more attempts for the map stage, because
@@ -1007,48 +1002,36 @@ class DAGSchedulerSuite
     assert(shuffleStage.numAvailableOutputs === 0)
 
     // should be ignored for being too old
-    runEvent(CompletionEvent(
+    runEvent(makeCompletionEvent(
       taskSet.tasks(0),
       Success,
-      makeMapStatus("hostA", reduceRdd.partitions.size),
-      null,
-      createFakeTaskInfo(),
-      null))
+      makeMapStatus("hostA", reduceRdd.partitions.size)))
     assert(shuffleStage.numAvailableOutputs === 0)
 
     // should work because it's a non-failed host (so the available map outputs will increase)
-    runEvent(CompletionEvent(
+    runEvent(makeCompletionEvent(
       taskSet.tasks(0),
       Success,
-      makeMapStatus("hostB", reduceRdd.partitions.size),
-      null,
-      createFakeTaskInfo(),
-      null))
+      makeMapStatus("hostB", reduceRdd.partitions.size)))
     assert(shuffleStage.numAvailableOutputs === 1)
 
     // should be ignored for being too old
-    runEvent(CompletionEvent(
+    runEvent(makeCompletionEvent(
       taskSet.tasks(0),
       Success,
-      makeMapStatus("hostA", reduceRdd.partitions.size),
-      null,
-      createFakeTaskInfo(),
-      null))
+      makeMapStatus("hostA", reduceRdd.partitions.size)))
     assert(shuffleStage.numAvailableOutputs === 1)
 
     // should work because it's a new epoch, which will increase the number of available map
     // outputs, and also finish the stage
     taskSet.tasks(1).epoch = newEpoch
-    runEvent(CompletionEvent(
+    runEvent(makeCompletionEvent(
       taskSet.tasks(1),
       Success,
-      makeMapStatus("hostA", reduceRdd.partitions.size),
-      null,
-      createFakeTaskInfo(),
-      null))
+      makeMapStatus("hostA", reduceRdd.partitions.size)))
     assert(shuffleStage.numAvailableOutputs === 2)
     assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet ===
-           HashSet(makeBlockManagerId("hostB"), makeBlockManagerId("hostA")))
+      HashSet(makeBlockManagerId("hostB"), makeBlockManagerId("hostA")))
 
     // finish the next stage normally, which completes the job
     complete(taskSets(1), Seq((Success, 42), (Success, 43)))
@@ -1140,12 +1123,9 @@ class DAGSchedulerSuite
 
     // then one executor dies, and a task fails in stage 1
     runEvent(ExecutorLost("exec-hostA"))
-    runEvent(CompletionEvent(
+    runEvent(makeCompletionEvent(
       taskSets(1).tasks(0),
       FetchFailed(null, firstShuffleId, 2, 0, "Fetch failed"),
-      null,
-      null,
-      createFakeTaskInfo(),
       null))
 
     // so we resubmit stage 0, which completes happily
@@ -1155,13 +1135,10 @@ class DAGSchedulerSuite
     assert(stage0Resubmit.stageAttemptId === 1)
     val task = stage0Resubmit.tasks(0)
     assert(task.partitionId === 2)
-    runEvent(CompletionEvent(
+    runEvent(makeCompletionEvent(
       task,
       Success,
-      makeMapStatus("hostC", shuffleMapRdd.partitions.length),
-      null,
-      createFakeTaskInfo(),
-      null))
+      makeMapStatus("hostC", shuffleMapRdd.partitions.length)))
 
     // now here is where things get tricky : we will now have a task set representing
     // the second attempt for stage 1, but we *also* have some tasks for the first attempt for
@@ -1174,28 +1151,19 @@ class DAGSchedulerSuite
     // we'll have some tasks finish from the first attempt, and some finish from the second attempt,
     // so that we actually have all stage outputs, though no attempt has completed all its
     // tasks
-    runEvent(CompletionEvent(
+    runEvent(makeCompletionEvent(
       taskSets(3).tasks(0),
       Success,
-      makeMapStatus("hostC", reduceRdd.partitions.length),
-      null,
-      createFakeTaskInfo(),
-      null))
-    runEvent(CompletionEvent(
+      makeMapStatus("hostC", reduceRdd.partitions.length)))
+    runEvent(makeCompletionEvent(
       taskSets(3).tasks(1),
       Success,
-      makeMapStatus("hostC", reduceRdd.partitions.length),
-      null,
-      createFakeTaskInfo(),
-      null))
+      makeMapStatus("hostC", reduceRdd.partitions.length)))
     // late task finish from the first attempt
-    runEvent(CompletionEvent(
+    runEvent(makeCompletionEvent(
       taskSets(1).tasks(2),
       Success,
-      makeMapStatus("hostB", reduceRdd.partitions.length),
-      null,
-      createFakeTaskInfo(),
-      null))
+      makeMapStatus("hostB", reduceRdd.partitions.length)))
 
     // What should happen now is that we submit stage 2.  However, we might not see an error
     // b/c of DAGScheduler's error handling (it tends to swallow errors and just log them).  But
@@ -1242,21 +1210,21 @@ class DAGSchedulerSuite
     submit(reduceRdd, Array(0))
 
     // complete some of the tasks from the first stage, on one host
-    runEvent(CompletionEvent(
-      taskSets(0).tasks(0), Success,
-      makeMapStatus("hostA", reduceRdd.partitions.length), null, createFakeTaskInfo(), null))
-    runEvent(CompletionEvent(
-      taskSets(0).tasks(1), Success,
-      makeMapStatus("hostA", reduceRdd.partitions.length), null, createFakeTaskInfo(), null))
+    runEvent(makeCompletionEvent(
+      taskSets(0).tasks(0),
+      Success,
+      makeMapStatus("hostA", reduceRdd.partitions.length)))
+    runEvent(makeCompletionEvent(
+      taskSets(0).tasks(1),
+      Success,
+      makeMapStatus("hostA", reduceRdd.partitions.length)))
 
     // now that host goes down
     runEvent(ExecutorLost("exec-hostA"))
 
     // so we resubmit those tasks
-    runEvent(CompletionEvent(
-      taskSets(0).tasks(0), Resubmitted, null, null, createFakeTaskInfo(), null))
-    runEvent(CompletionEvent(
-      taskSets(0).tasks(1), Resubmitted, null, null, createFakeTaskInfo(), null))
+    runEvent(makeCompletionEvent(taskSets(0).tasks(0), Resubmitted, null))
+    runEvent(makeCompletionEvent(taskSets(0).tasks(1), Resubmitted, null))
 
     // now complete everything on a different host
     complete(taskSets(0), Seq(
@@ -1449,12 +1417,12 @@ class DAGSchedulerSuite
     // DAGScheduler will immediately resubmit the stage after it appears to have no pending tasks
     // rather than marking it is as failed and waiting.
     complete(taskSets(0), Seq(
-        (Success, makeMapStatus("hostA", 1)),
-       (Success, makeMapStatus("hostB", 1))))
+      (Success, makeMapStatus("hostA", 1)),
+      (Success, makeMapStatus("hostB", 1))))
     // have hostC complete the resubmitted task
     complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1))))
     assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet ===
-           HashSet(makeBlockManagerId("hostC"), makeBlockManagerId("hostB")))
+      HashSet(makeBlockManagerId("hostC"), makeBlockManagerId("hostB")))
     complete(taskSets(2), Seq((Success, 42)))
     assert(results === Map(0 -> 42))
     assertDataStructuresEmpty()
@@ -1469,15 +1437,15 @@ class DAGSchedulerSuite
     submit(finalRdd, Array(0))
     // have the first stage complete normally
     complete(taskSets(0), Seq(
-        (Success, makeMapStatus("hostA", 2)),
-        (Success, makeMapStatus("hostB", 2))))
+      (Success, makeMapStatus("hostA", 2)),
+      (Success, makeMapStatus("hostB", 2))))
     // have the second stage complete normally
     complete(taskSets(1), Seq(
-        (Success, makeMapStatus("hostA", 1)),
-        (Success, makeMapStatus("hostC", 1))))
+      (Success, makeMapStatus("hostA", 1)),
+      (Success, makeMapStatus("hostC", 1))))
     // fail the third stage because hostA went down
     complete(taskSets(2), Seq(
-        (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0, "ignored"), null)))
+      (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0, "ignored"), null)))
     // TODO assert this:
     // blockManagerMaster.removeExecutor("exec-hostA")
     // have DAGScheduler try again
@@ -1500,15 +1468,15 @@ class DAGSchedulerSuite
     cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC"))
     // complete stage 0
     complete(taskSets(0), Seq(
-        (Success, makeMapStatus("hostA", 2)),
-        (Success, makeMapStatus("hostB", 2))))
+      (Success, makeMapStatus("hostA", 2)),
+      (Success, makeMapStatus("hostB", 2))))
     // complete stage 1
     complete(taskSets(1), Seq(
-        (Success, makeMapStatus("hostA", 1)),
-        (Success, makeMapStatus("hostB", 1))))
+      (Success, makeMapStatus("hostA", 1)),
+      (Success, makeMapStatus("hostB", 1))))
     // pretend stage 2 failed because hostA went down
     complete(taskSets(2), Seq(
-        (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0, "ignored"), null)))
+      (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0, "ignored"), null)))
     // TODO assert this:
     // blockManagerMaster.removeExecutor("exec-hostA")
     // DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun.
@@ -1606,6 +1574,28 @@ class DAGSchedulerSuite
     assertDataStructuresEmpty()
   }
 
+  test("accumulators are updated on exception failures") {
+    val acc1 = sc.accumulator(0L, "ingenieur")
+    val acc2 = sc.accumulator(0L, "boulanger")
+    val acc3 = sc.accumulator(0L, "agriculteur")
+    assert(Accumulators.get(acc1.id).isDefined)
+    assert(Accumulators.get(acc2.id).isDefined)
+    assert(Accumulators.get(acc3.id).isDefined)
+    val accInfo1 = new AccumulableInfo(
+      acc1.id, acc1.name, Some(15L), None, internal = false, countFailedValues = false)
+    val accInfo2 = new AccumulableInfo(
+      acc2.id, acc2.name, Some(13L), None, internal = false, countFailedValues = false)
+    val accInfo3 = new AccumulableInfo(
+      acc3.id, acc3.name, Some(18L), None, internal = false, countFailedValues = false)
+    val accumUpdates = Seq(accInfo1, accInfo2, accInfo3)
+    val exceptionFailure = new ExceptionFailure(new SparkException("fondue?"), accumUpdates)
+    submit(new MyRDD(sc, 1, Nil), Array(0))
+    runEvent(makeCompletionEvent(taskSets.head.tasks.head, exceptionFailure, "result"))
+    assert(Accumulators.get(acc1.id).get.value === 15L)
+    assert(Accumulators.get(acc2.id).get.value === 13L)
+    assert(Accumulators.get(acc3.id).get.value === 18L)
+  }
+
   test("reduce tasks should be placed locally with map output") {
     // Create an shuffleMapRdd with 1 partition
     val shuffleMapRdd = new MyRDD(sc, 1, Nil)
@@ -1614,9 +1604,9 @@ class DAGSchedulerSuite
     val reduceRdd = new MyRDD(sc, 1, List(shuffleDep), tracker = mapOutputTracker)
     submit(reduceRdd, Array(0))
     complete(taskSets(0), Seq(
-        (Success, makeMapStatus("hostA", 1))))
+      (Success, makeMapStatus("hostA", 1))))
     assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet ===
-           HashSet(makeBlockManagerId("hostA")))
+      HashSet(makeBlockManagerId("hostA")))
 
     // Reducer should run on the same host that map task ran
     val reduceTaskSet = taskSets(1)
@@ -1884,8 +1874,7 @@ class DAGSchedulerSuite
     submitMapStage(shuffleDep)
 
     val oldTaskSet = taskSets(0)
-    runEvent(CompletionEvent(oldTaskSet.tasks(0), Success, makeMapStatus("hostA", 2),
-      null, createFakeTaskInfo(), null))
+    runEvent(makeCompletionEvent(oldTaskSet.tasks(0), Success, makeMapStatus("hostA", 2)))
     assert(results.size === 0)    // Map stage job should not be complete yet
 
     // Pretend host A was lost
@@ -1895,23 +1884,19 @@ class DAGSchedulerSuite
     assert(newEpoch > oldEpoch)
 
     // Suppose we also get a completed event from task 1 on the same host; this should be ignored
-    runEvent(CompletionEvent(oldTaskSet.tasks(1), Success, makeMapStatus("hostA", 2),
-      null, createFakeTaskInfo(), null))
+    runEvent(makeCompletionEvent(oldTaskSet.tasks(1), Success, makeMapStatus("hostA", 2)))
     assert(results.size === 0)    // Map stage job should not be complete yet
 
     // A completion from another task should work because it's a non-failed host
-    runEvent(CompletionEvent(oldTaskSet.tasks(2), Success, makeMapStatus("hostB", 2),
-      null, createFakeTaskInfo(), null))
+    runEvent(makeCompletionEvent(oldTaskSet.tasks(2), Success, makeMapStatus("hostB", 2)))
     assert(results.size === 0)    // Map stage job should not be complete yet
 
     // Now complete tasks in the second task set
     val newTaskSet = taskSets(1)
     assert(newTaskSet.tasks.size === 2)     // Both tasks 0 and 1 were on on hostA
-    runEvent(CompletionEvent(newTaskSet.tasks(0), Success, makeMapStatus("hostB", 2),
-      null, createFakeTaskInfo(), null))
+    runEvent(makeCompletionEvent(newTaskSet.tasks(0), Success, makeMapStatus("hostB", 2)))
     assert(results.size === 0)    // Map stage job should not be complete yet
-    runEvent(CompletionEvent(newTaskSet.tasks(1), Success, makeMapStatus("hostB", 2),
-      null, createFakeTaskInfo(), null))
+    runEvent(makeCompletionEvent(newTaskSet.tasks(1), Success, makeMapStatus("hostB", 2)))
     assert(results.size === 1)    // Map stage job should now finally be complete
     assertDataStructuresEmpty()
 
@@ -1962,5 +1947,21 @@ class DAGSchedulerSuite
     info
   }
 
-}
+  private def makeCompletionEvent(
+      task: Task[_],
+      reason: TaskEndReason,
+      result: Any,
+      extraAccumUpdates: Seq[AccumulableInfo] = Seq.empty[AccumulableInfo],
+      taskInfo: TaskInfo = createFakeTaskInfo()): CompletionEvent = {
+    val accumUpdates = reason match {
+      case Success =>
+        task.initialAccumulators.map { a =>
+          new AccumulableInfo(a.id, a.name, Some(a.zero), None, a.isInternal, a.countFailedValues)
+        }
+      case ef: ExceptionFailure => ef.accumUpdates
+      case _ => Seq.empty[AccumulableInfo]
+    }
+    CompletionEvent(task, reason, result, accumUpdates ++ extraAccumUpdates, taskInfo)
+  }
 
+}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala
index 761e82e6cf1cecb996ff1e421d5d69fcbf45af02..35215c15ea805c43617870579101fa76a5ba0103 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala
@@ -26,7 +26,7 @@ import org.scalatest.BeforeAndAfter
 import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
 import org.apache.spark.deploy.SparkHadoopUtil
 import org.apache.spark.io.CompressionCodec
-import org.apache.spark.util.{JsonProtocol, Utils}
+import org.apache.spark.util.{JsonProtocol, JsonProtocolSuite, Utils}
 
 /**
  * Test whether ReplayListenerBus replays events from logs correctly.
@@ -131,7 +131,11 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter {
     assert(sc.eventLogger.isDefined)
     val originalEvents = sc.eventLogger.get.loggedEvents
     val replayedEvents = eventMonster.loggedEvents
-    originalEvents.zip(replayedEvents).foreach { case (e1, e2) => assert(e1 === e2) }
+    originalEvents.zip(replayedEvents).foreach { case (e1, e2) =>
+      // Don't compare the JSON here because accumulators in StageInfo may be out of order
+      JsonProtocolSuite.assertEquals(
+        JsonProtocol.sparkEventFromJson(e1), JsonProtocol.sparkEventFromJson(e2))
+    }
   }
 
   /**
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index e5ec44a9f3b6b77ff330ff6d9c57a831bb6c0b3c..b3bb86db10a32fa77ec589f998f38e1e88b969b1 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -22,6 +22,8 @@ import org.mockito.Mockito._
 import org.scalatest.BeforeAndAfter
 
 import org.apache.spark._
+import org.apache.spark.executor.TaskMetricsSuite
+import org.apache.spark.memory.TaskMemoryManager
 import org.apache.spark.metrics.source.JvmSource
 import org.apache.spark.network.util.JavaUtils
 import org.apache.spark.rdd.RDD
@@ -57,8 +59,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
     val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
     val func = (c: TaskContext, i: Iterator[String]) => i.next()
     val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func))))
-    val task = new ResultTask[String, String](
-      0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, Seq.empty)
+    val task = new ResultTask[String, String](0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0)
     intercept[RuntimeException] {
       task.run(0, 0, null)
     }
@@ -97,6 +98,57 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
     }.collect()
     assert(attemptIdsWithFailedTask.toSet === Set(0, 1))
   }
+
+  test("accumulators are updated on exception failures") {
+    // This means use 1 core and 4 max task failures
+    sc = new SparkContext("local[1,4]", "test")
+    val param = AccumulatorParam.LongAccumulatorParam
+    // Create 2 accumulators, one that counts failed values and another that doesn't
+    val acc1 = new Accumulator(0L, param, Some("x"), internal = false, countFailedValues = true)
+    val acc2 = new Accumulator(0L, param, Some("y"), internal = false, countFailedValues = false)
+    // Fail first 3 attempts of every task. This means each task should be run 4 times.
+    sc.parallelize(1 to 10, 10).map { i =>
+      acc1 += 1
+      acc2 += 1
+      if (TaskContext.get.attemptNumber() <= 2) {
+        throw new Exception("you did something wrong")
+      } else {
+        0
+      }
+    }.count()
+    // The one that counts failed values should be 4x the one that didn't,
+    // since we ran each task 4 times
+    assert(Accumulators.get(acc1.id).get.value === 40L)
+    assert(Accumulators.get(acc2.id).get.value === 10L)
+  }
+
+  test("failed tasks collect only accumulators whose values count during failures") {
+    sc = new SparkContext("local", "test")
+    val param = AccumulatorParam.LongAccumulatorParam
+    val acc1 = new Accumulator(0L, param, Some("x"), internal = false, countFailedValues = true)
+    val acc2 = new Accumulator(0L, param, Some("y"), internal = false, countFailedValues = false)
+    val initialAccums = InternalAccumulator.create()
+    // Create a dummy task. We won't end up running this; we just want to collect
+    // accumulator updates from it.
+    val task = new Task[Int](0, 0, 0, Seq.empty[Accumulator[_]]) {
+      context = new TaskContextImpl(0, 0, 0L, 0,
+        new TaskMemoryManager(SparkEnv.get.memoryManager, 0L),
+        SparkEnv.get.metricsSystem,
+        initialAccums)
+      context.taskMetrics.registerAccumulator(acc1)
+      context.taskMetrics.registerAccumulator(acc2)
+      override def runTask(tc: TaskContext): Int = 0
+    }
+    // First, simulate task success. This should give us all the accumulators.
+    val accumUpdates1 = task.collectAccumulatorUpdates(taskFailed = false)
+    val accumUpdates2 = (initialAccums ++ Seq(acc1, acc2)).map(TaskMetricsSuite.makeInfo)
+    TaskMetricsSuite.assertUpdatesEquals(accumUpdates1, accumUpdates2)
+    // Now, simulate task failures. This should give us only the accums that count failed values.
+    val accumUpdates3 = task.collectAccumulatorUpdates(taskFailed = true)
+    val accumUpdates4 = (initialAccums ++ Seq(acc1)).map(TaskMetricsSuite.makeInfo)
+    TaskMetricsSuite.assertUpdatesEquals(accumUpdates3, accumUpdates4)
+  }
+
 }
 
 private object TaskContextSuite {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
index cc2557c2f1df29c463eb14a6101186ee453ec174..b5385c11a926e2132002ede25c6590b9342bbe58 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
@@ -21,10 +21,15 @@ import java.io.File
 import java.net.URL
 import java.nio.ByteBuffer
 
+import scala.collection.mutable.ArrayBuffer
 import scala.concurrent.duration._
 import scala.language.postfixOps
 import scala.util.control.NonFatal
 
+import com.google.common.util.concurrent.MoreExecutors
+import org.mockito.ArgumentCaptor
+import org.mockito.Matchers.{any, anyLong}
+import org.mockito.Mockito.{spy, times, verify}
 import org.scalatest.BeforeAndAfter
 import org.scalatest.concurrent.Eventually._
 
@@ -33,13 +38,14 @@ import org.apache.spark.storage.TaskResultBlockId
 import org.apache.spark.TestUtils.JavaSourceFromString
 import org.apache.spark.util.{MutableURLClassLoader, RpcUtils, Utils}
 
+
 /**
  * Removes the TaskResult from the BlockManager before delegating to a normal TaskResultGetter.
  *
  * Used to test the case where a BlockManager evicts the task result (or dies) before the
  * TaskResult is retrieved.
  */
-class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedulerImpl)
+private class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedulerImpl)
   extends TaskResultGetter(sparkEnv, scheduler) {
   var removedResult = false
 
@@ -72,6 +78,31 @@ class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedule
   }
 }
 
+
+/**
+ * A [[TaskResultGetter]] that stores the [[DirectTaskResult]]s it receives from executors
+ * _before_ modifying the results in any way.
+ */
+private class MyTaskResultGetter(env: SparkEnv, scheduler: TaskSchedulerImpl)
+  extends TaskResultGetter(env, scheduler) {
+
+  // Use the current thread so we can access its results synchronously
+  protected override val getTaskResultExecutor = MoreExecutors.sameThreadExecutor()
+
+  // DirectTaskResults that we receive from the executors
+  private val _taskResults = new ArrayBuffer[DirectTaskResult[_]]
+
+  def taskResults: Seq[DirectTaskResult[_]] = _taskResults
+
+  override def enqueueSuccessfulTask(tsm: TaskSetManager, tid: Long, data: ByteBuffer): Unit = {
+    // work on a copy since the super class still needs to use the buffer
+    val newBuffer = data.duplicate()
+    _taskResults += env.closureSerializer.newInstance().deserialize[DirectTaskResult[_]](newBuffer)
+    super.enqueueSuccessfulTask(tsm, tid, data)
+  }
+}
+
+
 /**
  * Tests related to handling task results (both direct and indirect).
  */
@@ -182,5 +213,39 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local
       Thread.currentThread.setContextClassLoader(originalClassLoader)
     }
   }
+
+  test("task result size is set on the driver, not the executors") {
+    import InternalAccumulator._
+
+    // Set up custom TaskResultGetter and TaskSchedulerImpl spy
+    sc = new SparkContext("local", "test", conf)
+    val scheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl]
+    val spyScheduler = spy(scheduler)
+    val resultGetter = new MyTaskResultGetter(sc.env, spyScheduler)
+    val newDAGScheduler = new DAGScheduler(sc, spyScheduler)
+    scheduler.taskResultGetter = resultGetter
+    sc.dagScheduler = newDAGScheduler
+    sc.taskScheduler = spyScheduler
+    sc.taskScheduler.setDAGScheduler(newDAGScheduler)
+
+    // Just run 1 task and capture the corresponding DirectTaskResult
+    sc.parallelize(1 to 1, 1).count()
+    val captor = ArgumentCaptor.forClass(classOf[DirectTaskResult[_]])
+    verify(spyScheduler, times(1)).handleSuccessfulTask(any(), anyLong(), captor.capture())
+
+    // When a task finishes, the executor sends a serialized DirectTaskResult to the driver
+    // without setting the result size so as to avoid serializing the result again. Instead,
+    // the result size is set later in TaskResultGetter on the driver before passing the
+    // DirectTaskResult on to TaskSchedulerImpl. In this test, we capture the DirectTaskResult
+    // before and after the result size is set.
+    assert(resultGetter.taskResults.size === 1)
+    val resBefore = resultGetter.taskResults.head
+    val resAfter = captor.getValue
+    val resSizeBefore = resBefore.accumUpdates.find(_.name == Some(RESULT_SIZE)).flatMap(_.update)
+    val resSizeAfter = resAfter.accumUpdates.find(_.name == Some(RESULT_SIZE)).flatMap(_.update)
+    assert(resSizeBefore.exists(_ == 0L))
+    assert(resSizeAfter.exists(_.toString.toLong > 0L))
+  }
+
 }
 
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index ecc18fc6e15b4da6251d9647e03a8e0147b95805..a2e74365641a6cfbb64eac281355c172c527a171 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -24,7 +24,6 @@ import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark._
-import org.apache.spark.executor.TaskMetrics
 import org.apache.spark.util.ManualClock
 
 class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler)
@@ -38,9 +37,8 @@ class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler)
       task: Task[_],
       reason: TaskEndReason,
       result: Any,
-      accumUpdates: Map[Long, Any],
-      taskInfo: TaskInfo,
-      taskMetrics: TaskMetrics) {
+      accumUpdates: Seq[AccumulableInfo],
+      taskInfo: TaskInfo) {
     taskScheduler.endedTasks(taskInfo.index) = reason
   }
 
@@ -167,14 +165,17 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
     val taskSet = FakeTask.createTaskSet(1)
     val clock = new ManualClock
     val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock)
+    val accumUpdates = taskSet.tasks.head.initialAccumulators.map { a =>
+      new AccumulableInfo(a.id, a.name, Some(0L), None, a.isInternal, a.countFailedValues)
+    }
 
     // Offer a host with NO_PREF as the constraint,
     // we should get a nopref task immediately since that's what we only have
-    var taskOption = manager.resourceOffer("exec1", "host1", NO_PREF)
+    val taskOption = manager.resourceOffer("exec1", "host1", NO_PREF)
     assert(taskOption.isDefined)
 
     // Tell it the task has finished
-    manager.handleSuccessfulTask(0, createTaskResult(0))
+    manager.handleSuccessfulTask(0, createTaskResult(0, accumUpdates))
     assert(sched.endedTasks(0) === Success)
     assert(sched.finishedManagers.contains(manager))
   }
@@ -184,10 +185,15 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
     val sched = new FakeTaskScheduler(sc, ("exec1", "host1"))
     val taskSet = FakeTask.createTaskSet(3)
     val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES)
+    val accumUpdatesByTask: Array[Seq[AccumulableInfo]] = taskSet.tasks.map { task =>
+      task.initialAccumulators.map { a =>
+        new AccumulableInfo(a.id, a.name, Some(0L), None, a.isInternal, a.countFailedValues)
+      }
+    }
 
     // First three offers should all find tasks
     for (i <- 0 until 3) {
-      var taskOption = manager.resourceOffer("exec1", "host1", NO_PREF)
+      val taskOption = manager.resourceOffer("exec1", "host1", NO_PREF)
       assert(taskOption.isDefined)
       val task = taskOption.get
       assert(task.executorId === "exec1")
@@ -198,14 +204,14 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
     assert(manager.resourceOffer("exec1", "host1", NO_PREF) === None)
 
     // Finish the first two tasks
-    manager.handleSuccessfulTask(0, createTaskResult(0))
-    manager.handleSuccessfulTask(1, createTaskResult(1))
+    manager.handleSuccessfulTask(0, createTaskResult(0, accumUpdatesByTask(0)))
+    manager.handleSuccessfulTask(1, createTaskResult(1, accumUpdatesByTask(1)))
     assert(sched.endedTasks(0) === Success)
     assert(sched.endedTasks(1) === Success)
     assert(!sched.finishedManagers.contains(manager))
 
     // Finish the last task
-    manager.handleSuccessfulTask(2, createTaskResult(2))
+    manager.handleSuccessfulTask(2, createTaskResult(2, accumUpdatesByTask(2)))
     assert(sched.endedTasks(2) === Success)
     assert(sched.finishedManagers.contains(manager))
   }
@@ -620,7 +626,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
 
     // multiple 1k result
     val r = sc.makeRDD(0 until 10, 10).map(genBytes(1024)).collect()
-    assert(10 === r.size )
+    assert(10 === r.size)
 
     // single 10M result
     val thrown = intercept[SparkException] {sc.makeRDD(genBytes(10 << 20)(0), 1).collect()}
@@ -761,7 +767,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
     // Regression test for SPARK-2931
     sc = new SparkContext("local", "test")
     val sched = new FakeTaskScheduler(sc,
-        ("execA", "host1"), ("execB", "host2"), ("execC", "host3"))
+      ("execA", "host1"), ("execB", "host2"), ("execC", "host3"))
     val taskSet = FakeTask.createTaskSet(3,
       Seq(TaskLocation("host1")),
       Seq(TaskLocation("host2")),
@@ -786,8 +792,10 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
     assert(TaskLocation("executor_host1_3") === ExecutorCacheTaskLocation("host1", "3"))
   }
 
-  def createTaskResult(id: Int): DirectTaskResult[Int] = {
+  private def createTaskResult(
+      id: Int,
+      accumUpdates: Seq[AccumulableInfo] = Seq.empty[AccumulableInfo]): DirectTaskResult[Int] = {
     val valueSer = SparkEnv.get.serializer.newInstance()
-    new DirectTaskResult[Int](valueSer.serialize(id), mutable.Map.empty, new TaskMetrics)
+    new DirectTaskResult[Int](valueSer.serialize(id), accumUpdates)
   }
 }
diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala
index 86699e7f56953817c3966e796d67cf58d7141225..b83ffa3282e4d6ea2c829d980e0ec2c4cb348b55 100644
--- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala
@@ -31,6 +31,8 @@ import org.apache.spark.ui.scope.RDDOperationGraphListener
 
 class StagePageSuite extends SparkFunSuite with LocalSparkContext {
 
+  private val peakExecutionMemory = 10
+
   test("peak execution memory only displayed if unsafe is enabled") {
     val unsafeConf = "spark.sql.unsafe.enabled"
     val conf = new SparkConf(false).set(unsafeConf, "true")
@@ -52,7 +54,7 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext {
     val conf = new SparkConf(false).set(unsafeConf, "true")
     val html = renderStagePage(conf).toString().toLowerCase
     // verify min/25/50/75/max show task value not cumulative values
-    assert(html.contains("<td>10.0 b</td>" * 5))
+    assert(html.contains(s"<td>$peakExecutionMemory.0 b</td>" * 5))
   }
 
   /**
@@ -79,14 +81,13 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext {
     (1 to 2).foreach {
       taskId =>
         val taskInfo = new TaskInfo(taskId, taskId, 0, 0, "0", "localhost", TaskLocality.ANY, false)
-        val peakExecutionMemory = 10
-        taskInfo.accumulables += new AccumulableInfo(0, InternalAccumulator.PEAK_EXECUTION_MEMORY,
-          Some(peakExecutionMemory.toString), (peakExecutionMemory * taskId).toString, true)
         jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo))
         jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo))
         taskInfo.markSuccessful()
+        val taskMetrics = TaskMetrics.empty
+        taskMetrics.incPeakExecutionMemory(peakExecutionMemory)
         jobListener.onTaskEnd(
-          SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, TaskMetrics.empty))
+          SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, taskMetrics))
     }
     jobListener.onStageCompleted(SparkListenerStageCompleted(stageInfo))
     page.render(request)
diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
index 607617cbe91caedf0e46267d8c3d6b408e4f2d3d..18a16a25bfac5f7720036e07a418b3b021697791 100644
--- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
@@ -240,7 +240,7 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with
     val taskFailedReasons = Seq(
       Resubmitted,
       new FetchFailed(null, 0, 0, 0, "ignored"),
-      ExceptionFailure("Exception", "description", null, null, None, None),
+      ExceptionFailure("Exception", "description", null, null, None),
       TaskResultLost,
       TaskKilled,
       ExecutorLostFailure("0", true, Some("Induced failure")),
@@ -269,20 +269,22 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with
     val execId = "exe-1"
 
     def makeTaskMetrics(base: Int): TaskMetrics = {
-      val taskMetrics = new TaskMetrics()
-      taskMetrics.setExecutorRunTime(base + 4)
-      taskMetrics.incDiskBytesSpilled(base + 5)
-      taskMetrics.incMemoryBytesSpilled(base + 6)
+      val accums = InternalAccumulator.create()
+      accums.foreach(Accumulators.register)
+      val taskMetrics = new TaskMetrics(accums)
       val shuffleReadMetrics = taskMetrics.registerTempShuffleReadMetrics()
+      val shuffleWriteMetrics = taskMetrics.registerShuffleWriteMetrics()
+      val inputMetrics = taskMetrics.registerInputMetrics(DataReadMethod.Hadoop)
+      val outputMetrics = taskMetrics.registerOutputMetrics(DataWriteMethod.Hadoop)
       shuffleReadMetrics.incRemoteBytesRead(base + 1)
       shuffleReadMetrics.incLocalBytesRead(base + 9)
       shuffleReadMetrics.incRemoteBlocksFetched(base + 2)
       taskMetrics.mergeShuffleReadMetrics()
-      val shuffleWriteMetrics = taskMetrics.registerShuffleWriteMetrics()
       shuffleWriteMetrics.incBytesWritten(base + 3)
-      val inputMetrics = taskMetrics.registerInputMetrics(DataReadMethod.Hadoop)
-      inputMetrics.incBytesRead(base + 7)
-      val outputMetrics = taskMetrics.registerOutputMetrics(DataWriteMethod.Hadoop)
+      taskMetrics.setExecutorRunTime(base + 4)
+      taskMetrics.incDiskBytesSpilled(base + 5)
+      taskMetrics.incMemoryBytesSpilled(base + 6)
+      inputMetrics.setBytesRead(base + 7)
       outputMetrics.setBytesWritten(base + 8)
       taskMetrics
     }
@@ -300,9 +302,9 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with
     listener.onTaskStart(SparkListenerTaskStart(1, 0, makeTaskInfo(1237L)))
 
     listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate(execId, Array(
-      (1234L, 0, 0, makeTaskMetrics(0)),
-      (1235L, 0, 0, makeTaskMetrics(100)),
-      (1236L, 1, 0, makeTaskMetrics(200)))))
+      (1234L, 0, 0, makeTaskMetrics(0).accumulatorUpdates()),
+      (1235L, 0, 0, makeTaskMetrics(100).accumulatorUpdates()),
+      (1236L, 1, 0, makeTaskMetrics(200).accumulatorUpdates()))))
 
     var stage0Data = listener.stageIdToData.get((0, 0)).get
     var stage1Data = listener.stageIdToData.get((1, 0)).get
diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
index e5ca2de4ad53772236108ad47ef2b16bdacbc1c0..57021d1d3d52844060024baf0b7ec10d13fa3d31 100644
--- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
@@ -22,6 +22,10 @@ import java.util.Properties
 import scala.collection.Map
 
 import org.json4s.jackson.JsonMethods._
+import org.json4s.JsonAST.{JArray, JInt, JString, JValue}
+import org.json4s.JsonDSL._
+import org.scalatest.Assertions
+import org.scalatest.exceptions.TestFailedException
 
 import org.apache.spark._
 import org.apache.spark.executor._
@@ -32,12 +36,7 @@ import org.apache.spark.shuffle.MetadataFetchFailedException
 import org.apache.spark.storage._
 
 class JsonProtocolSuite extends SparkFunSuite {
-
-  val jobSubmissionTime = 1421191042750L
-  val jobCompletionTime = 1421191296660L
-
-  val executorAddedTime = 1421458410000L
-  val executorRemovedTime = 1421458922000L
+  import JsonProtocolSuite._
 
   test("SparkListenerEvent") {
     val stageSubmitted =
@@ -82,9 +81,13 @@ class JsonProtocolSuite extends SparkFunSuite {
     val executorAdded = SparkListenerExecutorAdded(executorAddedTime, "exec1",
       new ExecutorInfo("Hostee.awesome.com", 11, logUrlMap))
     val executorRemoved = SparkListenerExecutorRemoved(executorRemovedTime, "exec2", "test reason")
-    val executorMetricsUpdate = SparkListenerExecutorMetricsUpdate("exec3", Seq(
-      (1L, 2, 3, makeTaskMetrics(300L, 400L, 500L, 600L, 700, 800,
-        hasHadoopInput = true, hasOutput = true))))
+    val executorMetricsUpdate = {
+      // Use custom accum ID for determinism
+      val accumUpdates =
+        makeTaskMetrics(300L, 400L, 500L, 600L, 700, 800, hasHadoopInput = true, hasOutput = true)
+          .accumulatorUpdates().zipWithIndex.map { case (a, i) => a.copy(id = i) }
+      SparkListenerExecutorMetricsUpdate("exec3", Seq((1L, 2, 3, accumUpdates)))
+    }
 
     testEvent(stageSubmitted, stageSubmittedJsonString)
     testEvent(stageCompleted, stageCompletedJsonString)
@@ -142,7 +145,7 @@ class JsonProtocolSuite extends SparkFunSuite {
       "Some exception")
     val fetchMetadataFailed = new MetadataFetchFailedException(17,
       19, "metadata Fetch failed exception").toTaskEndReason
-    val exceptionFailure = new ExceptionFailure(exception, None)
+    val exceptionFailure = new ExceptionFailure(exception, Seq.empty[AccumulableInfo])
     testTaskEndReason(Success)
     testTaskEndReason(Resubmitted)
     testTaskEndReason(fetchFailed)
@@ -166,9 +169,8 @@ class JsonProtocolSuite extends SparkFunSuite {
    |  Backward compatibility tests  |
    * ============================== */
 
-  test("ExceptionFailure backward compatibility") {
-    val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, null,
-      None, None)
+  test("ExceptionFailure backward compatibility: full stack trace") {
+    val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, null, None)
     val oldEvent = JsonProtocol.taskEndReasonToJson(exceptionFailure)
       .removeField({ _._1 == "Full Stack Trace" })
     assertEquals(exceptionFailure, JsonProtocol.taskEndReasonFromJson(oldEvent))
@@ -273,14 +275,13 @@ class JsonProtocolSuite extends SparkFunSuite {
     assert(expectedFetchFailed === JsonProtocol.taskEndReasonFromJson(oldEvent))
   }
 
-  test("ShuffleReadMetrics: Local bytes read and time taken backwards compatibility") {
-    // Metrics about local shuffle bytes read and local read time were added in 1.3.1.
+  test("ShuffleReadMetrics: Local bytes read backwards compatibility") {
+    // Metrics about local shuffle bytes read were added in 1.3.1.
     val metrics = makeTaskMetrics(1L, 2L, 3L, 4L, 5, 6,
       hasHadoopInput = false, hasOutput = false, hasRecords = false)
     assert(metrics.shuffleReadMetrics.nonEmpty)
     val newJson = JsonProtocol.taskMetricsToJson(metrics)
     val oldJson = newJson.removeField { case (field, _) => field == "Local Bytes Read" }
-      .removeField { case (field, _) => field == "Local Read Time" }
     val newMetrics = JsonProtocol.taskMetricsFromJson(oldJson)
     assert(newMetrics.shuffleReadMetrics.get.localBytesRead == 0)
   }
@@ -371,22 +372,76 @@ class JsonProtocolSuite extends SparkFunSuite {
   }
 
   test("AccumulableInfo backward compatibility") {
-    // "Internal" property of AccumulableInfo were added after 1.5.1.
-    val accumulableInfo = makeAccumulableInfo(1)
+    // "Internal" property of AccumulableInfo was added in 1.5.1
+    val accumulableInfo = makeAccumulableInfo(1, internal = true, countFailedValues = true)
     val oldJson = JsonProtocol.accumulableInfoToJson(accumulableInfo)
       .removeField({ _._1 == "Internal" })
     val oldInfo = JsonProtocol.accumulableInfoFromJson(oldJson)
-    assert(false === oldInfo.internal)
+    assert(!oldInfo.internal)
+    // "Count Failed Values" property of AccumulableInfo was added in 2.0.0
+    val oldJson2 = JsonProtocol.accumulableInfoToJson(accumulableInfo)
+      .removeField({ _._1 == "Count Failed Values" })
+    val oldInfo2 = JsonProtocol.accumulableInfoFromJson(oldJson2)
+    assert(!oldInfo2.countFailedValues)
+  }
+
+  test("ExceptionFailure backward compatibility: accumulator updates") {
+    // "Task Metrics" was replaced with "Accumulator Updates" in 2.0.0. For older event logs,
+    // we should still be able to fallback to constructing the accumulator updates from the
+    // "Task Metrics" field, if it exists.
+    val tm = makeTaskMetrics(1L, 2L, 3L, 4L, 5, 6, hasHadoopInput = true, hasOutput = true)
+    val tmJson = JsonProtocol.taskMetricsToJson(tm)
+    val accumUpdates = tm.accumulatorUpdates()
+    val exception = new SparkException("sentimental")
+    val exceptionFailure = new ExceptionFailure(exception, accumUpdates)
+    val exceptionFailureJson = JsonProtocol.taskEndReasonToJson(exceptionFailure)
+    val tmFieldJson: JValue = "Task Metrics" -> tmJson
+    val oldExceptionFailureJson: JValue =
+      exceptionFailureJson.removeField { _._1 == "Accumulator Updates" }.merge(tmFieldJson)
+    val oldExceptionFailure =
+      JsonProtocol.taskEndReasonFromJson(oldExceptionFailureJson).asInstanceOf[ExceptionFailure]
+    assert(exceptionFailure.className === oldExceptionFailure.className)
+    assert(exceptionFailure.description === oldExceptionFailure.description)
+    assertSeqEquals[StackTraceElement](
+      exceptionFailure.stackTrace, oldExceptionFailure.stackTrace, assertStackTraceElementEquals)
+    assert(exceptionFailure.fullStackTrace === oldExceptionFailure.fullStackTrace)
+    assertSeqEquals[AccumulableInfo](
+      exceptionFailure.accumUpdates, oldExceptionFailure.accumUpdates, (x, y) => x == y)
   }
 
-  /** -------------------------- *
-   | Helper test running methods |
-   * --------------------------- */
+  test("AccumulableInfo value de/serialization") {
+    import InternalAccumulator._
+    val blocks = Seq[(BlockId, BlockStatus)](
+      (TestBlockId("meebo"), BlockStatus(StorageLevel.MEMORY_ONLY, 1L, 2L)),
+      (TestBlockId("feebo"), BlockStatus(StorageLevel.DISK_ONLY, 3L, 4L)))
+    val blocksJson = JArray(blocks.toList.map { case (id, status) =>
+      ("Block ID" -> id.toString) ~
+      ("Status" -> JsonProtocol.blockStatusToJson(status))
+    })
+    testAccumValue(Some(RESULT_SIZE), 3L, JInt(3))
+    testAccumValue(Some(shuffleRead.REMOTE_BLOCKS_FETCHED), 2, JInt(2))
+    testAccumValue(Some(input.READ_METHOD), "aka", JString("aka"))
+    testAccumValue(Some(UPDATED_BLOCK_STATUSES), blocks, blocksJson)
+    // For anything else, we just cast the value to a string
+    testAccumValue(Some("anything"), blocks, JString(blocks.toString))
+    testAccumValue(Some("anything"), 123, JString("123"))
+  }
+
+}
+
+
+private[spark] object JsonProtocolSuite extends Assertions {
+  import InternalAccumulator._
+
+  private val jobSubmissionTime = 1421191042750L
+  private val jobCompletionTime = 1421191296660L
+  private val executorAddedTime = 1421458410000L
+  private val executorRemovedTime = 1421458922000L
 
   private def testEvent(event: SparkListenerEvent, jsonString: String) {
     val actualJsonString = compact(render(JsonProtocol.sparkEventToJson(event)))
     val newEvent = JsonProtocol.sparkEventFromJson(parse(actualJsonString))
-    assertJsonStringEquals(jsonString, actualJsonString)
+    assertJsonStringEquals(jsonString, actualJsonString, event.getClass.getSimpleName)
     assertEquals(event, newEvent)
   }
 
@@ -440,11 +495,19 @@ class JsonProtocolSuite extends SparkFunSuite {
     assertEquals(info, newInfo)
   }
 
+  private def testAccumValue(name: Option[String], value: Any, expectedJson: JValue): Unit = {
+    val json = JsonProtocol.accumValueToJson(name, value)
+    assert(json === expectedJson)
+    val newValue = JsonProtocol.accumValueFromJson(name, json)
+    val expectedValue = if (name.exists(_.startsWith(METRICS_PREFIX))) value else value.toString
+    assert(newValue === expectedValue)
+  }
+
   /** -------------------------------- *
    | Util methods for comparing events |
-   * --------------------------------- */
+    * --------------------------------- */
 
-  private def assertEquals(event1: SparkListenerEvent, event2: SparkListenerEvent) {
+  private[spark] def assertEquals(event1: SparkListenerEvent, event2: SparkListenerEvent) {
     (event1, event2) match {
       case (e1: SparkListenerStageSubmitted, e2: SparkListenerStageSubmitted) =>
         assert(e1.properties === e2.properties)
@@ -478,14 +541,17 @@ class JsonProtocolSuite extends SparkFunSuite {
         assert(e1.executorId === e1.executorId)
       case (e1: SparkListenerExecutorMetricsUpdate, e2: SparkListenerExecutorMetricsUpdate) =>
         assert(e1.execId === e2.execId)
-        assertSeqEquals[(Long, Int, Int, TaskMetrics)](e1.taskMetrics, e2.taskMetrics, (a, b) => {
-          val (taskId1, stageId1, stageAttemptId1, metrics1) = a
-          val (taskId2, stageId2, stageAttemptId2, metrics2) = b
-          assert(taskId1 === taskId2)
-          assert(stageId1 === stageId2)
-          assert(stageAttemptId1 === stageAttemptId2)
-          assertEquals(metrics1, metrics2)
-        })
+        assertSeqEquals[(Long, Int, Int, Seq[AccumulableInfo])](
+          e1.accumUpdates,
+          e2.accumUpdates,
+          (a, b) => {
+            val (taskId1, stageId1, stageAttemptId1, updates1) = a
+            val (taskId2, stageId2, stageAttemptId2, updates2) = b
+            assert(taskId1 === taskId2)
+            assert(stageId1 === stageId2)
+            assert(stageAttemptId1 === stageAttemptId2)
+            assertSeqEquals[AccumulableInfo](updates1, updates2, (a, b) => a.equals(b))
+          })
       case (e1, e2) =>
         assert(e1 === e2)
       case _ => fail("Events don't match in types!")
@@ -544,7 +610,6 @@ class JsonProtocolSuite extends SparkFunSuite {
   }
 
   private def assertEquals(metrics1: TaskMetrics, metrics2: TaskMetrics) {
-    assert(metrics1.hostname === metrics2.hostname)
     assert(metrics1.executorDeserializeTime === metrics2.executorDeserializeTime)
     assert(metrics1.resultSize === metrics2.resultSize)
     assert(metrics1.jvmGCTime === metrics2.jvmGCTime)
@@ -601,7 +666,7 @@ class JsonProtocolSuite extends SparkFunSuite {
         assert(r1.description === r2.description)
         assertSeqEquals(r1.stackTrace, r2.stackTrace, assertStackTraceElementEquals)
         assert(r1.fullStackTrace === r2.fullStackTrace)
-        assertOptionEquals(r1.metrics, r2.metrics, assertTaskMetricsEquals)
+        assertSeqEquals[AccumulableInfo](r1.accumUpdates, r2.accumUpdates, (a, b) => a.equals(b))
       case (TaskResultLost, TaskResultLost) =>
       case (TaskKilled, TaskKilled) =>
       case (TaskCommitDenied(jobId1, partitionId1, attemptNumber1),
@@ -637,10 +702,16 @@ class JsonProtocolSuite extends SparkFunSuite {
       assertStackTraceElementEquals)
   }
 
-  private def assertJsonStringEquals(json1: String, json2: String) {
+  private def assertJsonStringEquals(expected: String, actual: String, metadata: String) {
     val formatJsonString = (json: String) => json.replaceAll("[\\s|]", "")
-    assert(formatJsonString(json1) === formatJsonString(json2),
-      s"input ${formatJsonString(json1)} got ${formatJsonString(json2)}")
+    if (formatJsonString(expected) != formatJsonString(actual)) {
+      // scalastyle:off
+      // This prints something useful if the JSON strings don't match
+      println("=== EXPECTED ===\n" + pretty(parse(expected)) + "\n")
+      println("=== ACTUAL ===\n" + pretty(parse(actual)) + "\n")
+      // scalastyle:on
+      throw new TestFailedException(s"$metadata JSON did not equal", 1)
+    }
   }
 
   private def assertSeqEquals[T](seq1: Seq[T], seq2: Seq[T], assertEquals: (T, T) => Unit) {
@@ -699,7 +770,7 @@ class JsonProtocolSuite extends SparkFunSuite {
 
   /** ----------------------------------- *
    | Util methods for constructing events |
-   * ------------------------------------ */
+    * ------------------------------------ */
 
   private val properties = {
     val p = new Properties
@@ -746,8 +817,12 @@ class JsonProtocolSuite extends SparkFunSuite {
     taskInfo
   }
 
-  private def makeAccumulableInfo(id: Int, internal: Boolean = false): AccumulableInfo =
-    AccumulableInfo(id, " Accumulable " + id, Some("delta" + id), "val" + id, internal)
+  private def makeAccumulableInfo(
+      id: Int,
+      internal: Boolean = false,
+      countFailedValues: Boolean = false): AccumulableInfo =
+    new AccumulableInfo(id, Some(s"Accumulable$id"), Some(s"delta$id"), Some(s"val$id"),
+      internal, countFailedValues)
 
   /**
    * Creates a TaskMetrics object describing a task that read data from Hadoop (if hasHadoopInput is
@@ -764,7 +839,6 @@ class JsonProtocolSuite extends SparkFunSuite {
       hasOutput: Boolean,
       hasRecords: Boolean = true) = {
     val t = new TaskMetrics
-    t.setHostname("localhost")
     t.setExecutorDeserializeTime(a)
     t.setExecutorRunTime(b)
     t.setResultSize(c)
@@ -774,7 +848,7 @@ class JsonProtocolSuite extends SparkFunSuite {
 
     if (hasHadoopInput) {
       val inputMetrics = t.registerInputMetrics(DataReadMethod.Hadoop)
-      inputMetrics.incBytesRead(d + e + f)
+      inputMetrics.setBytesRead(d + e + f)
       inputMetrics.incRecordsRead(if (hasRecords) (d + e + f) / 100 else -1)
     } else {
       val sr = t.registerTempShuffleReadMetrics()
@@ -794,7 +868,7 @@ class JsonProtocolSuite extends SparkFunSuite {
       val sw = t.registerShuffleWriteMetrics()
       sw.incBytesWritten(a + b + c)
       sw.incWriteTime(b + c + d)
-      sw.setRecordsWritten(if (hasRecords) (a + b + c) / 100 else -1)
+      sw.incRecordsWritten(if (hasRecords) (a + b + c) / 100 else -1)
     }
     // Make at most 6 blocks
     t.setUpdatedBlockStatuses((1 to (e % 5 + 1)).map { i =>
@@ -826,14 +900,16 @@ class JsonProtocolSuite extends SparkFunSuite {
       |        "Name": "Accumulable2",
       |        "Update": "delta2",
       |        "Value": "val2",
-      |        "Internal": false
+      |        "Internal": false,
+      |        "Count Failed Values": false
       |      },
       |      {
       |        "ID": 1,
       |        "Name": "Accumulable1",
       |        "Update": "delta1",
       |        "Value": "val1",
-      |        "Internal": false
+      |        "Internal": false,
+      |        "Count Failed Values": false
       |      }
       |    ]
       |  },
@@ -881,14 +957,16 @@ class JsonProtocolSuite extends SparkFunSuite {
       |        "Name": "Accumulable2",
       |        "Update": "delta2",
       |        "Value": "val2",
-      |        "Internal": false
+      |        "Internal": false,
+      |        "Count Failed Values": false
       |      },
       |      {
       |        "ID": 1,
       |        "Name": "Accumulable1",
       |        "Update": "delta1",
       |        "Value": "val1",
-      |        "Internal": false
+      |        "Internal": false,
+      |        "Count Failed Values": false
       |      }
       |    ]
       |  }
@@ -919,21 +997,24 @@ class JsonProtocolSuite extends SparkFunSuite {
       |        "Name": "Accumulable1",
       |        "Update": "delta1",
       |        "Value": "val1",
-      |        "Internal": false
+      |        "Internal": false,
+      |        "Count Failed Values": false
       |      },
       |      {
       |        "ID": 2,
       |        "Name": "Accumulable2",
       |        "Update": "delta2",
       |        "Value": "val2",
-      |        "Internal": false
+      |        "Internal": false,
+      |        "Count Failed Values": false
       |      },
       |      {
       |        "ID": 3,
       |        "Name": "Accumulable3",
       |        "Update": "delta3",
       |        "Value": "val3",
-      |        "Internal": true
+      |        "Internal": true,
+      |        "Count Failed Values": false
       |      }
       |    ]
       |  }
@@ -962,21 +1043,24 @@ class JsonProtocolSuite extends SparkFunSuite {
       |        "Name": "Accumulable1",
       |        "Update": "delta1",
       |        "Value": "val1",
-      |        "Internal": false
+      |        "Internal": false,
+      |        "Count Failed Values": false
       |      },
       |      {
       |        "ID": 2,
       |        "Name": "Accumulable2",
       |        "Update": "delta2",
       |        "Value": "val2",
-      |        "Internal": false
+      |        "Internal": false,
+      |        "Count Failed Values": false
       |      },
       |      {
       |        "ID": 3,
       |        "Name": "Accumulable3",
       |        "Update": "delta3",
       |        "Value": "val3",
-      |        "Internal": true
+      |        "Internal": true,
+      |        "Count Failed Values": false
       |      }
       |    ]
       |  }
@@ -1011,26 +1095,28 @@ class JsonProtocolSuite extends SparkFunSuite {
       |        "Name": "Accumulable1",
       |        "Update": "delta1",
       |        "Value": "val1",
-      |        "Internal": false
+      |        "Internal": false,
+      |        "Count Failed Values": false
       |      },
       |      {
       |        "ID": 2,
       |        "Name": "Accumulable2",
       |        "Update": "delta2",
       |        "Value": "val2",
-      |        "Internal": false
+      |        "Internal": false,
+      |        "Count Failed Values": false
       |      },
       |      {
       |        "ID": 3,
       |        "Name": "Accumulable3",
       |        "Update": "delta3",
       |        "Value": "val3",
-      |        "Internal": true
+      |        "Internal": true,
+      |        "Count Failed Values": false
       |      }
       |    ]
       |  },
       |  "Task Metrics": {
-      |    "Host Name": "localhost",
       |    "Executor Deserialize Time": 300,
       |    "Executor Run Time": 400,
       |    "Result Size": 500,
@@ -1044,7 +1130,7 @@ class JsonProtocolSuite extends SparkFunSuite {
       |      "Fetch Wait Time": 900,
       |      "Remote Bytes Read": 1000,
       |      "Local Bytes Read": 1100,
-      |      "Total Records Read" : 10
+      |      "Total Records Read": 10
       |    },
       |    "Shuffle Write Metrics": {
       |      "Shuffle Bytes Written": 1200,
@@ -1098,26 +1184,28 @@ class JsonProtocolSuite extends SparkFunSuite {
       |        "Name": "Accumulable1",
       |        "Update": "delta1",
       |        "Value": "val1",
-      |        "Internal": false
+      |        "Internal": false,
+      |        "Count Failed Values": false
       |      },
       |      {
       |        "ID": 2,
       |        "Name": "Accumulable2",
       |        "Update": "delta2",
       |        "Value": "val2",
-      |        "Internal": false
+      |        "Internal": false,
+      |        "Count Failed Values": false
       |      },
       |      {
       |        "ID": 3,
       |        "Name": "Accumulable3",
       |        "Update": "delta3",
       |        "Value": "val3",
-      |        "Internal": true
+      |        "Internal": true,
+      |        "Count Failed Values": false
       |      }
       |    ]
       |  },
       |  "Task Metrics": {
-      |    "Host Name": "localhost",
       |    "Executor Deserialize Time": 300,
       |    "Executor Run Time": 400,
       |    "Result Size": 500,
@@ -1182,26 +1270,28 @@ class JsonProtocolSuite extends SparkFunSuite {
       |        "Name": "Accumulable1",
       |        "Update": "delta1",
       |        "Value": "val1",
-      |        "Internal": false
+      |        "Internal": false,
+      |        "Count Failed Values": false
       |      },
       |      {
       |        "ID": 2,
       |        "Name": "Accumulable2",
       |        "Update": "delta2",
       |        "Value": "val2",
-      |        "Internal": false
+      |        "Internal": false,
+      |        "Count Failed Values": false
       |      },
       |      {
       |        "ID": 3,
       |        "Name": "Accumulable3",
       |        "Update": "delta3",
       |        "Value": "val3",
-      |        "Internal": true
+      |        "Internal": true,
+      |        "Count Failed Values": false
       |      }
       |    ]
       |  },
       |  "Task Metrics": {
-      |    "Host Name": "localhost",
       |    "Executor Deserialize Time": 300,
       |    "Executor Run Time": 400,
       |    "Result Size": 500,
@@ -1273,17 +1363,19 @@ class JsonProtocolSuite extends SparkFunSuite {
       |      "Accumulables": [
       |        {
       |          "ID": 2,
-      |          "Name": " Accumulable 2",
+      |          "Name": "Accumulable2",
       |          "Update": "delta2",
       |          "Value": "val2",
-      |          "Internal": false
+      |          "Internal": false,
+      |          "Count Failed Values": false
       |        },
       |        {
       |          "ID": 1,
-      |          "Name": " Accumulable 1",
+      |          "Name": "Accumulable1",
       |          "Update": "delta1",
       |          "Value": "val1",
-      |          "Internal": false
+      |          "Internal": false,
+      |          "Count Failed Values": false
       |        }
       |      ]
       |    },
@@ -1331,17 +1423,19 @@ class JsonProtocolSuite extends SparkFunSuite {
       |      "Accumulables": [
       |        {
       |          "ID": 2,
-      |          "Name": " Accumulable 2",
+      |          "Name": "Accumulable2",
       |          "Update": "delta2",
       |          "Value": "val2",
-      |          "Internal": false
+      |          "Internal": false,
+      |          "Count Failed Values": false
       |        },
       |        {
       |          "ID": 1,
-      |          "Name": " Accumulable 1",
+      |          "Name": "Accumulable1",
       |          "Update": "delta1",
       |          "Value": "val1",
-      |          "Internal": false
+      |          "Internal": false,
+      |          "Count Failed Values": false
       |        }
       |      ]
       |    },
@@ -1405,17 +1499,19 @@ class JsonProtocolSuite extends SparkFunSuite {
       |      "Accumulables": [
       |        {
       |          "ID": 2,
-      |          "Name": " Accumulable 2",
+      |          "Name": "Accumulable2",
       |          "Update": "delta2",
       |          "Value": "val2",
-      |          "Internal": false
+      |          "Internal": false,
+      |          "Count Failed Values": false
       |        },
       |        {
       |          "ID": 1,
-      |          "Name": " Accumulable 1",
+      |          "Name": "Accumulable1",
       |          "Update": "delta1",
       |          "Value": "val1",
-      |          "Internal": false
+      |          "Internal": false,
+      |          "Count Failed Values": false
       |        }
       |      ]
       |    },
@@ -1495,17 +1591,19 @@ class JsonProtocolSuite extends SparkFunSuite {
       |      "Accumulables": [
       |        {
       |          "ID": 2,
-      |          "Name": " Accumulable 2",
+      |          "Name": "Accumulable2",
       |          "Update": "delta2",
       |          "Value": "val2",
-      |          "Internal": false
+      |          "Internal": false,
+      |          "Count Failed Values": false
       |        },
       |        {
       |          "ID": 1,
-      |          "Name": " Accumulable 1",
+      |          "Name": "Accumulable1",
       |          "Update": "delta1",
       |          "Value": "val1",
-      |          "Internal": false
+      |          "Internal": false,
+      |          "Count Failed Values": false
       |        }
       |      ]
       |    }
@@ -1657,51 +1755,208 @@ class JsonProtocolSuite extends SparkFunSuite {
     """
 
   private val executorMetricsUpdateJsonString =
-  s"""
-     |{
-     |  "Event": "SparkListenerExecutorMetricsUpdate",
-     |  "Executor ID": "exec3",
-     |  "Metrics Updated": [
-     |  {
-     |    "Task ID": 1,
-     |    "Stage ID": 2,
-     |    "Stage Attempt ID": 3,
-     |    "Task Metrics": {
-     |    "Host Name": "localhost",
-     |    "Executor Deserialize Time": 300,
-     |    "Executor Run Time": 400,
-     |    "Result Size": 500,
-     |    "JVM GC Time": 600,
-     |    "Result Serialization Time": 700,
-     |    "Memory Bytes Spilled": 800,
-     |    "Disk Bytes Spilled": 0,
-     |    "Input Metrics": {
-     |      "Data Read Method": "Hadoop",
-     |      "Bytes Read": 2100,
-     |      "Records Read": 21
-     |    },
-     |    "Output Metrics": {
-     |      "Data Write Method": "Hadoop",
-     |      "Bytes Written": 1200,
-     |      "Records Written": 12
-     |    },
-     |    "Updated Blocks": [
-     |      {
-     |        "Block ID": "rdd_0_0",
-     |        "Status": {
-     |          "Storage Level": {
-     |            "Use Disk": true,
-     |            "Use Memory": true,
-     |            "Deserialized": false,
-     |            "Replication": 2
-     |          },
-     |          "Memory Size": 0,
-     |          "Disk Size": 0
-     |        }
-     |      }
-     |    ]
-     |  }
-     |  }]
-     |}
-   """.stripMargin
+    s"""
+      |{
+      |  "Event": "SparkListenerExecutorMetricsUpdate",
+      |  "Executor ID": "exec3",
+      |  "Metrics Updated": [
+      |    {
+      |      "Task ID": 1,
+      |      "Stage ID": 2,
+      |      "Stage Attempt ID": 3,
+      |      "Accumulator Updates": [
+      |        {
+      |          "ID": 0,
+      |          "Name": "$EXECUTOR_DESERIALIZE_TIME",
+      |          "Update": 300,
+      |          "Internal": true,
+      |          "Count Failed Values": true
+      |        },
+      |        {
+      |          "ID": 1,
+      |          "Name": "$EXECUTOR_RUN_TIME",
+      |          "Update": 400,
+      |          "Internal": true,
+      |          "Count Failed Values": true
+      |        },
+      |        {
+      |          "ID": 2,
+      |          "Name": "$RESULT_SIZE",
+      |          "Update": 500,
+      |          "Internal": true,
+      |          "Count Failed Values": true
+      |        },
+      |        {
+      |          "ID": 3,
+      |          "Name": "$JVM_GC_TIME",
+      |          "Update": 600,
+      |          "Internal": true,
+      |          "Count Failed Values": true
+      |        },
+      |        {
+      |          "ID": 4,
+      |          "Name": "$RESULT_SERIALIZATION_TIME",
+      |          "Update": 700,
+      |          "Internal": true,
+      |          "Count Failed Values": true
+      |        },
+      |        {
+      |          "ID": 5,
+      |          "Name": "$MEMORY_BYTES_SPILLED",
+      |          "Update": 800,
+      |          "Internal": true,
+      |          "Count Failed Values": true
+      |        },
+      |        {
+      |          "ID": 6,
+      |          "Name": "$DISK_BYTES_SPILLED",
+      |          "Update": 0,
+      |          "Internal": true,
+      |          "Count Failed Values": true
+      |        },
+      |        {
+      |          "ID": 7,
+      |          "Name": "$PEAK_EXECUTION_MEMORY",
+      |          "Update": 0,
+      |          "Internal": true,
+      |          "Count Failed Values": true
+      |        },
+      |        {
+      |          "ID": 8,
+      |          "Name": "$UPDATED_BLOCK_STATUSES",
+      |          "Update": [
+      |            {
+      |              "BlockID": "rdd_0_0",
+      |              "Status": {
+      |                "StorageLevel": {
+      |                  "UseDisk": true,
+      |                  "UseMemory": true,
+      |                  "Deserialized": false,
+      |                  "Replication": 2
+      |                },
+      |                "MemorySize": 0,
+      |                "DiskSize": 0
+      |              }
+      |            }
+      |          ],
+      |          "Internal": true,
+      |          "Count Failed Values": true
+      |        },
+      |        {
+      |          "ID": 9,
+      |          "Name": "${shuffleRead.REMOTE_BLOCKS_FETCHED}",
+      |          "Update": 0,
+      |          "Internal": true,
+      |          "Count Failed Values": true
+      |        },
+      |        {
+      |          "ID": 10,
+      |          "Name": "${shuffleRead.LOCAL_BLOCKS_FETCHED}",
+      |          "Update": 0,
+      |          "Internal": true,
+      |          "Count Failed Values": true
+      |        },
+      |        {
+      |          "ID": 11,
+      |          "Name": "${shuffleRead.REMOTE_BYTES_READ}",
+      |          "Update": 0,
+      |          "Internal": true,
+      |          "Count Failed Values": true
+      |        },
+      |        {
+      |          "ID": 12,
+      |          "Name": "${shuffleRead.LOCAL_BYTES_READ}",
+      |          "Update": 0,
+      |          "Internal": true,
+      |          "Count Failed Values": true
+      |        },
+      |        {
+      |          "ID": 13,
+      |          "Name": "${shuffleRead.FETCH_WAIT_TIME}",
+      |          "Update": 0,
+      |          "Internal": true,
+      |          "Count Failed Values": true
+      |        },
+      |        {
+      |          "ID": 14,
+      |          "Name": "${shuffleRead.RECORDS_READ}",
+      |          "Update": 0,
+      |          "Internal": true,
+      |          "Count Failed Values": true
+      |        },
+      |        {
+      |          "ID": 15,
+      |          "Name": "${shuffleWrite.BYTES_WRITTEN}",
+      |          "Update": 0,
+      |          "Internal": true,
+      |          "Count Failed Values": true
+      |        },
+      |        {
+      |          "ID": 16,
+      |          "Name": "${shuffleWrite.RECORDS_WRITTEN}",
+      |          "Update": 0,
+      |          "Internal": true,
+      |          "Count Failed Values": true
+      |        },
+      |        {
+      |          "ID": 17,
+      |          "Name": "${shuffleWrite.WRITE_TIME}",
+      |          "Update": 0,
+      |          "Internal": true,
+      |          "Count Failed Values": true
+      |        },
+      |        {
+      |          "ID": 18,
+      |          "Name": "${input.READ_METHOD}",
+      |          "Update": "Hadoop",
+      |          "Internal": true,
+      |          "Count Failed Values": true
+      |        },
+      |        {
+      |          "ID": 19,
+      |          "Name": "${input.BYTES_READ}",
+      |          "Update": 2100,
+      |          "Internal": true,
+      |          "Count Failed Values": true
+      |        },
+      |        {
+      |          "ID": 20,
+      |          "Name": "${input.RECORDS_READ}",
+      |          "Update": 21,
+      |          "Internal": true,
+      |          "Count Failed Values": true
+      |        },
+      |        {
+      |          "ID": 21,
+      |          "Name": "${output.WRITE_METHOD}",
+      |          "Update": "Hadoop",
+      |          "Internal": true,
+      |          "Count Failed Values": true
+      |        },
+      |        {
+      |          "ID": 22,
+      |          "Name": "${output.BYTES_WRITTEN}",
+      |          "Update": 1200,
+      |          "Internal": true,
+      |          "Count Failed Values": true
+      |        },
+      |        {
+      |          "ID": 23,
+      |          "Name": "${output.RECORDS_WRITTEN}",
+      |          "Update": 12,
+      |          "Internal": true,
+      |          "Count Failed Values": true
+      |        },
+      |        {
+      |          "ID": 24,
+      |          "Name": "$TEST_ACCUM",
+      |          "Update": 0,
+      |          "Internal": true,
+      |          "Count Failed Values": true
+      |        }
+      |      ]
+      |    }
+      |  ]
+      |}
+    """.stripMargin
 }
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index fc7dc2181de851062495dd489464dd2404475bc6..968a2903f3010b762019de118b45148bcf9fea69 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -175,6 +175,15 @@ object MimaExcludes {
       ) ++ Seq(
         // SPARK-12510 Refactor ActorReceiver to support Java
         ProblemFilters.exclude[AbstractClassProblem]("org.apache.spark.streaming.receiver.ActorReceiver")
+      ) ++ Seq(
+        // SPARK-12895 Implement TaskMetrics using accumulators
+        ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.internalMetricsToAccumulators"),
+        ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.collectInternalAccumulators"),
+        ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.collectAccumulators")
+      ) ++ Seq(
+        // SPARK-12896 Send only accumulator updates to driver, not TaskMetrics
+        ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.Accumulable.this"),
+        ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.Accumulator.this")
       ) ++ Seq(
         // SPARK-12692 Scala style: Fix the style violation (Space before "," or ":")
         ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.SparkSink.org$apache$spark$streaming$flume$sink$Logging$$log_"),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
index 73dc8cb984471086791bd18efc8c54858f5389bb..75cb6d1137c353ae7fb797fbb2a1a85189ad6683 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
@@ -79,17 +79,17 @@ case class Sort(
         sorter.setTestSpillFrequency(testSpillFrequency)
       }
 
+      val metrics = TaskContext.get().taskMetrics()
       // Remember spill data size of this task before execute this operator so that we can
       // figure out how many bytes we spilled for this operator.
-      val spillSizeBefore = TaskContext.get().taskMetrics().memoryBytesSpilled
+      val spillSizeBefore = metrics.memoryBytesSpilled
 
       val sortedIterator = sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]])
 
       dataSize += sorter.getPeakMemoryUsage
-      spillSize += TaskContext.get().taskMetrics().memoryBytesSpilled - spillSizeBefore
+      spillSize += metrics.memoryBytesSpilled - spillSizeBefore
+      metrics.incPeakExecutionMemory(sorter.getPeakMemoryUsage)
 
-      TaskContext.get().internalMetricsToAccumulators(
-        InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.getPeakMemoryUsage)
       sortedIterator
     }
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index 41799c596b6d3065d585600ea3d12514a67f8271..001e9c306ac4558ff2073fbb50370c36748069c0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -418,10 +418,10 @@ class TungstenAggregationIterator(
         val mapMemory = hashMap.getPeakMemoryUsedBytes
         val sorterMemory = Option(externalSorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L)
         val peakMemory = Math.max(mapMemory, sorterMemory)
+        val metrics = TaskContext.get().taskMetrics()
         dataSize += peakMemory
-        spillSize += TaskContext.get().taskMetrics().memoryBytesSpilled - spillSizeBefore
-        TaskContext.get().internalMetricsToAccumulators(
-          InternalAccumulator.PEAK_EXECUTION_MEMORY).add(peakMemory)
+        spillSize += metrics.memoryBytesSpilled - spillSizeBefore
+        metrics.incPeakExecutionMemory(peakMemory)
       }
       numOutputRows += 1
       res
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala
index 8222b84d33e3b0e3874ce9a8bfdcbfa706be56fa..edd87c2d8ed07bfb7b91d97626ac05d82d9137ef 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala
@@ -136,14 +136,17 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag](
 
       // Find a function that will return the FileSystem bytes read by this thread. Do this before
       // creating RecordReader, because RecordReader's constructor might read some bytes
-      val bytesReadCallback = inputMetrics.bytesReadCallback.orElse {
-        split.serializableHadoopSplit.value match {
-          case _: FileSplit | _: CombineFileSplit =>
-            SparkHadoopUtil.get.getFSBytesReadOnThreadCallback()
-          case _ => None
+      val getBytesReadCallback: Option[() => Long] = split.serializableHadoopSplit.value match {
+        case _: FileSplit | _: CombineFileSplit =>
+              SparkHadoopUtil.get.getFSBytesReadOnThreadCallback()
+        case _ => None
+      }
+
+      def updateBytesRead(): Unit = {
+        getBytesReadCallback.foreach { getBytesRead =>
+          inputMetrics.setBytesRead(getBytesRead())
         }
       }
-      inputMetrics.setBytesReadCallback(bytesReadCallback)
 
       val format = inputFormatClass.newInstance
       format match {
@@ -208,6 +211,9 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag](
         if (!finished) {
           inputMetrics.incRecordsRead(1)
         }
+        if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) {
+          updateBytesRead()
+        }
         reader.getCurrentValue
       }
 
@@ -228,8 +234,8 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag](
           } finally {
             reader = null
           }
-          if (bytesReadCallback.isDefined) {
-            inputMetrics.updateBytesRead()
+          if (getBytesReadCallback.isDefined) {
+            updateBytesRead()
           } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] ||
             split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) {
             // If we can't get the bytes read from the FS stats, fall back to the split size,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index c9ea579b5e809b51607cbac03537fc896c3555ae..04640711d99d0938b6a42f90d16e23e040913798 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -111,8 +111,7 @@ case class BroadcastHashJoin(
       val hashedRelation = broadcastRelation.value
       hashedRelation match {
         case unsafe: UnsafeHashedRelation =>
-          TaskContext.get().internalMetricsToAccumulators(
-            InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize)
+          TaskContext.get().taskMetrics().incPeakExecutionMemory(unsafe.getUnsafeSize)
         case _ =>
       }
       hashJoin(streamedIter, numStreamedRows, hashedRelation, numOutputRows)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
index 6c7fa2eee5bfae9d325ebf219392d003843ec72f..db8edd169dcfaccb32c1b0f4fd10799af72f01f8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
@@ -119,8 +119,7 @@ case class BroadcastHashOuterJoin(
 
       hashTable match {
         case unsafe: UnsafeHashedRelation =>
-          TaskContext.get().internalMetricsToAccumulators(
-            InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize)
+          TaskContext.get().taskMetrics().incPeakExecutionMemory(unsafe.getUnsafeSize)
         case _ =>
       }
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
index 004407b2e692512d5704ec3b4cbf5b6d58465904..8929dc3af19121497f4c916ff3236e222ad31ae4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
@@ -66,8 +66,7 @@ case class BroadcastLeftSemiJoinHash(
         val hashedRelation = broadcastedRelation.value
         hashedRelation match {
           case unsafe: UnsafeHashedRelation =>
-            TaskContext.get().internalMetricsToAccumulators(
-              InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize)
+            TaskContext.get().taskMetrics().incPeakExecutionMemory(unsafe.getUnsafeSize)
           case _ =>
         }
         hashSemiJoin(streamIter, numLeftRows, hashedRelation, numOutputRows)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
index 52735c9d7f8c473756b707e5192c08c4bb52cd37..950dc7816241f943f541ed23c1bbd9c8a8162a45 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.execution.metric
 
-import org.apache.spark.{Accumulable, AccumulableParam, SparkContext}
+import org.apache.spark.{Accumulable, AccumulableParam, Accumulators, SparkContext}
 import org.apache.spark.util.Utils
 
 /**
@@ -28,7 +28,7 @@ import org.apache.spark.util.Utils
  */
 private[sql] abstract class SQLMetric[R <: SQLMetricValue[T], T](
     name: String, val param: SQLMetricParam[R, T])
-  extends Accumulable[R, T](param.zero, param, Some(name), true) {
+  extends Accumulable[R, T](param.zero, param, Some(name), internal = true) {
 
   def reset(): Unit = {
     this.value = param.zero
@@ -131,6 +131,8 @@ private[sql] object SQLMetrics {
       name: String,
       param: LongSQLMetricParam): LongSQLMetric = {
     val acc = new LongSQLMetric(name, param)
+    // This is an internal accumulator so we need to register it explicitly.
+    Accumulators.register(acc)
     sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc))
     acc
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
index 83c64f755f90fbb4ce39599a6f94e97cd2673dbc..544606f1168b641632ccc173ee149190a92be014 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
@@ -139,9 +139,8 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi
 
   override def onExecutorMetricsUpdate(
       executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = synchronized {
-    for ((taskId, stageId, stageAttemptID, metrics) <- executorMetricsUpdate.taskMetrics) {
-      updateTaskAccumulatorValues(taskId, stageId, stageAttemptID, metrics.accumulatorUpdates(),
-        finishTask = false)
+    for ((taskId, stageId, stageAttemptID, accumUpdates) <- executorMetricsUpdate.accumUpdates) {
+      updateTaskAccumulatorValues(taskId, stageId, stageAttemptID, accumUpdates, finishTask = false)
     }
   }
 
@@ -177,7 +176,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi
       taskId: Long,
       stageId: Int,
       stageAttemptID: Int,
-      accumulatorUpdates: Map[Long, Any],
+      accumulatorUpdates: Seq[AccumulableInfo],
       finishTask: Boolean): Unit = {
 
     _stageIdToStageMetrics.get(stageId) match {
@@ -289,8 +288,10 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi
           for (stageId <- executionUIData.stages;
                stageMetrics <- _stageIdToStageMetrics.get(stageId).toIterable;
                taskMetrics <- stageMetrics.taskIdToMetricUpdates.values;
-               accumulatorUpdate <- taskMetrics.accumulatorUpdates.toSeq) yield {
-            accumulatorUpdate
+               accumulatorUpdate <- taskMetrics.accumulatorUpdates) yield {
+            assert(accumulatorUpdate.update.isDefined, s"accumulator update from " +
+              s"task did not have a partial value: ${accumulatorUpdate.name}")
+            (accumulatorUpdate.id, accumulatorUpdate.update.get)
           }
         }.filter { case (id, _) => executionUIData.accumulatorMetrics.contains(id) }
         mergeAccumulatorUpdates(accumulatorUpdates, accumulatorId =>
@@ -328,9 +329,10 @@ private[spark] class SQLHistoryListener(conf: SparkConf, sparkUI: SparkUI)
       taskEnd.taskInfo.taskId,
       taskEnd.stageId,
       taskEnd.stageAttemptId,
-      taskEnd.taskInfo.accumulables.map { acc =>
-        (acc.id, new LongSQLMetricValue(acc.update.getOrElse("0").toLong))
-      }.toMap,
+      taskEnd.taskInfo.accumulables.map { a =>
+        val newValue = new LongSQLMetricValue(a.update.map(_.asInstanceOf[Long]).getOrElse(0L))
+        a.copy(update = Some(newValue))
+      },
       finishTask = true)
   }
 
@@ -406,4 +408,4 @@ private[ui] class SQLStageMetrics(
 private[ui] class SQLTaskMetrics(
     val attemptId: Long, // TODO not used yet
     var finished: Boolean,
-    var accumulatorUpdates: Map[Long, Any])
+    var accumulatorUpdates: Seq[AccumulableInfo])
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 47308966e92cb00abc112154cbce57f056d64b6b..10ccd4b8f60db56f7f642a1d2d511f9ea12d35f2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1648,7 +1648,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
 
   test("external sorting updates peak execution memory") {
     AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external sort") {
-      sortTest()
+      sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC").collect()
     }
   }
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala
index 9575d26fd123f8caefd23d37944c502d7f21b200..273937fa8ce90546181b7a9fa950ed5bed22bc4f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala
@@ -49,8 +49,7 @@ case class ReferenceSort(
       val context = TaskContext.get()
       context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
       context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
-      context.internalMetricsToAccumulators(
-        InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes)
+      context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
       CompletionIterator[InternalRow, Iterator[InternalRow]](baseIterator, sorter.stop())
     }, preservesPartitioning = true)
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
index 9c258cb31f460d0cde720b88927a059595bb71ad..c7df8b51e2f13d9190077730177adffcfdd956bc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
@@ -71,8 +71,7 @@ class UnsafeFixedWidthAggregationMapSuite
         taskAttemptId = Random.nextInt(10000),
         attemptNumber = 0,
         taskMemoryManager = taskMemoryManager,
-        metricsSystem = null,
-        internalAccumulators = Seq.empty))
+        metricsSystem = null))
 
       try {
         f
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
index 8a95359d9de254fbae0210ef3067e0a07a6d918e..e03bd6a3e7d208e7905fdaa1cc4fdcd77603a247 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
@@ -117,8 +117,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext {
       taskAttemptId = 98456,
       attemptNumber = 0,
       taskMemoryManager = taskMemMgr,
-      metricsSystem = null,
-      internalAccumulators = Seq.empty))
+      metricsSystem = null))
 
     val sorter = new UnsafeKVExternalSorter(
       keySchema, valueSchema, SparkEnv.get.blockManager, pageSize)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala
index 647a7e9a4e196851cecfbe12f4e5c709db716294..86c2c25c2c7e1f169499a941d308d6f9a1ad169f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala
@@ -17,12 +17,19 @@
 
 package org.apache.spark.sql.execution.columnar
 
+import org.scalatest.BeforeAndAfterEach
+
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql._
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.test.SQLTestData._
 
-class PartitionBatchPruningSuite extends SparkFunSuite with SharedSQLContext {
+
+class PartitionBatchPruningSuite
+  extends SparkFunSuite
+  with BeforeAndAfterEach
+  with SharedSQLContext {
+
   import testImplicits._
 
   private lazy val originalColumnBatchSize = sqlContext.conf.columnBatchSize
@@ -32,30 +39,41 @@ class PartitionBatchPruningSuite extends SparkFunSuite with SharedSQLContext {
     super.beforeAll()
     // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch
     sqlContext.setConf(SQLConf.COLUMN_BATCH_SIZE, 10)
-
-    val pruningData = sparkContext.makeRDD((1 to 100).map { key =>
-      val string = if (((key - 1) / 10) % 2 == 0) null else key.toString
-      TestData(key, string)
-    }, 5).toDF()
-    pruningData.registerTempTable("pruningData")
-
     // Enable in-memory partition pruning
     sqlContext.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true)
     // Enable in-memory table scan accumulators
     sqlContext.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true")
-    sqlContext.cacheTable("pruningData")
   }
 
   override protected def afterAll(): Unit = {
     try {
       sqlContext.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize)
       sqlContext.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning)
-      sqlContext.uncacheTable("pruningData")
     } finally {
       super.afterAll()
     }
   }
 
+  override protected def beforeEach(): Unit = {
+    super.beforeEach()
+    // This creates accumulators, which get cleaned up after every single test,
+    // so we need to do this before every test.
+    val pruningData = sparkContext.makeRDD((1 to 100).map { key =>
+      val string = if (((key - 1) / 10) % 2 == 0) null else key.toString
+      TestData(key, string)
+    }, 5).toDF()
+    pruningData.registerTempTable("pruningData")
+    sqlContext.cacheTable("pruningData")
+  }
+
+  override protected def afterEach(): Unit = {
+    try {
+      sqlContext.uncacheTable("pruningData")
+    } finally {
+      super.afterEach()
+    }
+  }
+
   // Comparisons
   checkBatchPruning("SELECT key FROM pruningData WHERE key = 1", 1, 1)(Seq(1))
   checkBatchPruning("SELECT key FROM pruningData WHERE 1 = key", 1, 1)(Seq(1))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
index 81a159d542c67c10ab62eceecf4cf680af7661b1..2c408c8878470fc8f49bda3358aff375e77f0731 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.ui
 
 import java.util.Properties
 
+import org.mockito.Mockito.{mock, when}
+
 import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite}
 import org.apache.spark.executor.TaskMetrics
 import org.apache.spark.scheduler._
@@ -67,9 +69,11 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
   )
 
   private def createTaskMetrics(accumulatorUpdates: Map[Long, Long]): TaskMetrics = {
-    val metrics = new TaskMetrics
-    metrics.setAccumulatorsUpdater(() => accumulatorUpdates.mapValues(new LongSQLMetricValue(_)))
-    metrics.updateAccumulators()
+    val metrics = mock(classOf[TaskMetrics])
+    when(metrics.accumulatorUpdates()).thenReturn(accumulatorUpdates.map { case (id, update) =>
+      new AccumulableInfo(id, Some(""), Some(new LongSQLMetricValue(update)),
+        value = None, internal = true, countFailedValues = true)
+    }.toSeq)
     metrics
   }
 
@@ -114,17 +118,17 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
     assert(listener.getExecutionMetrics(0).isEmpty)
 
     listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq(
-      // (task id, stage id, stage attempt, metrics)
-      (0L, 0, 0, createTaskMetrics(accumulatorUpdates)),
-      (1L, 0, 0, createTaskMetrics(accumulatorUpdates))
+      // (task id, stage id, stage attempt, accum updates)
+      (0L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()),
+      (1L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates())
     )))
 
     checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2))
 
     listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq(
-      // (task id, stage id, stage attempt, metrics)
-      (0L, 0, 0, createTaskMetrics(accumulatorUpdates)),
-      (1L, 0, 0, createTaskMetrics(accumulatorUpdates.mapValues(_ * 2)))
+      // (task id, stage id, stage attempt, accum updates)
+      (0L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()),
+      (1L, 0, 0, createTaskMetrics(accumulatorUpdates.mapValues(_ * 2)).accumulatorUpdates())
     )))
 
     checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 3))
@@ -133,9 +137,9 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
     listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(0, 1)))
 
     listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq(
-      // (task id, stage id, stage attempt, metrics)
-      (0L, 0, 1, createTaskMetrics(accumulatorUpdates)),
-      (1L, 0, 1, createTaskMetrics(accumulatorUpdates))
+      // (task id, stage id, stage attempt, accum updates)
+      (0L, 0, 1, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()),
+      (1L, 0, 1, createTaskMetrics(accumulatorUpdates).accumulatorUpdates())
     )))
 
     checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2))
@@ -173,9 +177,9 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
     listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(1, 0)))
 
     listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq(
-      // (task id, stage id, stage attempt, metrics)
-      (0L, 1, 0, createTaskMetrics(accumulatorUpdates)),
-      (1L, 1, 0, createTaskMetrics(accumulatorUpdates))
+      // (task id, stage id, stage attempt, accum updates)
+      (0L, 1, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()),
+      (1L, 1, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates())
     )))
 
     checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 7))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
index b46b0d2f6040ad70164b4df16d028493c944d957..9a24a2487a254010f95a1d6af359a4e314a086b2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
@@ -140,7 +140,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
         .filter(_._2.name == InternalAccumulator.PEAK_EXECUTION_MEMORY)
 
       assert(peakMemoryAccumulator.size == 1)
-      peakMemoryAccumulator.head._2.value.toLong
+      peakMemoryAccumulator.head._2.value.get.asInstanceOf[Long]
     }
 
     assert(sparkListener.getCompletedStageInfos.length == 2)