diff --git a/core/src/main/scala/org/apache/spark/Accumulable.scala b/core/src/main/scala/org/apache/spark/Accumulable.scala
index e8f053c150693448e31ba734f37d7de3a0f75571..c76720c4bb8b22e0bea077c3403958019c4851bc 100644
--- a/core/src/main/scala/org/apache/spark/Accumulable.scala
+++ b/core/src/main/scala/org/apache/spark/Accumulable.scala
@@ -63,7 +63,7 @@ class Accumulable[R, T] private (
       param: AccumulableParam[R, T],
       name: Option[String],
       countFailedValues: Boolean) = {
-    this(Accumulators.newId(), initialValue, param, name, countFailedValues)
+    this(AccumulatorContext.newId(), initialValue, param, name, countFailedValues)
   }
 
   private[spark] def this(initialValue: R, param: AccumulableParam[R, T], name: Option[String]) = {
@@ -72,34 +72,23 @@ class Accumulable[R, T] private (
 
   def this(initialValue: R, param: AccumulableParam[R, T]) = this(initialValue, param, None)
 
-  @volatile @transient private var value_ : R = initialValue // Current value on driver
-  val zero = param.zero(initialValue) // Zero value to be passed to executors
-  private var deserialized = false
-
-  Accumulators.register(this)
-
-  /**
-   * Return a copy of this [[Accumulable]].
-   *
-   * The copy will have the same ID as the original and will not be registered with
-   * [[Accumulators]] again. This method exists so that the caller can avoid passing the
-   * same mutable instance around.
-   */
-  private[spark] def copy(): Accumulable[R, T] = {
-    new Accumulable[R, T](id, initialValue, param, name, countFailedValues)
-  }
+  val zero = param.zero(initialValue)
+  private[spark] val newAcc = new LegacyAccumulatorWrapper(initialValue, param)
+  newAcc.metadata = AccumulatorMetadata(id, name, countFailedValues)
+  // Register the new accumulator in ctor, to follow the previous behaviour.
+  AccumulatorContext.register(newAcc)
 
   /**
    * Add more data to this accumulator / accumulable
    * @param term the data to add
    */
-  def += (term: T) { value_ = param.addAccumulator(value_, term) }
+  def += (term: T) { newAcc.add(term) }
 
   /**
    * Add more data to this accumulator / accumulable
    * @param term the data to add
    */
-  def add(term: T) { value_ = param.addAccumulator(value_, term) }
+  def add(term: T) { newAcc.add(term) }
 
   /**
    * Merge two accumulable objects together
@@ -107,7 +96,7 @@ class Accumulable[R, T] private (
    * Normally, a user will not want to use this version, but will instead call `+=`.
    * @param term the other `R` that will get merged with this
    */
-  def ++= (term: R) { value_ = param.addInPlace(value_, term)}
+  def ++= (term: R) { newAcc._value = param.addInPlace(newAcc._value, term) }
 
   /**
    * Merge two accumulable objects together
@@ -115,18 +104,12 @@ class Accumulable[R, T] private (
    * Normally, a user will not want to use this version, but will instead call `add`.
    * @param term the other `R` that will get merged with this
    */
-  def merge(term: R) { value_ = param.addInPlace(value_, term)}
+  def merge(term: R) { newAcc._value = param.addInPlace(newAcc._value, term) }
 
   /**
    * Access the accumulator's current value; only allowed on driver.
    */
-  def value: R = {
-    if (!deserialized) {
-      value_
-    } else {
-      throw new UnsupportedOperationException("Can't read accumulator value in task")
-    }
-  }
+  def value: R = newAcc.value
 
   /**
    * Get the current value of this accumulator from within a task.
@@ -137,14 +120,14 @@ class Accumulable[R, T] private (
    * The typical use of this method is to directly mutate the local value, eg., to add
    * an element to a Set.
    */
-  def localValue: R = value_
+  def localValue: R = newAcc.localValue
 
   /**
    * Set the accumulator's value; only allowed on driver.
    */
   def value_= (newValue: R) {
-    if (!deserialized) {
-      value_ = newValue
+    if (newAcc.isAtDriverSide) {
+      newAcc._value = newValue
     } else {
       throw new UnsupportedOperationException("Can't assign accumulator value in task")
     }
@@ -153,7 +136,7 @@ class Accumulable[R, T] private (
   /**
    * Set the accumulator's value. For internal use only.
    */
-  def setValue(newValue: R): Unit = { value_ = newValue }
+  def setValue(newValue: R): Unit = { newAcc._value = newValue }
 
   /**
    * Set the accumulator's value. For internal use only.
@@ -168,22 +151,7 @@ class Accumulable[R, T] private (
     new AccumulableInfo(id, name, update, value, isInternal, countFailedValues)
   }
 
-  // Called by Java when deserializing an object
-  private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
-    in.defaultReadObject()
-    value_ = zero
-    deserialized = true
-
-    // Automatically register the accumulator when it is deserialized with the task closure.
-    // This is for external accumulators and internal ones that do not represent task level
-    // metrics, e.g. internal SQL metrics, which are per-operator.
-    val taskContext = TaskContext.get()
-    if (taskContext != null) {
-      taskContext.registerAccumulator(this)
-    }
-  }
-
-  override def toString: String = if (value_ == null) "null" else value_.toString
+  override def toString: String = if (newAcc._value == null) "null" else newAcc._value.toString
 }
 
 
diff --git a/core/src/main/scala/org/apache/spark/Accumulator.scala b/core/src/main/scala/org/apache/spark/Accumulator.scala
index 0c17f014c90db4085e7244717789eb74d2c56968..9b007b97760883e14f4680255bd0872535c507d6 100644
--- a/core/src/main/scala/org/apache/spark/Accumulator.scala
+++ b/core/src/main/scala/org/apache/spark/Accumulator.scala
@@ -68,73 +68,6 @@ class Accumulator[T] private[spark] (
   extends Accumulable[T, T](initialValue, param, name, countFailedValues)
 
 
-// TODO: The multi-thread support in accumulators is kind of lame; check
-// if there's a more intuitive way of doing it right
-private[spark] object Accumulators extends Logging {
-  /**
-   * This global map holds the original accumulator objects that are created on the driver.
-   * It keeps weak references to these objects so that accumulators can be garbage-collected
-   * once the RDDs and user-code that reference them are cleaned up.
-   * TODO: Don't use a global map; these should be tied to a SparkContext (SPARK-13051).
-   */
-  @GuardedBy("Accumulators")
-  val originals = mutable.Map[Long, WeakReference[Accumulable[_, _]]]()
-
-  private val nextId = new AtomicLong(0L)
-
-  /**
-   * Return a globally unique ID for a new [[Accumulable]].
-   * Note: Once you copy the [[Accumulable]] the ID is no longer unique.
-   */
-  def newId(): Long = nextId.getAndIncrement
-
-  /**
-   * Register an [[Accumulable]] created on the driver such that it can be used on the executors.
-   *
-   * All accumulators registered here can later be used as a container for accumulating partial
-   * values across multiple tasks. This is what [[org.apache.spark.scheduler.DAGScheduler]] does.
-   * Note: if an accumulator is registered here, it should also be registered with the active
-   * context cleaner for cleanup so as to avoid memory leaks.
-   *
-   * If an [[Accumulable]] with the same ID was already registered, this does nothing instead
-   * of overwriting it. This happens when we copy accumulators, e.g. when we reconstruct
-   * [[org.apache.spark.executor.TaskMetrics]] from accumulator updates.
-   */
-  def register(a: Accumulable[_, _]): Unit = synchronized {
-    if (!originals.contains(a.id)) {
-      originals(a.id) = new WeakReference[Accumulable[_, _]](a)
-    }
-  }
-
-  /**
-   * Unregister the [[Accumulable]] with the given ID, if any.
-   */
-  def remove(accId: Long): Unit = synchronized {
-    originals.remove(accId)
-  }
-
-  /**
-   * Return the [[Accumulable]] registered with the given ID, if any.
-   */
-  def get(id: Long): Option[Accumulable[_, _]] = synchronized {
-    originals.get(id).map { weakRef =>
-      // Since we are storing weak references, we must check whether the underlying data is valid.
-      weakRef.get.getOrElse {
-        throw new IllegalAccessError(s"Attempted to access garbage collected accumulator $id")
-      }
-    }
-  }
-
-  /**
-   * Clear all registered [[Accumulable]]s. For testing only.
-   */
-  def clear(): Unit = synchronized {
-    originals.clear()
-  }
-
-}
-
-
 /**
  * A simpler version of [[org.apache.spark.AccumulableParam]] where the only data type you can add
  * in is the same type as the accumulated value. An implicit AccumulatorParam object needs to be
diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala
index 76692ccec815316e621f530b934e9b5e8ef0182b..63a00a84af3cd2d0ee3a1c899d0757d5f6eae834 100644
--- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala
@@ -144,7 +144,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
     registerForCleanup(rdd, CleanRDD(rdd.id))
   }
 
-  def registerAccumulatorForCleanup(a: Accumulable[_, _]): Unit = {
+  def registerAccumulatorForCleanup(a: NewAccumulator[_, _]): Unit = {
     registerForCleanup(a, CleanAccum(a.id))
   }
 
@@ -241,7 +241,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
   def doCleanupAccum(accId: Long, blocking: Boolean): Unit = {
     try {
       logDebug("Cleaning accumulator " + accId)
-      Accumulators.remove(accId)
+      AccumulatorContext.remove(accId)
       listeners.asScala.foreach(_.accumCleaned(accId))
       logInfo("Cleaned accumulator " + accId)
     } catch {
diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
index 2bdbd3fae9b827e0a4523e030bc124c791ff7c4f..9eac05fdf9f3d3e15a56fdd02bd07d1d31297cb2 100644
--- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
+++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
@@ -35,7 +35,7 @@ import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils}
  */
 private[spark] case class Heartbeat(
     executorId: String,
-    accumUpdates: Array[(Long, Seq[AccumulableInfo])], // taskId -> accum updates
+    accumUpdates: Array[(Long, Seq[NewAccumulator[_, _]])], // taskId -> accumulator updates
     blockManagerId: BlockManagerId)
 
 /**
diff --git a/core/src/main/scala/org/apache/spark/NewAccumulator.scala b/core/src/main/scala/org/apache/spark/NewAccumulator.scala
new file mode 100644
index 0000000000000000000000000000000000000000..edb9b741a87123dcb637c0f6e42bdcc74a499752
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/NewAccumulator.scala
@@ -0,0 +1,391 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.{lang => jl}
+import java.io.ObjectInputStream
+import java.util.concurrent.atomic.AtomicLong
+import javax.annotation.concurrent.GuardedBy
+
+import org.apache.spark.scheduler.AccumulableInfo
+import org.apache.spark.util.Utils
+
+
+private[spark] case class AccumulatorMetadata(
+    id: Long,
+    name: Option[String],
+    countFailedValues: Boolean) extends Serializable
+
+
+/**
+ * The base class for accumulators, that can accumulate inputs of type `IN`, and produce output of
+ * type `OUT`.
+ */
+abstract class NewAccumulator[IN, OUT] extends Serializable {
+  private[spark] var metadata: AccumulatorMetadata = _
+  private[this] var atDriverSide = true
+
+  private[spark] def register(
+      sc: SparkContext,
+      name: Option[String] = None,
+      countFailedValues: Boolean = false): Unit = {
+    if (this.metadata != null) {
+      throw new IllegalStateException("Cannot register an Accumulator twice.")
+    }
+    this.metadata = AccumulatorMetadata(AccumulatorContext.newId(), name, countFailedValues)
+    AccumulatorContext.register(this)
+    sc.cleaner.foreach(_.registerAccumulatorForCleanup(this))
+  }
+
+  /**
+   * Returns true if this accumulator has been registered.  Note that all accumulators must be
+   * registered before ues, or it will throw exception.
+   */
+  final def isRegistered: Boolean =
+    metadata != null && AccumulatorContext.originals.containsKey(metadata.id)
+
+  private def assertMetadataNotNull(): Unit = {
+    if (metadata == null) {
+      throw new IllegalAccessError("The metadata of this accumulator has not been assigned yet.")
+    }
+  }
+
+  /**
+   * Returns the id of this accumulator, can only be called after registration.
+   */
+  final def id: Long = {
+    assertMetadataNotNull()
+    metadata.id
+  }
+
+  /**
+   * Returns the name of this accumulator, can only be called after registration.
+   */
+  final def name: Option[String] = {
+    assertMetadataNotNull()
+    metadata.name
+  }
+
+  /**
+   * Whether to accumulate values from failed tasks. This is set to true for system and time
+   * metrics like serialization time or bytes spilled, and false for things with absolute values
+   * like number of input rows.  This should be used for internal metrics only.
+   */
+  private[spark] final def countFailedValues: Boolean = {
+    assertMetadataNotNull()
+    metadata.countFailedValues
+  }
+
+  /**
+   * Creates an [[AccumulableInfo]] representation of this [[NewAccumulator]] with the provided
+   * values.
+   */
+  private[spark] def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = {
+    val isInternal = name.exists(_.startsWith(InternalAccumulator.METRICS_PREFIX))
+    new AccumulableInfo(id, name, update, value, isInternal, countFailedValues)
+  }
+
+  final private[spark] def isAtDriverSide: Boolean = atDriverSide
+
+  /**
+   * Tells if this accumulator is zero value or not. e.g. for a counter accumulator, 0 is zero
+   * value; for a list accumulator, Nil is zero value.
+   */
+  def isZero(): Boolean
+
+  /**
+   * Creates a new copy of this accumulator, which is zero value. i.e. call `isZero` on the copy
+   * must return true.
+   */
+  def copyAndReset(): NewAccumulator[IN, OUT]
+
+  /**
+   * Takes the inputs and accumulates. e.g. it can be a simple `+=` for counter accumulator.
+   */
+  def add(v: IN): Unit
+
+  /**
+   * Merges another same-type accumulator into this one and update its state, i.e. this should be
+   * merge-in-place.
+   */
+  def merge(other: NewAccumulator[IN, OUT]): Unit
+
+  /**
+   * Access this accumulator's current value; only allowed on driver.
+   */
+  final def value: OUT = {
+    if (atDriverSide) {
+      localValue
+    } else {
+      throw new UnsupportedOperationException("Can't read accumulator value in task")
+    }
+  }
+
+  /**
+   * Defines the current value of this accumulator.
+   *
+   * This is NOT the global value of the accumulator.  To get the global value after a
+   * completed operation on the dataset, call `value`.
+   */
+  def localValue: OUT
+
+  // Called by Java when serializing an object
+  final protected def writeReplace(): Any = {
+    if (atDriverSide) {
+      if (!isRegistered) {
+        throw new UnsupportedOperationException(
+          "Accumulator must be registered before send to executor")
+      }
+      val copy = copyAndReset()
+      assert(copy.isZero(), "copyAndReset must return a zero value copy")
+      copy.metadata = metadata
+      copy
+    } else {
+      this
+    }
+  }
+
+  // Called by Java when deserializing an object
+  private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
+    in.defaultReadObject()
+    if (atDriverSide) {
+      atDriverSide = false
+
+      // Automatically register the accumulator when it is deserialized with the task closure.
+      // This is for external accumulators and internal ones that do not represent task level
+      // metrics, e.g. internal SQL metrics, which are per-operator.
+      val taskContext = TaskContext.get()
+      if (taskContext != null) {
+        taskContext.registerAccumulator(this)
+      }
+    } else {
+      atDriverSide = true
+    }
+  }
+
+  override def toString: String = {
+    if (metadata == null) {
+      "Un-registered Accumulator: " + getClass.getSimpleName
+    } else {
+      getClass.getSimpleName + s"(id: $id, name: $name, value: $localValue)"
+    }
+  }
+}
+
+
+private[spark] object AccumulatorContext {
+
+  /**
+   * This global map holds the original accumulator objects that are created on the driver.
+   * It keeps weak references to these objects so that accumulators can be garbage-collected
+   * once the RDDs and user-code that reference them are cleaned up.
+   * TODO: Don't use a global map; these should be tied to a SparkContext (SPARK-13051).
+   */
+  @GuardedBy("AccumulatorContext")
+  val originals = new java.util.HashMap[Long, jl.ref.WeakReference[NewAccumulator[_, _]]]
+
+  private[this] val nextId = new AtomicLong(0L)
+
+  /**
+   * Return a globally unique ID for a new [[Accumulator]].
+   * Note: Once you copy the [[Accumulator]] the ID is no longer unique.
+   */
+  def newId(): Long = nextId.getAndIncrement
+
+  /**
+   * Register an [[Accumulator]] created on the driver such that it can be used on the executors.
+   *
+   * All accumulators registered here can later be used as a container for accumulating partial
+   * values across multiple tasks. This is what [[org.apache.spark.scheduler.DAGScheduler]] does.
+   * Note: if an accumulator is registered here, it should also be registered with the active
+   * context cleaner for cleanup so as to avoid memory leaks.
+   *
+   * If an [[Accumulator]] with the same ID was already registered, this does nothing instead
+   * of overwriting it. We will never register same accumulator twice, this is just a sanity check.
+   */
+  def register(a: NewAccumulator[_, _]): Unit = synchronized {
+    if (!originals.containsKey(a.id)) {
+      originals.put(a.id, new jl.ref.WeakReference[NewAccumulator[_, _]](a))
+    }
+  }
+
+  /**
+   * Unregister the [[Accumulator]] with the given ID, if any.
+   */
+  def remove(id: Long): Unit = synchronized {
+    originals.remove(id)
+  }
+
+  /**
+   * Return the [[Accumulator]] registered with the given ID, if any.
+   */
+  def get(id: Long): Option[NewAccumulator[_, _]] = synchronized {
+    Option(originals.get(id)).map { ref =>
+      // Since we are storing weak references, we must check whether the underlying data is valid.
+      val acc = ref.get
+      if (acc eq null) {
+        throw new IllegalAccessError(s"Attempted to access garbage collected accumulator $id")
+      }
+      acc
+    }
+  }
+
+  /**
+   * Clear all registered [[Accumulator]]s. For testing only.
+   */
+  def clear(): Unit = synchronized {
+    originals.clear()
+  }
+}
+
+
+class LongAccumulator extends NewAccumulator[jl.Long, jl.Long] {
+  private[this] var _sum = 0L
+
+  override def isZero(): Boolean = _sum == 0
+
+  override def copyAndReset(): LongAccumulator = new LongAccumulator
+
+  override def add(v: jl.Long): Unit = _sum += v
+
+  def add(v: Long): Unit = _sum += v
+
+  def sum: Long = _sum
+
+  override def merge(other: NewAccumulator[jl.Long, jl.Long]): Unit = other match {
+    case o: LongAccumulator => _sum += o.sum
+    case _ => throw new UnsupportedOperationException(
+      s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
+  }
+
+  private[spark] def setValue(newValue: Long): Unit = _sum = newValue
+
+  override def localValue: jl.Long = _sum
+}
+
+
+class DoubleAccumulator extends NewAccumulator[jl.Double, jl.Double] {
+  private[this] var _sum = 0.0
+
+  override def isZero(): Boolean = _sum == 0.0
+
+  override def copyAndReset(): DoubleAccumulator = new DoubleAccumulator
+
+  override def add(v: jl.Double): Unit = _sum += v
+
+  def add(v: Double): Unit = _sum += v
+
+  def sum: Double = _sum
+
+  override def merge(other: NewAccumulator[jl.Double, jl.Double]): Unit = other match {
+    case o: DoubleAccumulator => _sum += o.sum
+    case _ => throw new UnsupportedOperationException(
+      s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
+  }
+
+  private[spark] def setValue(newValue: Double): Unit = _sum = newValue
+
+  override def localValue: jl.Double = _sum
+}
+
+
+class AverageAccumulator extends NewAccumulator[jl.Double, jl.Double] {
+  private[this] var _sum = 0.0
+  private[this] var _count = 0L
+
+  override def isZero(): Boolean = _sum == 0.0 && _count == 0
+
+  override def copyAndReset(): AverageAccumulator = new AverageAccumulator
+
+  override def add(v: jl.Double): Unit = {
+    _sum += v
+    _count += 1
+  }
+
+  def add(d: Double): Unit = {
+    _sum += d
+    _count += 1
+  }
+
+  override def merge(other: NewAccumulator[jl.Double, jl.Double]): Unit = other match {
+    case o: AverageAccumulator =>
+      _sum += o.sum
+      _count += o.count
+    case _ => throw new UnsupportedOperationException(
+      s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
+  }
+
+  override def localValue: jl.Double = if (_count == 0) {
+    Double.NaN
+  } else {
+    _sum / _count
+  }
+
+  def sum: Double = _sum
+
+  def count: Long = _count
+}
+
+
+class ListAccumulator[T] extends NewAccumulator[T, java.util.List[T]] {
+  private[this] val _list: java.util.List[T] = new java.util.ArrayList[T]
+
+  override def isZero(): Boolean = _list.isEmpty
+
+  override def copyAndReset(): ListAccumulator[T] = new ListAccumulator
+
+  override def add(v: T): Unit = _list.add(v)
+
+  override def merge(other: NewAccumulator[T, java.util.List[T]]): Unit = other match {
+    case o: ListAccumulator[T] => _list.addAll(o.localValue)
+    case _ => throw new UnsupportedOperationException(
+      s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
+  }
+
+  override def localValue: java.util.List[T] = java.util.Collections.unmodifiableList(_list)
+
+  private[spark] def setValue(newValue: java.util.List[T]): Unit = {
+    _list.clear()
+    _list.addAll(newValue)
+  }
+}
+
+
+class LegacyAccumulatorWrapper[R, T](
+    initialValue: R,
+    param: org.apache.spark.AccumulableParam[R, T]) extends NewAccumulator[T, R] {
+  private[spark] var _value = initialValue  // Current value on driver
+
+  override def isZero(): Boolean = _value == param.zero(initialValue)
+
+  override def copyAndReset(): LegacyAccumulatorWrapper[R, T] = {
+    val acc = new LegacyAccumulatorWrapper(initialValue, param)
+    acc._value = param.zero(initialValue)
+    acc
+  }
+
+  override def add(v: T): Unit = _value = param.addAccumulator(_value, v)
+
+  override def merge(other: NewAccumulator[T, R]): Unit = other match {
+    case o: LegacyAccumulatorWrapper[R, T] => _value = param.addInPlace(_value, o.localValue)
+    case _ => throw new UnsupportedOperationException(
+      s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
+  }
+
+  override def localValue: R = _value
+}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index f322a770bf18795a13970da31e781e896f8b9589..865989aee0c8dcdec43896e32f1ef56510aff111 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -1217,10 +1217,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
    * Create an [[org.apache.spark.Accumulator]] variable of a given type, which tasks can "add"
    * values to using the `+=` method. Only the driver can access the accumulator's `value`.
    */
-  def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]): Accumulator[T] =
-  {
+  def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]): Accumulator[T] = {
     val acc = new Accumulator(initialValue, param)
-    cleaner.foreach(_.registerAccumulatorForCleanup(acc))
+    cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc))
     acc
   }
 
@@ -1232,7 +1231,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
   def accumulator[T](initialValue: T, name: String)(implicit param: AccumulatorParam[T])
     : Accumulator[T] = {
     val acc = new Accumulator(initialValue, param, Some(name))
-    cleaner.foreach(_.registerAccumulatorForCleanup(acc))
+    cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc))
     acc
   }
 
@@ -1245,7 +1244,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
   def accumulable[R, T](initialValue: R)(implicit param: AccumulableParam[R, T])
     : Accumulable[R, T] = {
     val acc = new Accumulable(initialValue, param)
-    cleaner.foreach(_.registerAccumulatorForCleanup(acc))
+    cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc))
     acc
   }
 
@@ -1259,7 +1258,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
   def accumulable[R, T](initialValue: R, name: String)(implicit param: AccumulableParam[R, T])
     : Accumulable[R, T] = {
     val acc = new Accumulable(initialValue, param, Some(name))
-    cleaner.foreach(_.registerAccumulatorForCleanup(acc))
+    cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc))
     acc
   }
 
@@ -1273,7 +1272,101 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
       (initialValue: R): Accumulable[R, T] = {
     val param = new GrowableAccumulableParam[R, T]
     val acc = new Accumulable(initialValue, param)
-    cleaner.foreach(_.registerAccumulatorForCleanup(acc))
+    cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc))
+    acc
+  }
+
+  /**
+   * Register the given accumulator.  Note that accumulators must be registered before use, or it
+   * will throw exception.
+   */
+  def register(acc: NewAccumulator[_, _]): Unit = {
+    acc.register(this)
+  }
+
+  /**
+   * Register the given accumulator with given name.  Note that accumulators must be registered
+   * before use, or it will throw exception.
+   */
+  def register(acc: NewAccumulator[_, _], name: String): Unit = {
+    acc.register(this, name = Some(name))
+  }
+
+  /**
+   * Create and register a long accumulator, which starts with 0 and accumulates inputs by `+=`.
+   */
+  def longAccumulator: LongAccumulator = {
+    val acc = new LongAccumulator
+    register(acc)
+    acc
+  }
+
+  /**
+   * Create and register a long accumulator, which starts with 0 and accumulates inputs by `+=`.
+   */
+  def longAccumulator(name: String): LongAccumulator = {
+    val acc = new LongAccumulator
+    register(acc, name)
+    acc
+  }
+
+  /**
+   * Create and register a double accumulator, which starts with 0 and accumulates inputs by `+=`.
+   */
+  def doubleAccumulator: DoubleAccumulator = {
+    val acc = new DoubleAccumulator
+    register(acc)
+    acc
+  }
+
+  /**
+   * Create and register a double accumulator, which starts with 0 and accumulates inputs by `+=`.
+   */
+  def doubleAccumulator(name: String): DoubleAccumulator = {
+    val acc = new DoubleAccumulator
+    register(acc, name)
+    acc
+  }
+
+  /**
+   * Create and register an average accumulator, which accumulates double inputs by recording the
+   * total sum and total count, and produce the output by sum / total.  Note that Double.NaN will be
+   * returned if no input is added.
+   */
+  def averageAccumulator: AverageAccumulator = {
+    val acc = new AverageAccumulator
+    register(acc)
+    acc
+  }
+
+  /**
+   * Create and register an average accumulator, which accumulates double inputs by recording the
+   * total sum and total count, and produce the output by sum / total.  Note that Double.NaN will be
+   * returned if no input is added.
+   */
+  def averageAccumulator(name: String): AverageAccumulator = {
+    val acc = new AverageAccumulator
+    register(acc, name)
+    acc
+  }
+
+  /**
+   * Create and register a list accumulator, which starts with empty list and accumulates inputs
+   * by adding them into the inner list.
+   */
+  def listAccumulator[T]: ListAccumulator[T] = {
+    val acc = new ListAccumulator[T]
+    register(acc)
+    acc
+  }
+
+  /**
+   * Create and register a list accumulator, which starts with empty list and accumulates inputs
+   * by adding them into the inner list.
+   */
+  def listAccumulator[T](name: String): ListAccumulator[T] = {
+    val acc = new ListAccumulator[T]
+    register(acc, name)
     acc
   }
 
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index e7940bd9eddcda1b4739534fd3e55872e03b541c..9e532574628675b6332e0037e333783976dd5221 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -188,6 +188,6 @@ abstract class TaskContext extends Serializable {
    * Register an accumulator that belongs to this task. Accumulators must call this method when
    * deserializing in executors.
    */
-  private[spark] def registerAccumulator(a: Accumulable[_, _]): Unit
+  private[spark] def registerAccumulator(a: NewAccumulator[_, _]): Unit
 
 }
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index 43e555670dc694c6911eb0dab6c956189f07233c..bc3807f5db1809c0d4622dbefdf20d344d42b3e2 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -122,7 +122,7 @@ private[spark] class TaskContextImpl(
   override def getMetricsSources(sourceName: String): Seq[Source] =
     metricsSystem.getSourcesByName(sourceName)
 
-  private[spark] override def registerAccumulator(a: Accumulable[_, _]): Unit = {
+  private[spark] override def registerAccumulator(a: NewAccumulator[_, _]): Unit = {
     taskMetrics.registerAccumulator(a)
   }
 
diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
index 7487cfe9c55094f98d724ab0955b1acb17deebe2..82ba2d0c274be1aa43352fc596c44b320effc8c3 100644
--- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala
+++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
@@ -19,10 +19,7 @@ package org.apache.spark
 
 import java.io.{ObjectInputStream, ObjectOutputStream}
 
-import scala.util.Try
-
 import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.executor.TaskMetrics
 import org.apache.spark.internal.Logging
 import org.apache.spark.scheduler.AccumulableInfo
 import org.apache.spark.storage.BlockManagerId
@@ -120,18 +117,10 @@ case class ExceptionFailure(
     stackTrace: Array[StackTraceElement],
     fullStackTrace: String,
     private val exceptionWrapper: Option[ThrowableSerializationWrapper],
-    accumUpdates: Seq[AccumulableInfo] = Seq.empty[AccumulableInfo])
+    accumUpdates: Seq[AccumulableInfo] = Seq.empty,
+    private[spark] var accums: Seq[NewAccumulator[_, _]] = Nil)
   extends TaskFailedReason {
 
-  @deprecated("use accumUpdates instead", "2.0.0")
-  val metrics: Option[TaskMetrics] = {
-    if (accumUpdates.nonEmpty) {
-      Try(TaskMetrics.fromAccumulatorUpdates(accumUpdates)).toOption
-    } else {
-      None
-    }
-  }
-
   /**
    * `preserveCause` is used to keep the exception itself so it is available to the
    * driver. This may be set to `false` in the event that the exception is not in fact
@@ -149,6 +138,11 @@ case class ExceptionFailure(
     this(e, accumUpdates, preserveCause = true)
   }
 
+  private[spark] def withAccums(accums: Seq[NewAccumulator[_, _]]): ExceptionFailure = {
+    this.accums = accums
+    this
+  }
+
   def exception: Option[Throwable] = exceptionWrapper.flatMap(w => Option(w.exception))
 
   override def toErrorString: String =
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 650f05c309d2047e602e9b97def3e13cff190b4f..4d61f7e23248b5c7629bf25b44754e82446ccf63 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -353,22 +353,24 @@ private[spark] class Executor(
           logError(s"Exception in $taskName (TID $taskId)", t)
 
           // Collect latest accumulator values to report back to the driver
-          val accumulatorUpdates: Seq[AccumulableInfo] =
+          val accums: Seq[NewAccumulator[_, _]] =
             if (task != null) {
               task.metrics.setExecutorRunTime(System.currentTimeMillis() - taskStart)
               task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime)
               task.collectAccumulatorUpdates(taskFailed = true)
             } else {
-              Seq.empty[AccumulableInfo]
+              Seq.empty
             }
 
+          val accUpdates = accums.map(acc => acc.toInfo(Some(acc.localValue), None))
+
           val serializedTaskEndReason = {
             try {
-              ser.serialize(new ExceptionFailure(t, accumulatorUpdates))
+              ser.serialize(new ExceptionFailure(t, accUpdates).withAccums(accums))
             } catch {
               case _: NotSerializableException =>
                 // t is not serializable so just send the stacktrace
-                ser.serialize(new ExceptionFailure(t, accumulatorUpdates, preserveCause = false))
+                ser.serialize(new ExceptionFailure(t, accUpdates, false).withAccums(accums))
             }
           }
           execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason)
@@ -476,14 +478,14 @@ private[spark] class Executor(
   /** Reports heartbeat and metrics for active tasks to the driver. */
   private def reportHeartBeat(): Unit = {
     // list of (task id, accumUpdates) to send back to the driver
-    val accumUpdates = new ArrayBuffer[(Long, Seq[AccumulableInfo])]()
+    val accumUpdates = new ArrayBuffer[(Long, Seq[NewAccumulator[_, _]])]()
     val curGCTime = computeTotalGcTime()
 
     for (taskRunner <- runningTasks.values().asScala) {
       if (taskRunner.task != null) {
         taskRunner.task.metrics.mergeShuffleReadMetrics()
         taskRunner.task.metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime)
-        accumUpdates += ((taskRunner.taskId, taskRunner.task.metrics.accumulatorUpdates()))
+        accumUpdates += ((taskRunner.taskId, taskRunner.task.metrics.accumulators()))
       }
     }
 
diff --git a/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala
index 535352e7dd7a1e25acb42ee6772da2f34ddacf6b..6f7160ac0d3a37762d7221d783f0760dd28b9a0b 100644
--- a/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.executor
 
-import org.apache.spark.InternalAccumulator
+import org.apache.spark.LongAccumulator
 import org.apache.spark.annotation.DeveloperApi
 
 
@@ -40,20 +40,18 @@ object DataReadMethod extends Enumeration with Serializable {
  */
 @DeveloperApi
 class InputMetrics private[spark] () extends Serializable {
-  import InternalAccumulator._
-
-  private[executor] val _bytesRead = TaskMetrics.createLongAccum(input.BYTES_READ)
-  private[executor] val _recordsRead = TaskMetrics.createLongAccum(input.RECORDS_READ)
+  private[executor] val _bytesRead = new LongAccumulator
+  private[executor] val _recordsRead = new LongAccumulator
 
   /**
    * Total number of bytes read.
    */
-  def bytesRead: Long = _bytesRead.localValue
+  def bytesRead: Long = _bytesRead.sum
 
   /**
    * Total number of records read.
    */
-  def recordsRead: Long = _recordsRead.localValue
+  def recordsRead: Long = _recordsRead.sum
 
   private[spark] def incBytesRead(v: Long): Unit = _bytesRead.add(v)
   private[spark] def incRecordsRead(v: Long): Unit = _recordsRead.add(v)
diff --git a/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala b/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala
index 586c98b15637b1d821afda3f4d437f929813e78f..db3924cb6937e1db906e39d44fa4a930a83eb963 100644
--- a/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.executor
 
-import org.apache.spark.InternalAccumulator
+import org.apache.spark.LongAccumulator
 import org.apache.spark.annotation.DeveloperApi
 
 
@@ -39,20 +39,18 @@ object DataWriteMethod extends Enumeration with Serializable {
  */
 @DeveloperApi
 class OutputMetrics private[spark] () extends Serializable {
-  import InternalAccumulator._
-
-  private[executor] val _bytesWritten = TaskMetrics.createLongAccum(output.BYTES_WRITTEN)
-  private[executor] val _recordsWritten = TaskMetrics.createLongAccum(output.RECORDS_WRITTEN)
+  private[executor] val _bytesWritten = new LongAccumulator
+  private[executor] val _recordsWritten = new LongAccumulator
 
   /**
    * Total number of bytes written.
    */
-  def bytesWritten: Long = _bytesWritten.localValue
+  def bytesWritten: Long = _bytesWritten.sum
 
   /**
    * Total number of records written.
    */
-  def recordsWritten: Long = _recordsWritten.localValue
+  def recordsWritten: Long = _recordsWritten.sum
 
   private[spark] def setBytesWritten(v: Long): Unit = _bytesWritten.setValue(v)
   private[spark] def setRecordsWritten(v: Long): Unit = _recordsWritten.setValue(v)
diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala
index f012a74db6c2c4d30a9fbf79faa360d2229377d1..fa962108c3064f0ba8ba00378938521a59ad2b7c 100644
--- a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.executor
 
-import org.apache.spark.InternalAccumulator
+import org.apache.spark.LongAccumulator
 import org.apache.spark.annotation.DeveloperApi
 
 
@@ -28,52 +28,44 @@ import org.apache.spark.annotation.DeveloperApi
  */
 @DeveloperApi
 class ShuffleReadMetrics private[spark] () extends Serializable {
-  import InternalAccumulator._
-
-  private[executor] val _remoteBlocksFetched =
-    TaskMetrics.createIntAccum(shuffleRead.REMOTE_BLOCKS_FETCHED)
-  private[executor] val _localBlocksFetched =
-    TaskMetrics.createIntAccum(shuffleRead.LOCAL_BLOCKS_FETCHED)
-  private[executor] val _remoteBytesRead =
-    TaskMetrics.createLongAccum(shuffleRead.REMOTE_BYTES_READ)
-  private[executor] val _localBytesRead =
-    TaskMetrics.createLongAccum(shuffleRead.LOCAL_BYTES_READ)
-  private[executor] val _fetchWaitTime =
-    TaskMetrics.createLongAccum(shuffleRead.FETCH_WAIT_TIME)
-  private[executor] val _recordsRead =
-    TaskMetrics.createLongAccum(shuffleRead.RECORDS_READ)
+  private[executor] val _remoteBlocksFetched = new LongAccumulator
+  private[executor] val _localBlocksFetched = new LongAccumulator
+  private[executor] val _remoteBytesRead = new LongAccumulator
+  private[executor] val _localBytesRead = new LongAccumulator
+  private[executor] val _fetchWaitTime = new LongAccumulator
+  private[executor] val _recordsRead = new LongAccumulator
 
   /**
    * Number of remote blocks fetched in this shuffle by this task.
    */
-  def remoteBlocksFetched: Int = _remoteBlocksFetched.localValue
+  def remoteBlocksFetched: Long = _remoteBlocksFetched.sum
 
   /**
    * Number of local blocks fetched in this shuffle by this task.
    */
-  def localBlocksFetched: Int = _localBlocksFetched.localValue
+  def localBlocksFetched: Long = _localBlocksFetched.sum
 
   /**
    * Total number of remote bytes read from the shuffle by this task.
    */
-  def remoteBytesRead: Long = _remoteBytesRead.localValue
+  def remoteBytesRead: Long = _remoteBytesRead.sum
 
   /**
    * Shuffle data that was read from the local disk (as opposed to from a remote executor).
    */
-  def localBytesRead: Long = _localBytesRead.localValue
+  def localBytesRead: Long = _localBytesRead.sum
 
   /**
    * Time the task spent waiting for remote shuffle blocks. This only includes the time
    * blocking on shuffle input data. For instance if block B is being fetched while the task is
    * still not finished processing block A, it is not considered to be blocking on block B.
    */
-  def fetchWaitTime: Long = _fetchWaitTime.localValue
+  def fetchWaitTime: Long = _fetchWaitTime.sum
 
   /**
    * Total number of records read from the shuffle by this task.
    */
-  def recordsRead: Long = _recordsRead.localValue
+  def recordsRead: Long = _recordsRead.sum
 
   /**
    * Total bytes fetched in the shuffle by this task (both remote and local).
@@ -83,10 +75,10 @@ class ShuffleReadMetrics private[spark] () extends Serializable {
   /**
    * Number of blocks fetched in this shuffle by this task (remote or local).
    */
-  def totalBlocksFetched: Int = remoteBlocksFetched + localBlocksFetched
+  def totalBlocksFetched: Long = remoteBlocksFetched + localBlocksFetched
 
-  private[spark] def incRemoteBlocksFetched(v: Int): Unit = _remoteBlocksFetched.add(v)
-  private[spark] def incLocalBlocksFetched(v: Int): Unit = _localBlocksFetched.add(v)
+  private[spark] def incRemoteBlocksFetched(v: Long): Unit = _remoteBlocksFetched.add(v)
+  private[spark] def incLocalBlocksFetched(v: Long): Unit = _localBlocksFetched.add(v)
   private[spark] def incRemoteBytesRead(v: Long): Unit = _remoteBytesRead.add(v)
   private[spark] def incLocalBytesRead(v: Long): Unit = _localBytesRead.add(v)
   private[spark] def incFetchWaitTime(v: Long): Unit = _fetchWaitTime.add(v)
@@ -104,12 +96,12 @@ class ShuffleReadMetrics private[spark] () extends Serializable {
    * [[TempShuffleReadMetrics]] into `this`.
    */
   private[spark] def setMergeValues(metrics: Seq[TempShuffleReadMetrics]): Unit = {
-    _remoteBlocksFetched.setValue(_remoteBlocksFetched.zero)
-    _localBlocksFetched.setValue(_localBlocksFetched.zero)
-    _remoteBytesRead.setValue(_remoteBytesRead.zero)
-    _localBytesRead.setValue(_localBytesRead.zero)
-    _fetchWaitTime.setValue(_fetchWaitTime.zero)
-    _recordsRead.setValue(_recordsRead.zero)
+    _remoteBlocksFetched.setValue(0)
+    _localBlocksFetched.setValue(0)
+    _remoteBytesRead.setValue(0)
+    _localBytesRead.setValue(0)
+    _fetchWaitTime.setValue(0)
+    _recordsRead.setValue(0)
     metrics.foreach { metric =>
       _remoteBlocksFetched.add(metric.remoteBlocksFetched)
       _localBlocksFetched.add(metric.localBlocksFetched)
@@ -127,22 +119,22 @@ class ShuffleReadMetrics private[spark] () extends Serializable {
  * last.
  */
 private[spark] class TempShuffleReadMetrics {
-  private[this] var _remoteBlocksFetched = 0
-  private[this] var _localBlocksFetched = 0
+  private[this] var _remoteBlocksFetched = 0L
+  private[this] var _localBlocksFetched = 0L
   private[this] var _remoteBytesRead = 0L
   private[this] var _localBytesRead = 0L
   private[this] var _fetchWaitTime = 0L
   private[this] var _recordsRead = 0L
 
-  def incRemoteBlocksFetched(v: Int): Unit = _remoteBlocksFetched += v
-  def incLocalBlocksFetched(v: Int): Unit = _localBlocksFetched += v
+  def incRemoteBlocksFetched(v: Long): Unit = _remoteBlocksFetched += v
+  def incLocalBlocksFetched(v: Long): Unit = _localBlocksFetched += v
   def incRemoteBytesRead(v: Long): Unit = _remoteBytesRead += v
   def incLocalBytesRead(v: Long): Unit = _localBytesRead += v
   def incFetchWaitTime(v: Long): Unit = _fetchWaitTime += v
   def incRecordsRead(v: Long): Unit = _recordsRead += v
 
-  def remoteBlocksFetched: Int = _remoteBlocksFetched
-  def localBlocksFetched: Int = _localBlocksFetched
+  def remoteBlocksFetched: Long = _remoteBlocksFetched
+  def localBlocksFetched: Long = _localBlocksFetched
   def remoteBytesRead: Long = _remoteBytesRead
   def localBytesRead: Long = _localBytesRead
   def fetchWaitTime: Long = _fetchWaitTime
diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala
index 7326fba841587241199138269bb8c9ecc9e87661..0e70a4f522849a7f8ac03a0d04ef0b8055eea98a 100644
--- a/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.executor
 
-import org.apache.spark.InternalAccumulator
+import org.apache.spark.LongAccumulator
 import org.apache.spark.annotation.DeveloperApi
 
 
@@ -28,29 +28,24 @@ import org.apache.spark.annotation.DeveloperApi
  */
 @DeveloperApi
 class ShuffleWriteMetrics private[spark] () extends Serializable {
-  import InternalAccumulator._
-
-  private[executor] val _bytesWritten =
-    TaskMetrics.createLongAccum(shuffleWrite.BYTES_WRITTEN)
-  private[executor] val _recordsWritten =
-    TaskMetrics.createLongAccum(shuffleWrite.RECORDS_WRITTEN)
-  private[executor] val _writeTime =
-    TaskMetrics.createLongAccum(shuffleWrite.WRITE_TIME)
+  private[executor] val _bytesWritten = new LongAccumulator
+  private[executor] val _recordsWritten = new LongAccumulator
+  private[executor] val _writeTime = new LongAccumulator
 
   /**
    * Number of bytes written for the shuffle by this task.
    */
-  def bytesWritten: Long = _bytesWritten.localValue
+  def bytesWritten: Long = _bytesWritten.sum
 
   /**
    * Total number of records written to the shuffle by this task.
    */
-  def recordsWritten: Long = _recordsWritten.localValue
+  def recordsWritten: Long = _recordsWritten.sum
 
   /**
    * Time the task spent blocking on writes to disk or buffer cache, in nanoseconds.
    */
-  def writeTime: Long = _writeTime.localValue
+  def writeTime: Long = _writeTime.sum
 
   private[spark] def incBytesWritten(v: Long): Unit = _bytesWritten.add(v)
   private[spark] def incRecordsWritten(v: Long): Unit = _recordsWritten.add(v)
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index 8513d053f2e971dc1964c1aac39163c12c0f2d77..0b64917219a7e29ffe340473341b077592a6251a 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -17,10 +17,9 @@
 
 package org.apache.spark.executor
 
-import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.{ArrayBuffer, LinkedHashMap}
 
 import org.apache.spark._
-import org.apache.spark.AccumulatorParam.{IntAccumulatorParam, LongAccumulatorParam, UpdatedBlockStatusesAccumulatorParam}
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.internal.Logging
 import org.apache.spark.scheduler.AccumulableInfo
@@ -42,53 +41,51 @@ import org.apache.spark.storage.{BlockId, BlockStatus}
  */
 @DeveloperApi
 class TaskMetrics private[spark] () extends Serializable {
-  import InternalAccumulator._
-
   // Each metric is internally represented as an accumulator
-  private val _executorDeserializeTime = TaskMetrics.createLongAccum(EXECUTOR_DESERIALIZE_TIME)
-  private val _executorRunTime = TaskMetrics.createLongAccum(EXECUTOR_RUN_TIME)
-  private val _resultSize = TaskMetrics.createLongAccum(RESULT_SIZE)
-  private val _jvmGCTime = TaskMetrics.createLongAccum(JVM_GC_TIME)
-  private val _resultSerializationTime = TaskMetrics.createLongAccum(RESULT_SERIALIZATION_TIME)
-  private val _memoryBytesSpilled = TaskMetrics.createLongAccum(MEMORY_BYTES_SPILLED)
-  private val _diskBytesSpilled = TaskMetrics.createLongAccum(DISK_BYTES_SPILLED)
-  private val _peakExecutionMemory = TaskMetrics.createLongAccum(PEAK_EXECUTION_MEMORY)
-  private val _updatedBlockStatuses = TaskMetrics.createBlocksAccum(UPDATED_BLOCK_STATUSES)
+  private val _executorDeserializeTime = new LongAccumulator
+  private val _executorRunTime = new LongAccumulator
+  private val _resultSize = new LongAccumulator
+  private val _jvmGCTime = new LongAccumulator
+  private val _resultSerializationTime = new LongAccumulator
+  private val _memoryBytesSpilled = new LongAccumulator
+  private val _diskBytesSpilled = new LongAccumulator
+  private val _peakExecutionMemory = new LongAccumulator
+  private val _updatedBlockStatuses = new BlockStatusesAccumulator
 
   /**
    * Time taken on the executor to deserialize this task.
    */
-  def executorDeserializeTime: Long = _executorDeserializeTime.localValue
+  def executorDeserializeTime: Long = _executorDeserializeTime.sum
 
   /**
    * Time the executor spends actually running the task (including fetching shuffle data).
    */
-  def executorRunTime: Long = _executorRunTime.localValue
+  def executorRunTime: Long = _executorRunTime.sum
 
   /**
    * The number of bytes this task transmitted back to the driver as the TaskResult.
    */
-  def resultSize: Long = _resultSize.localValue
+  def resultSize: Long = _resultSize.sum
 
   /**
    * Amount of time the JVM spent in garbage collection while executing this task.
    */
-  def jvmGCTime: Long = _jvmGCTime.localValue
+  def jvmGCTime: Long = _jvmGCTime.sum
 
   /**
    * Amount of time spent serializing the task result.
    */
-  def resultSerializationTime: Long = _resultSerializationTime.localValue
+  def resultSerializationTime: Long = _resultSerializationTime.sum
 
   /**
    * The number of in-memory bytes spilled by this task.
    */
-  def memoryBytesSpilled: Long = _memoryBytesSpilled.localValue
+  def memoryBytesSpilled: Long = _memoryBytesSpilled.sum
 
   /**
    * The number of on-disk bytes spilled by this task.
    */
-  def diskBytesSpilled: Long = _diskBytesSpilled.localValue
+  def diskBytesSpilled: Long = _diskBytesSpilled.sum
 
   /**
    * Peak memory used by internal data structures created during shuffles, aggregations and
@@ -96,7 +93,7 @@ class TaskMetrics private[spark] () extends Serializable {
    * across all such data structures created in this task. For SQL jobs, this only tracks all
    * unsafe operators and ExternalSort.
    */
-  def peakExecutionMemory: Long = _peakExecutionMemory.localValue
+  def peakExecutionMemory: Long = _peakExecutionMemory.sum
 
   /**
    * Storage statuses of any blocks that have been updated as a result of this task.
@@ -114,7 +111,7 @@ class TaskMetrics private[spark] () extends Serializable {
   private[spark] def incMemoryBytesSpilled(v: Long): Unit = _memoryBytesSpilled.add(v)
   private[spark] def incDiskBytesSpilled(v: Long): Unit = _diskBytesSpilled.add(v)
   private[spark] def incPeakExecutionMemory(v: Long): Unit = _peakExecutionMemory.add(v)
-  private[spark] def incUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit =
+  private[spark] def incUpdatedBlockStatuses(v: (BlockId, BlockStatus)): Unit =
     _updatedBlockStatuses.add(v)
   private[spark] def setUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit =
     _updatedBlockStatuses.setValue(v)
@@ -175,124 +172,143 @@ class TaskMetrics private[spark] () extends Serializable {
   }
 
   // Only used for test
-  private[spark] val testAccum =
-    sys.props.get("spark.testing").map(_ => TaskMetrics.createLongAccum(TEST_ACCUM))
-
-  @transient private[spark] lazy val internalAccums: Seq[Accumulable[_, _]] = {
-    val in = inputMetrics
-    val out = outputMetrics
-    val sr = shuffleReadMetrics
-    val sw = shuffleWriteMetrics
-    Seq(_executorDeserializeTime, _executorRunTime, _resultSize, _jvmGCTime,
-      _resultSerializationTime, _memoryBytesSpilled, _diskBytesSpilled, _peakExecutionMemory,
-      _updatedBlockStatuses, sr._remoteBlocksFetched, sr._localBlocksFetched, sr._remoteBytesRead,
-      sr._localBytesRead, sr._fetchWaitTime, sr._recordsRead, sw._bytesWritten, sw._recordsWritten,
-      sw._writeTime, in._bytesRead, in._recordsRead, out._bytesWritten, out._recordsWritten) ++
-      testAccum
-  }
+  private[spark] val testAccum = sys.props.get("spark.testing").map(_ => new LongAccumulator)
+
+
+  import InternalAccumulator._
+  @transient private[spark] lazy val nameToAccums = LinkedHashMap(
+    EXECUTOR_DESERIALIZE_TIME -> _executorDeserializeTime,
+    EXECUTOR_RUN_TIME -> _executorRunTime,
+    RESULT_SIZE -> _resultSize,
+    JVM_GC_TIME -> _jvmGCTime,
+    RESULT_SERIALIZATION_TIME -> _resultSerializationTime,
+    MEMORY_BYTES_SPILLED -> _memoryBytesSpilled,
+    DISK_BYTES_SPILLED -> _diskBytesSpilled,
+    PEAK_EXECUTION_MEMORY -> _peakExecutionMemory,
+    UPDATED_BLOCK_STATUSES -> _updatedBlockStatuses,
+    shuffleRead.REMOTE_BLOCKS_FETCHED -> shuffleReadMetrics._remoteBlocksFetched,
+    shuffleRead.LOCAL_BLOCKS_FETCHED -> shuffleReadMetrics._localBlocksFetched,
+    shuffleRead.REMOTE_BYTES_READ -> shuffleReadMetrics._remoteBytesRead,
+    shuffleRead.LOCAL_BYTES_READ -> shuffleReadMetrics._localBytesRead,
+    shuffleRead.FETCH_WAIT_TIME -> shuffleReadMetrics._fetchWaitTime,
+    shuffleRead.RECORDS_READ -> shuffleReadMetrics._recordsRead,
+    shuffleWrite.BYTES_WRITTEN -> shuffleWriteMetrics._bytesWritten,
+    shuffleWrite.RECORDS_WRITTEN -> shuffleWriteMetrics._recordsWritten,
+    shuffleWrite.WRITE_TIME -> shuffleWriteMetrics._writeTime,
+    input.BYTES_READ -> inputMetrics._bytesRead,
+    input.RECORDS_READ -> inputMetrics._recordsRead,
+    output.BYTES_WRITTEN -> outputMetrics._bytesWritten,
+    output.RECORDS_WRITTEN -> outputMetrics._recordsWritten
+  ) ++ testAccum.map(TEST_ACCUM -> _)
+
+  @transient private[spark] lazy val internalAccums: Seq[NewAccumulator[_, _]] =
+    nameToAccums.values.toIndexedSeq
 
   /* ========================== *
    |        OTHER THINGS        |
    * ========================== */
 
-  private[spark] def registerForCleanup(sc: SparkContext): Unit = {
-    internalAccums.foreach { accum =>
-      sc.cleaner.foreach(_.registerAccumulatorForCleanup(accum))
+  private[spark] def register(sc: SparkContext): Unit = {
+    nameToAccums.foreach {
+      case (name, acc) => acc.register(sc, name = Some(name), countFailedValues = true)
     }
   }
 
   /**
    * External accumulators registered with this task.
    */
-  @transient private lazy val externalAccums = new ArrayBuffer[Accumulable[_, _]]
+  @transient private lazy val externalAccums = new ArrayBuffer[NewAccumulator[_, _]]
 
-  private[spark] def registerAccumulator(a: Accumulable[_, _]): Unit = {
+  private[spark] def registerAccumulator(a: NewAccumulator[_, _]): Unit = {
     externalAccums += a
   }
 
-  /**
-   * Return the latest updates of accumulators in this task.
-   *
-   * The [[AccumulableInfo.update]] field is always defined and the [[AccumulableInfo.value]]
-   * field is always empty, since this represents the partial updates recorded in this task,
-   * not the aggregated value across multiple tasks.
-   */
-  def accumulatorUpdates(): Seq[AccumulableInfo] = {
-    (internalAccums ++ externalAccums).map { a => a.toInfo(Some(a.localValue), None) }
-  }
+  private[spark] def accumulators(): Seq[NewAccumulator[_, _]] = internalAccums ++ externalAccums
 }
 
-/**
- * Internal subclass of [[TaskMetrics]] which is used only for posting events to listeners.
- * Its purpose is to obviate the need for the driver to reconstruct the original accumulators,
- * which might have been garbage-collected. See SPARK-13407 for more details.
- *
- * Instances of this class should be considered read-only and users should not call `inc*()` or
- * `set*()` methods. While we could override the setter methods to throw
- * UnsupportedOperationException, we choose not to do so because the overrides would quickly become
- * out-of-date when new metrics are added.
- */
-private[spark] class ListenerTaskMetrics(accumUpdates: Seq[AccumulableInfo]) extends TaskMetrics {
-
-  override def accumulatorUpdates(): Seq[AccumulableInfo] = accumUpdates
-
-  override private[spark] def registerAccumulator(a: Accumulable[_, _]): Unit = {
-    throw new UnsupportedOperationException("This TaskMetrics is read-only")
-  }
-}
 
 private[spark] object TaskMetrics extends Logging {
+  import InternalAccumulator._
 
   /**
    * Create an empty task metrics that doesn't register its accumulators.
    */
   def empty: TaskMetrics = {
-    val metrics = new TaskMetrics
-    metrics.internalAccums.foreach(acc => Accumulators.remove(acc.id))
-    metrics
+    val tm = new TaskMetrics
+    tm.nameToAccums.foreach { case (name, acc) =>
+      acc.metadata = AccumulatorMetadata(AccumulatorContext.newId(), Some(name), true)
+    }
+    tm
+  }
+
+  def registered: TaskMetrics = {
+    val tm = empty
+    tm.internalAccums.foreach(AccumulatorContext.register)
+    tm
   }
 
   /**
-   * Create a new accumulator representing an internal task metric.
+   * Construct a [[TaskMetrics]] object from a list of [[AccumulableInfo]], called on driver only.
+   * The returned [[TaskMetrics]] is only used to get some internal metrics, we don't need to take
+   * care of external accumulator info passed in.
    */
-  private def newMetric[T](
-      initialValue: T,
-      name: String,
-      param: AccumulatorParam[T]): Accumulator[T] = {
-    new Accumulator[T](initialValue, param, Some(name), countFailedValues = true)
+  def fromAccumulatorInfos(infos: Seq[AccumulableInfo]): TaskMetrics = {
+    val tm = new TaskMetrics
+    infos.filter(info => info.name.isDefined && info.update.isDefined).foreach { info =>
+      val name = info.name.get
+      val value = info.update.get
+      if (name == UPDATED_BLOCK_STATUSES) {
+        tm.setUpdatedBlockStatuses(value.asInstanceOf[Seq[(BlockId, BlockStatus)]])
+      } else {
+        tm.nameToAccums.get(name).foreach(
+          _.asInstanceOf[LongAccumulator].setValue(value.asInstanceOf[Long])
+        )
+      }
+    }
+    tm
   }
 
-  def createLongAccum(name: String): Accumulator[Long] = {
-    newMetric(0L, name, LongAccumulatorParam)
-  }
+  /**
+   * Construct a [[TaskMetrics]] object from a list of accumulator updates, called on driver only.
+   */
+  def fromAccumulators(accums: Seq[NewAccumulator[_, _]]): TaskMetrics = {
+    val tm = new TaskMetrics
+    val (internalAccums, externalAccums) =
+      accums.partition(a => a.name.isDefined && tm.nameToAccums.contains(a.name.get))
+
+    internalAccums.foreach { acc =>
+      val tmAcc = tm.nameToAccums(acc.name.get).asInstanceOf[NewAccumulator[Any, Any]]
+      tmAcc.metadata = acc.metadata
+      tmAcc.merge(acc.asInstanceOf[NewAccumulator[Any, Any]])
+    }
 
-  def createIntAccum(name: String): Accumulator[Int] = {
-    newMetric(0, name, IntAccumulatorParam)
+    tm.externalAccums ++= externalAccums
+    tm
   }
+}
+
+
+private[spark] class BlockStatusesAccumulator
+  extends NewAccumulator[(BlockId, BlockStatus), Seq[(BlockId, BlockStatus)]] {
+  private[this] var _seq = ArrayBuffer.empty[(BlockId, BlockStatus)]
 
-  def createBlocksAccum(name: String): Accumulator[Seq[(BlockId, BlockStatus)]] = {
-    newMetric(Nil, name, UpdatedBlockStatusesAccumulatorParam)
+  override def isZero(): Boolean = _seq.isEmpty
+
+  override def copyAndReset(): BlockStatusesAccumulator = new BlockStatusesAccumulator
+
+  override def add(v: (BlockId, BlockStatus)): Unit = _seq += v
+
+  override def merge(other: NewAccumulator[(BlockId, BlockStatus), Seq[(BlockId, BlockStatus)]])
+  : Unit = other match {
+    case o: BlockStatusesAccumulator => _seq ++= o.localValue
+    case _ => throw new UnsupportedOperationException(
+      s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
   }
 
-  /**
-   * Construct a [[TaskMetrics]] object from a list of accumulator updates, called on driver only.
-   *
-   * Executors only send accumulator updates back to the driver, not [[TaskMetrics]]. However, we
-   * need the latter to post task end events to listeners, so we need to reconstruct the metrics
-   * on the driver.
-   *
-   * This assumes the provided updates contain the initial set of accumulators representing
-   * internal task level metrics.
-   */
-  def fromAccumulatorUpdates(accumUpdates: Seq[AccumulableInfo]): TaskMetrics = {
-    val definedAccumUpdates = accumUpdates.filter(_.update.isDefined)
-    val metrics = new ListenerTaskMetrics(definedAccumUpdates)
-    // We don't register this [[ListenerTaskMetrics]] for cleanup, and this is only used to post
-    // event, we should un-register all accumulators immediately.
-    metrics.internalAccums.foreach(acc => Accumulators.remove(acc.id))
-    definedAccumUpdates.filter(_.internal).foreach { accum =>
-      metrics.internalAccums.find(_.name == accum.name).foreach(_.setValueAny(accum.update.get))
-    }
-    metrics
+  override def localValue: Seq[(BlockId, BlockStatus)] = _seq
+
+  def setValue(newValue: Seq[(BlockId, BlockStatus)]): Unit = {
+    _seq.clear()
+    _seq ++= newValue
   }
 }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index b7fb608ea50641960871d3c757e948c5e02fbdef..a96d5f6fbf082e100c841cc96aab504b2bee5118 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -209,7 +209,7 @@ class DAGScheduler(
       task: Task[_],
       reason: TaskEndReason,
       result: Any,
-      accumUpdates: Seq[AccumulableInfo],
+      accumUpdates: Seq[NewAccumulator[_, _]],
       taskInfo: TaskInfo): Unit = {
     eventProcessLoop.post(
       CompletionEvent(task, reason, result, accumUpdates, taskInfo))
@@ -1088,21 +1088,19 @@ class DAGScheduler(
     val task = event.task
     val stage = stageIdToStage(task.stageId)
     try {
-      event.accumUpdates.foreach { ainfo =>
-        assert(ainfo.update.isDefined, "accumulator from task should have a partial value")
-        val id = ainfo.id
-        val partialValue = ainfo.update.get
+      event.accumUpdates.foreach { updates =>
+        val id = updates.id
         // Find the corresponding accumulator on the driver and update it
-        val acc: Accumulable[Any, Any] = Accumulators.get(id) match {
-          case Some(accum) => accum.asInstanceOf[Accumulable[Any, Any]]
+        val acc: NewAccumulator[Any, Any] = AccumulatorContext.get(id) match {
+          case Some(accum) => accum.asInstanceOf[NewAccumulator[Any, Any]]
           case None =>
             throw new SparkException(s"attempted to access non-existent accumulator $id")
         }
-        acc ++= partialValue
+        acc.merge(updates.asInstanceOf[NewAccumulator[Any, Any]])
         // To avoid UI cruft, ignore cases where value wasn't updated
-        if (acc.name.isDefined && partialValue != acc.zero) {
+        if (acc.name.isDefined && !updates.isZero()) {
           stage.latestInfo.accumulables(id) = acc.toInfo(None, Some(acc.value))
-          event.taskInfo.accumulables += acc.toInfo(Some(partialValue), Some(acc.value))
+          event.taskInfo.accumulables += acc.toInfo(Some(updates.value), Some(acc.value))
         }
       }
     } catch {
@@ -1131,7 +1129,7 @@ class DAGScheduler(
     val taskMetrics: TaskMetrics =
       if (event.accumUpdates.nonEmpty) {
         try {
-          TaskMetrics.fromAccumulatorUpdates(event.accumUpdates)
+          TaskMetrics.fromAccumulators(event.accumUpdates)
         } catch {
           case NonFatal(e) =>
             logError(s"Error when attempting to reconstruct metrics for task $taskId", e)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
index a3845c6acd77419742fd31e3ed8f618f6f72690e..e57a2246d87293428eb778eb6985f3eb17e89d3e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
@@ -71,7 +71,7 @@ private[scheduler] case class CompletionEvent(
     task: Task[_],
     reason: TaskEndReason,
     result: Any,
-    accumUpdates: Seq[AccumulableInfo],
+    accumUpdates: Seq[NewAccumulator[_, _]],
     taskInfo: TaskInfo)
   extends DAGSchedulerEvent
 
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
index 080ea6c33a7dd309bba3b881c75b41adb7c05a5d..7618dfeeedf8d40dba8942fdba31eeeea028771d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
@@ -21,18 +21,15 @@ import java.util.Properties
 import javax.annotation.Nullable
 
 import scala.collection.Map
-import scala.collection.mutable
 
 import com.fasterxml.jackson.annotation.JsonTypeInfo
 
 import org.apache.spark.{SparkConf, TaskEndReason}
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.executor.TaskMetrics
-import org.apache.spark.internal.Logging
 import org.apache.spark.scheduler.cluster.ExecutorInfo
 import org.apache.spark.storage.{BlockManagerId, BlockUpdatedInfo}
 import org.apache.spark.ui.SparkUI
-import org.apache.spark.util.{Distribution, Utils}
 
 @DeveloperApi
 @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "Event")
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index 02185bf631fdcd946e76f71681ab57321467dbe8..2f972b064b4778e80c5374c5a918dbe31e0bbad5 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -112,7 +112,7 @@ private[scheduler] abstract class Stage(
       numPartitionsToCompute: Int,
       taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty): Unit = {
     val metrics = new TaskMetrics
-    metrics.registerForCleanup(rdd.sparkContext)
+    metrics.register(rdd.sparkContext)
     _latestInfo = StageInfo.fromStage(
       this, nextAttemptId, Some(numPartitionsToCompute), metrics, taskLocalityPreferences)
     nextAttemptId += 1
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index eb10f3e69b09249aab5108e511bbe1d1c9038127..e7ca6efd84aee14a4abdfea5b5e2295f100bbcdc 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -23,7 +23,7 @@ import java.util.Properties
 
 import scala.collection.mutable.HashMap
 
-import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl}
+import org.apache.spark._
 import org.apache.spark.executor.TaskMetrics
 import org.apache.spark.memory.{MemoryMode, TaskMemoryManager}
 import org.apache.spark.metrics.MetricsSystem
@@ -52,7 +52,7 @@ private[spark] abstract class Task[T](
     val stageAttemptId: Int,
     val partitionId: Int,
     // The default value is only used in tests.
-    val metrics: TaskMetrics = TaskMetrics.empty,
+    val metrics: TaskMetrics = TaskMetrics.registered,
     @transient var localProperties: Properties = new Properties) extends Serializable {
 
   /**
@@ -153,11 +153,11 @@ private[spark] abstract class Task[T](
    * Collect the latest values of accumulators used in this task. If the task failed,
    * filter out the accumulators whose values should not be included on failures.
    */
-  def collectAccumulatorUpdates(taskFailed: Boolean = false): Seq[AccumulableInfo] = {
+  def collectAccumulatorUpdates(taskFailed: Boolean = false): Seq[NewAccumulator[_, _]] = {
     if (context != null) {
-      context.taskMetrics.accumulatorUpdates().filter { a => !taskFailed || a.countFailedValues }
+      context.taskMetrics.accumulators().filter { a => !taskFailed || a.countFailedValues }
     } else {
-      Seq.empty[AccumulableInfo]
+      Seq.empty
     }
   }
 
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
index 03135e63d75518bbc059a0d83fa39175ae637a62..b472c5511b738e8db9ed9fcef8fd07df3dd60050 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
@@ -22,7 +22,7 @@ import java.nio.ByteBuffer
 
 import scala.collection.mutable.ArrayBuffer
 
-import org.apache.spark.SparkEnv
+import org.apache.spark.{NewAccumulator, SparkEnv}
 import org.apache.spark.storage.BlockId
 import org.apache.spark.util.Utils
 
@@ -36,7 +36,7 @@ private[spark] case class IndirectTaskResult[T](blockId: BlockId, size: Int)
 /** A TaskResult that contains the task's return value and accumulator updates. */
 private[spark] class DirectTaskResult[T](
     var valueBytes: ByteBuffer,
-    var accumUpdates: Seq[AccumulableInfo])
+    var accumUpdates: Seq[NewAccumulator[_, _]])
   extends TaskResult[T] with Externalizable {
 
   private var valueObjectDeserialized = false
@@ -61,9 +61,9 @@ private[spark] class DirectTaskResult[T](
     if (numUpdates == 0) {
       accumUpdates = null
     } else {
-      val _accumUpdates = new ArrayBuffer[AccumulableInfo]
+      val _accumUpdates = new ArrayBuffer[NewAccumulator[_, _]]
       for (i <- 0 until numUpdates) {
-        _accumUpdates += in.readObject.asInstanceOf[AccumulableInfo]
+        _accumUpdates += in.readObject.asInstanceOf[NewAccumulator[_, _]]
       }
       accumUpdates = _accumUpdates
     }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
index ae7ef46abbf31c05c11bb29d8224c9afff7ee222..b438c285fdf1fa1e5ffc3277b989706ad35832a2 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -93,9 +93,10 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
           // we would have to serialize the result again after updating the size.
           result.accumUpdates = result.accumUpdates.map { a =>
             if (a.name == Some(InternalAccumulator.RESULT_SIZE)) {
-              assert(a.update == Some(0L),
-                "task result size should not have been set on the executors")
-              a.copy(update = Some(size.toLong))
+              val acc = a.asInstanceOf[LongAccumulator]
+              assert(acc.sum == 0L, "task result size should not have been set on the executors")
+              acc.setValue(size.toLong)
+              acc
             } else {
               a
             }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
index 647d44a0f0680710105ba06cb63c0fad615695c5..75a0c56311977adf71011d2b683afbd681c16704 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.scheduler
 
+import org.apache.spark.NewAccumulator
 import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
 import org.apache.spark.storage.BlockManagerId
 
@@ -66,7 +67,7 @@ private[spark] trait TaskScheduler {
    */
   def executorHeartbeatReceived(
       execId: String,
-      accumUpdates: Array[(Long, Seq[AccumulableInfo])],
+      accumUpdates: Array[(Long, Seq[NewAccumulator[_, _]])],
       blockManagerId: BlockManagerId): Boolean
 
   /**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index f31ec2af4ebd6deb0eb4270fbd99cd8c4cb3ef43..776a3226cc78da28f46243c8ffbfc0a1c38d8266 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -384,13 +384,14 @@ private[spark] class TaskSchedulerImpl(
    */
   override def executorHeartbeatReceived(
       execId: String,
-      accumUpdates: Array[(Long, Seq[AccumulableInfo])],
+      accumUpdates: Array[(Long, Seq[NewAccumulator[_, _]])],
       blockManagerId: BlockManagerId): Boolean = {
     // (taskId, stageId, stageAttemptId, accumUpdates)
     val accumUpdatesWithTaskIds: Array[(Long, Int, Int, Seq[AccumulableInfo])] = synchronized {
       accumUpdates.flatMap { case (id, updates) =>
         taskIdToTaskSetManager.get(id).map { taskSetMgr =>
-          (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, updates)
+          (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId,
+            updates.map(acc => acc.toInfo(Some(acc.value), None)))
         }
       }
     }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 6e08cdd87a8d1f185b55671f39af566565601ae7..b79f643a7481f5363317e314442a234c65d6bb0c 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -647,7 +647,7 @@ private[spark] class TaskSetManager(
     info.markFailed()
     val index = info.index
     copiesRunning(index) -= 1
-    var accumUpdates: Seq[AccumulableInfo] = Seq.empty[AccumulableInfo]
+    var accumUpdates: Seq[NewAccumulator[_, _]] = Seq.empty
     val failureReason = s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid, ${info.host}): " +
       reason.asInstanceOf[TaskFailedReason].toErrorString
     val failureException: Option[Throwable] = reason match {
@@ -663,7 +663,7 @@ private[spark] class TaskSetManager(
 
       case ef: ExceptionFailure =>
         // ExceptionFailure's might have accumulator updates
-        accumUpdates = ef.accumUpdates
+        accumUpdates = ef.accums
         if (ef.className == classOf[NotSerializableException].getName) {
           // If the task result wasn't serializable, there's no point in trying to re-execute it.
           logError("Task %s in stage %s (TID %d) had a not serializable result: %s; not retrying"
@@ -788,7 +788,7 @@ private[spark] class TaskSetManager(
           // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
           // stage finishes when a total of tasks.size tasks finish.
           sched.dagScheduler.taskEnded(
-            tasks(index), Resubmitted, null, Seq.empty[AccumulableInfo], info)
+            tasks(index), Resubmitted, null, Seq.empty, info)
         }
       }
     }
diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala
index 8daca6c390635211039d5375638dd1c35f813e27..c04b483831704b51e179f3b964cccf9f36173ef1 100644
--- a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala
@@ -266,7 +266,13 @@ private[spark] object SerializationDebugger extends Logging {
       (o, desc)
     } else {
       // write place
-      findObjectAndDescriptor(desc.invokeWriteReplace(o))
+      val replaced = desc.invokeWriteReplace(o)
+      // `writeReplace` may return the same object.
+      if (replaced eq o) {
+        (o, desc)
+      } else {
+        findObjectAndDescriptor(replaced)
+      }
     }
   }
 
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala
index ff28796a60f67d4437b59eb20fedc387d8a6a829..32e332a9adb9d22463d2f9733159502af02e0704 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala
@@ -186,8 +186,8 @@ class OutputMetrics private[spark](
     val recordsWritten: Long)
 
 class ShuffleReadMetrics private[spark](
-    val remoteBlocksFetched: Int,
-    val localBlocksFetched: Int,
+    val remoteBlocksFetched: Long,
+    val localBlocksFetched: Long,
     val fetchWaitTime: Long,
     val remoteBytesRead: Long,
     val localBytesRead: Long,
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 1c4921666fc5b0eafa1c94c0cac6b2aafaa0ca68..f2d06c7ea80790d0b1fd14f4478343ca18d92e57 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -801,7 +801,7 @@ private[spark] class BlockManager(
           reportBlockStatus(blockId, info, putBlockStatus)
         }
         Option(TaskContext.get()).foreach { c =>
-          c.taskMetrics().incUpdatedBlockStatuses(Seq((blockId, putBlockStatus)))
+          c.taskMetrics().incUpdatedBlockStatuses(blockId -> putBlockStatus)
         }
       }
       logDebug("Put block %s locally took %s".format(blockId, Utils.getUsedTimeMs(startTimeMs)))
@@ -958,7 +958,7 @@ private[spark] class BlockManager(
           reportBlockStatus(blockId, info, putBlockStatus)
         }
         Option(TaskContext.get()).foreach { c =>
-          c.taskMetrics().incUpdatedBlockStatuses(Seq((blockId, putBlockStatus)))
+          c.taskMetrics().incUpdatedBlockStatuses(blockId -> putBlockStatus)
         }
         logDebug("Put block %s locally took %s".format(blockId, Utils.getUsedTimeMs(startTimeMs)))
         if (level.replication > 1) {
@@ -1257,7 +1257,7 @@ private[spark] class BlockManager(
     }
     if (blockIsUpdated) {
       Option(TaskContext.get()).foreach { c =>
-        c.taskMetrics().incUpdatedBlockStatuses(Seq((blockId, status)))
+        c.taskMetrics().incUpdatedBlockStatuses(blockId -> status)
       }
     }
     status.storageLevel
@@ -1311,7 +1311,7 @@ private[spark] class BlockManager(
           reportBlockStatus(blockId, info, removeBlockStatus)
         }
         Option(TaskContext.get()).foreach { c =>
-          c.taskMetrics().incUpdatedBlockStatuses(Seq((blockId, removeBlockStatus)))
+          c.taskMetrics().incUpdatedBlockStatuses(blockId -> removeBlockStatus)
         }
     }
   }
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
index 9ab7d96e290d647ad114cec92f7fb52fc196c6e0..945830c8bf242c5e17fd5fcdd0637c317c040f06 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
@@ -375,26 +375,21 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
       execSummary.taskTime += info.duration
       stageData.numActiveTasks -= 1
 
-      val (errorMessage, accums): (Option[String], Seq[AccumulableInfo]) =
+      val errorMessage: Option[String] =
         taskEnd.reason match {
           case org.apache.spark.Success =>
             stageData.completedIndices.add(info.index)
             stageData.numCompleteTasks += 1
-            (None, taskEnd.taskMetrics.accumulatorUpdates())
+            None
           case e: ExceptionFailure => // Handle ExceptionFailure because we might have accumUpdates
             stageData.numFailedTasks += 1
-            (Some(e.toErrorString), e.accumUpdates)
+            Some(e.toErrorString)
           case e: TaskFailedReason => // All other failure cases
             stageData.numFailedTasks += 1
-            (Some(e.toErrorString), Seq.empty[AccumulableInfo])
+            Some(e.toErrorString)
         }
 
-      val taskMetrics =
-        if (accums.nonEmpty) {
-          Some(TaskMetrics.fromAccumulatorUpdates(accums))
-        } else {
-          None
-        }
+      val taskMetrics = Option(taskEnd.taskMetrics)
       taskMetrics.foreach { m =>
         val oldMetrics = stageData.taskData.get(info.taskId).flatMap(_.metrics)
         updateAggregateMetrics(stageData, info.executorId, m, oldMetrics)
@@ -503,7 +498,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
         new StageUIData
       })
       val taskData = stageData.taskData.get(taskId)
-      val metrics = TaskMetrics.fromAccumulatorUpdates(accumUpdates)
+      val metrics = TaskMetrics.fromAccumulatorInfos(accumUpdates)
       taskData.foreach { t =>
         if (!t.taskInfo.finished) {
           updateAggregateMetrics(stageData, executorMetricsUpdate.execId, metrics, t.metrics)
diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
index a613fbc5cc3b27fd11dd1123a121fe2884c20cef..aeab71d9df603bc4c3c87378bee0563c73d56fd7 100644
--- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
@@ -840,7 +840,9 @@ private[spark] object JsonProtocol {
         // Fallback on getting accumulator updates from TaskMetrics, which was logged in Spark 1.x
         val accumUpdates = Utils.jsonOption(json \ "Accumulator Updates")
           .map(_.extract[List[JValue]].map(accumulableInfoFromJson))
-          .getOrElse(taskMetricsFromJson(json \ "Metrics").accumulatorUpdates())
+          .getOrElse(taskMetricsFromJson(json \ "Metrics").accumulators().map(acc => {
+            acc.toInfo(Some(acc.localValue), None)
+          }))
         ExceptionFailure(className, description, stackTrace, fullStackTrace, None, accumUpdates)
       case `taskResultLost` => TaskResultLost
       case `taskKilled` => TaskKilled
diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
index 6063476936c7f288f9d5a10d5a403249e2aa5e8a..5f97e58845d7a0577d09f969ece4c0cd19e5d67a 100644
--- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
@@ -28,17 +28,17 @@ import scala.util.control.NonFatal
 import org.scalatest.Matchers
 import org.scalatest.exceptions.TestFailedException
 
-import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.AccumulatorParam.{ListAccumulatorParam, StringAccumulatorParam}
 import org.apache.spark.scheduler._
 import org.apache.spark.serializer.JavaSerializer
 
 
 class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContext {
-  import AccumulatorParam._
+  import AccumulatorSuite.createLongAccum
 
   override def afterEach(): Unit = {
     try {
-      Accumulators.clear()
+      AccumulatorContext.clear()
     } finally {
       super.afterEach()
     }
@@ -59,9 +59,30 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex
       }
     }
 
+  test("accumulator serialization") {
+    val ser = new JavaSerializer(new SparkConf).newInstance()
+    val acc = createLongAccum("x")
+    acc.add(5)
+    assert(acc.value == 5)
+    assert(acc.isAtDriverSide)
+
+    // serialize and de-serialize it, to simulate sending accumulator to executor.
+    val acc2 = ser.deserialize[LongAccumulator](ser.serialize(acc))
+    // value is reset on the executors
+    assert(acc2.localValue == 0)
+    assert(!acc2.isAtDriverSide)
+
+    acc2.add(10)
+    // serialize and de-serialize it again, to simulate sending accumulator back to driver.
+    val acc3 = ser.deserialize[LongAccumulator](ser.serialize(acc2))
+    // value is not reset on the driver
+    assert(acc3.value == 10)
+    assert(acc3.isAtDriverSide)
+  }
+
   test ("basic accumulation") {
     sc = new SparkContext("local", "test")
-    val acc : Accumulator[Int] = sc.accumulator(0)
+    val acc: Accumulator[Int] = sc.accumulator(0)
 
     val d = sc.parallelize(1 to 20)
     d.foreach{x => acc += x}
@@ -75,7 +96,7 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex
 
   test("value not assignable from tasks") {
     sc = new SparkContext("local", "test")
-    val acc : Accumulator[Int] = sc.accumulator(0)
+    val acc: Accumulator[Int] = sc.accumulator(0)
 
     val d = sc.parallelize(1 to 20)
     an [Exception] should be thrownBy {d.foreach{x => acc.value = x}}
@@ -169,14 +190,13 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex
     System.gc()
     assert(ref.get.isEmpty)
 
-    Accumulators.remove(accId)
-    assert(!Accumulators.originals.get(accId).isDefined)
+    AccumulatorContext.remove(accId)
+    assert(!AccumulatorContext.originals.containsKey(accId))
   }
 
   test("get accum") {
-    sc = new SparkContext("local", "test")
     // Don't register with SparkContext for cleanup
-    var acc = new Accumulable[Int, Int](0, IntAccumulatorParam, None, true)
+    var acc = createLongAccum("a")
     val accId = acc.id
     val ref = WeakReference(acc)
     assert(ref.get.isDefined)
@@ -188,44 +208,16 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex
 
     // Getting a garbage collected accum should throw error
     intercept[IllegalAccessError] {
-      Accumulators.get(accId)
+      AccumulatorContext.get(accId)
     }
 
     // Getting a normal accumulator. Note: this has to be separate because referencing an
     // accumulator above in an `assert` would keep it from being garbage collected.
-    val acc2 = new Accumulable[Long, Long](0L, LongAccumulatorParam, None, true)
-    assert(Accumulators.get(acc2.id) === Some(acc2))
+    val acc2 = createLongAccum("b")
+    assert(AccumulatorContext.get(acc2.id) === Some(acc2))
 
     // Getting an accumulator that does not exist should return None
-    assert(Accumulators.get(100000).isEmpty)
-  }
-
-  test("copy") {
-    val acc1 = new Accumulable[Long, Long](456L, LongAccumulatorParam, Some("x"), false)
-    val acc2 = acc1.copy()
-    assert(acc1.id === acc2.id)
-    assert(acc1.value === acc2.value)
-    assert(acc1.name === acc2.name)
-    assert(acc1.countFailedValues === acc2.countFailedValues)
-    assert(acc1 !== acc2)
-    // Modifying one does not affect the other
-    acc1.add(44L)
-    assert(acc1.value === 500L)
-    assert(acc2.value === 456L)
-    acc2.add(144L)
-    assert(acc1.value === 500L)
-    assert(acc2.value === 600L)
-  }
-
-  test("register multiple accums with same ID") {
-    val acc1 = new Accumulable[Int, Int](0, IntAccumulatorParam, None, true)
-    // `copy` will create a new Accumulable and register it.
-    val acc2 = acc1.copy()
-    assert(acc1 !== acc2)
-    assert(acc1.id === acc2.id)
-    // The second one does not override the first one
-    assert(Accumulators.originals.size === 1)
-    assert(Accumulators.get(acc1.id) === Some(acc1))
+    assert(AccumulatorContext.get(100000).isEmpty)
   }
 
   test("string accumulator param") {
@@ -257,37 +249,32 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex
     acc.setValue(Seq(9, 10))
     assert(acc.value === Seq(9, 10))
   }
-
-  test("value is reset on the executors") {
-    val acc1 = new Accumulator(0, IntAccumulatorParam, Some("thing"))
-    val acc2 = new Accumulator(0L, LongAccumulatorParam, Some("thing2"))
-    val externalAccums = Seq(acc1, acc2)
-    val taskMetrics = new TaskMetrics
-    // Set some values; these should not be observed later on the "executors"
-    acc1.setValue(10)
-    acc2.setValue(20L)
-    taskMetrics.testAccum.get.setValue(30L)
-    // Simulate the task being serialized and sent to the executors.
-    val dummyTask = new DummyTask(taskMetrics, externalAccums)
-    val serInstance = new JavaSerializer(new SparkConf).newInstance()
-    val taskSer = Task.serializeWithDependencies(
-      dummyTask, mutable.HashMap(), mutable.HashMap(), serInstance)
-    // Now we're on the executors.
-    // Deserialize the task and assert that its accumulators are zero'ed out.
-    val (_, _, _, taskBytes) = Task.deserializeWithDependencies(taskSer)
-    val taskDeser = serInstance.deserialize[DummyTask](
-      taskBytes, Thread.currentThread.getContextClassLoader)
-    // Assert that executors see only zeros
-    taskDeser.externalAccums.foreach { a => assert(a.localValue == a.zero) }
-    taskDeser.metrics.internalAccums.foreach { a => assert(a.localValue == a.zero) }
-  }
-
 }
 
 private[spark] object AccumulatorSuite {
-
   import InternalAccumulator._
 
+  /**
+   * Create a long accumulator and register it to [[AccumulatorContext]].
+   */
+  def createLongAccum(
+      name: String,
+      countFailedValues: Boolean = false,
+      initValue: Long = 0,
+      id: Long = AccumulatorContext.newId()): LongAccumulator = {
+    val acc = new LongAccumulator
+    acc.setValue(initValue)
+    acc.metadata = AccumulatorMetadata(id, Some(name), countFailedValues)
+    AccumulatorContext.register(acc)
+    acc
+  }
+
+  /**
+   * Make an [[AccumulableInfo]] out of an [[Accumulable]] with the intent to use the
+   * info as an accumulator update.
+   */
+  def makeInfo(a: NewAccumulator[_, _]): AccumulableInfo = a.toInfo(Some(a.localValue), None)
+
   /**
    * Run one or more Spark jobs and verify that in at least one job the peak execution memory
    * accumulator is updated afterwards.
@@ -340,7 +327,6 @@ private class SaveInfoListener extends SparkListener {
     if (jobCompletionCallback != null) {
       jobCompletionSem.acquire()
       if (exception != null) {
-        exception = null
         throw exception
       }
     }
@@ -377,13 +363,3 @@ private class SaveInfoListener extends SparkListener {
       (taskEnd.stageId, taskEnd.stageAttemptId), new ArrayBuffer[TaskInfo]) += taskEnd.taskInfo
   }
 }
-
-
-/**
- * A dummy [[Task]] that contains internal and external [[Accumulator]]s.
- */
-private[spark] class DummyTask(
-    metrics: TaskMetrics,
-    val externalAccums: Seq[Accumulator[_]]) extends Task[Int](0, 0, 0, metrics) {
-  override def runTask(c: TaskContext): Int = 1
-}
diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
index 4d2b3e7f3b14b73bb2d3c0e8f7a0d0d78f2f6057..1adc90ab1e9dd0b018a74fd330a7184d1c3efd85 100644
--- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
+++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
@@ -211,10 +211,10 @@ class HeartbeatReceiverSuite
   private def triggerHeartbeat(
       executorId: String,
       executorShouldReregister: Boolean): Unit = {
-    val metrics = new TaskMetrics
+    val metrics = TaskMetrics.empty
     val blockManagerId = BlockManagerId(executorId, "localhost", 12345)
     val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse](
-      Heartbeat(executorId, Array(1L -> metrics.accumulatorUpdates()), blockManagerId))
+      Heartbeat(executorId, Array(1L -> metrics.accumulators()), blockManagerId))
     if (executorShouldReregister) {
       assert(response.reregisterBlockManager)
     } else {
@@ -222,7 +222,7 @@ class HeartbeatReceiverSuite
       // Additionally verify that the scheduler callback is called with the correct parameters
       verify(scheduler).executorHeartbeatReceived(
         Matchers.eq(executorId),
-        Matchers.eq(Array(1L -> metrics.accumulatorUpdates())),
+        Matchers.eq(Array(1L -> metrics.accumulators())),
         Matchers.eq(blockManagerId))
     }
   }
diff --git a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala
index b074b954247319e20d87ec54fb109cdb3018a9b2..e4474bb813d5e39ee323394384549d98ae23975b 100644
--- a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark
 
+import scala.collection.JavaConverters._
 import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.executor.TaskMetrics
@@ -29,7 +30,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
 
   override def afterEach(): Unit = {
     try {
-      Accumulators.clear()
+      AccumulatorContext.clear()
     } finally {
       super.afterEach()
     }
@@ -37,9 +38,8 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
 
   test("internal accumulators in TaskContext") {
     val taskContext = TaskContext.empty()
-    val accumUpdates = taskContext.taskMetrics.accumulatorUpdates()
+    val accumUpdates = taskContext.taskMetrics.accumulators()
     assert(accumUpdates.size > 0)
-    assert(accumUpdates.forall(_.internal))
     val testAccum = taskContext.taskMetrics.testAccum.get
     assert(accumUpdates.exists(_.id == testAccum.id))
   }
@@ -51,7 +51,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
     sc.addSparkListener(listener)
     // Have each task add 1 to the internal accumulator
     val rdd = sc.parallelize(1 to 100, numPartitions).mapPartitions { iter =>
-      TaskContext.get().taskMetrics().testAccum.get += 1
+      TaskContext.get().taskMetrics().testAccum.get.add(1)
       iter
     }
     // Register asserts in job completion callback to avoid flakiness
@@ -87,17 +87,17 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
     val rdd = sc.parallelize(1 to 100, numPartitions)
       .map { i => (i, i) }
       .mapPartitions { iter =>
-        TaskContext.get().taskMetrics().testAccum.get += 1
+        TaskContext.get().taskMetrics().testAccum.get.add(1)
         iter
       }
       .reduceByKey { case (x, y) => x + y }
       .mapPartitions { iter =>
-        TaskContext.get().taskMetrics().testAccum.get += 10
+        TaskContext.get().taskMetrics().testAccum.get.add(10)
         iter
       }
       .repartition(numPartitions * 2)
       .mapPartitions { iter =>
-        TaskContext.get().taskMetrics().testAccum.get += 100
+        TaskContext.get().taskMetrics().testAccum.get.add(100)
         iter
       }
     // Register asserts in job completion callback to avoid flakiness
@@ -127,7 +127,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
     // This should retry both stages in the scheduler. Note that we only want to fail the
     // first stage attempt because we want the stage to eventually succeed.
     val x = sc.parallelize(1 to 100, numPartitions)
-      .mapPartitions { iter => TaskContext.get().taskMetrics().testAccum.get += 1; iter }
+      .mapPartitions { iter => TaskContext.get().taskMetrics().testAccum.get.add(1); iter }
       .groupBy(identity)
     val sid = x.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleHandle.shuffleId
     val rdd = x.mapPartitionsWithIndex { case (i, iter) =>
@@ -183,18 +183,18 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
       private val myCleaner = new SaveAccumContextCleaner(this)
       override def cleaner: Option[ContextCleaner] = Some(myCleaner)
     }
-    assert(Accumulators.originals.isEmpty)
+    assert(AccumulatorContext.originals.isEmpty)
     sc.parallelize(1 to 100).map { i => (i, i) }.reduceByKey { _ + _ }.count()
     val numInternalAccums = TaskMetrics.empty.internalAccums.length
     // We ran 2 stages, so we should have 2 sets of internal accumulators, 1 for each stage
-    assert(Accumulators.originals.size === numInternalAccums * 2)
+    assert(AccumulatorContext.originals.size === numInternalAccums * 2)
     val accumsRegistered = sc.cleaner match {
       case Some(cleaner: SaveAccumContextCleaner) => cleaner.accumsRegisteredForCleanup
       case _ => Seq.empty[Long]
     }
     // Make sure the same set of accumulators is registered for cleanup
     assert(accumsRegistered.size === numInternalAccums * 2)
-    assert(accumsRegistered.toSet === Accumulators.originals.keys.toSet)
+    assert(accumsRegistered.toSet === AccumulatorContext.originals.keySet().asScala)
   }
 
   /**
@@ -212,7 +212,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
   private class SaveAccumContextCleaner(sc: SparkContext) extends ContextCleaner(sc) {
     private val accumsRegistered = new ArrayBuffer[Long]
 
-    override def registerAccumulatorForCleanup(a: Accumulable[_, _]): Unit = {
+    override def registerAccumulatorForCleanup(a: NewAccumulator[_, _]): Unit = {
       accumsRegistered += a.id
       super.registerAccumulatorForCleanup(a)
     }
diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala
index 3228752b96389407964f710be1fb0a66aaf74ce0..4aae2c9b4a8e4af8fc8210bedb4c8bcf7e865e22 100644
--- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala
@@ -34,7 +34,7 @@ private[spark] abstract class SparkFunSuite
   protected override def afterAll(): Unit = {
     try {
       // Avoid leaking map entries in tests that use accumulators without SparkContext
-      Accumulators.clear()
+      AccumulatorContext.clear()
     } finally {
       super.afterAll()
     }
diff --git a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala
index ee70419727e861645545ab966836e7597d80ece3..94f6e1a3a77c15e92dd8718b97527c928b9f2766 100644
--- a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala
@@ -20,14 +20,11 @@ package org.apache.spark.executor
 import org.scalatest.Assertions
 
 import org.apache.spark._
-import org.apache.spark.scheduler.AccumulableInfo
-import org.apache.spark.storage.{BlockId, BlockStatus, StorageLevel, TestBlockId}
+import org.apache.spark.storage.{BlockStatus, StorageLevel, TestBlockId}
 
 
 class TaskMetricsSuite extends SparkFunSuite {
-  import AccumulatorParam._
   import StorageLevel._
-  import TaskMetricsSuite._
 
   test("mutating values") {
     val tm = new TaskMetrics
@@ -59,8 +56,8 @@ class TaskMetricsSuite extends SparkFunSuite {
     tm.incPeakExecutionMemory(8L)
     val block1 = (TestBlockId("a"), BlockStatus(MEMORY_ONLY, 1L, 2L))
     val block2 = (TestBlockId("b"), BlockStatus(MEMORY_ONLY, 3L, 4L))
-    tm.incUpdatedBlockStatuses(Seq(block1))
-    tm.incUpdatedBlockStatuses(Seq(block2))
+    tm.incUpdatedBlockStatuses(block1)
+    tm.incUpdatedBlockStatuses(block2)
     // assert new values exist
     assert(tm.executorDeserializeTime == 1L)
     assert(tm.executorRunTime == 2L)
@@ -194,18 +191,19 @@ class TaskMetricsSuite extends SparkFunSuite {
   }
 
   test("additional accumulables") {
-    val tm = new TaskMetrics
-    val acc1 = new Accumulator(0, IntAccumulatorParam, Some("a"))
-    val acc2 = new Accumulator(0, IntAccumulatorParam, Some("b"))
-    val acc3 = new Accumulator(0, IntAccumulatorParam, Some("c"))
-    val acc4 = new Accumulator(0, IntAccumulatorParam, Some("d"), countFailedValues = true)
+    val tm = TaskMetrics.empty
+    val acc1 = AccumulatorSuite.createLongAccum("a")
+    val acc2 = AccumulatorSuite.createLongAccum("b")
+    val acc3 = AccumulatorSuite.createLongAccum("c")
+    val acc4 = AccumulatorSuite.createLongAccum("d", true)
     tm.registerAccumulator(acc1)
     tm.registerAccumulator(acc2)
     tm.registerAccumulator(acc3)
     tm.registerAccumulator(acc4)
-    acc1 += 1
-    acc2 += 2
-    val newUpdates = tm.accumulatorUpdates().map { a => (a.id, a) }.toMap
+    acc1.add(1)
+    acc2.add(2)
+    val newUpdates = tm.accumulators()
+      .map(a => (a.id, a.asInstanceOf[NewAccumulator[Any, Any]])).toMap
     assert(newUpdates.contains(acc1.id))
     assert(newUpdates.contains(acc2.id))
     assert(newUpdates.contains(acc3.id))
@@ -214,46 +212,14 @@ class TaskMetricsSuite extends SparkFunSuite {
     assert(newUpdates(acc2.id).name === Some("b"))
     assert(newUpdates(acc3.id).name === Some("c"))
     assert(newUpdates(acc4.id).name === Some("d"))
-    assert(newUpdates(acc1.id).update === Some(1))
-    assert(newUpdates(acc2.id).update === Some(2))
-    assert(newUpdates(acc3.id).update === Some(0))
-    assert(newUpdates(acc4.id).update === Some(0))
+    assert(newUpdates(acc1.id).value === 1)
+    assert(newUpdates(acc2.id).value === 2)
+    assert(newUpdates(acc3.id).value === 0)
+    assert(newUpdates(acc4.id).value === 0)
     assert(!newUpdates(acc3.id).countFailedValues)
     assert(newUpdates(acc4.id).countFailedValues)
-    assert(newUpdates.values.map(_.update).forall(_.isDefined))
-    assert(newUpdates.values.map(_.value).forall(_.isEmpty))
     assert(newUpdates.size === tm.internalAccums.size + 4)
   }
-
-  test("from accumulator updates") {
-    val accumUpdates1 = TaskMetrics.empty.internalAccums.map { a =>
-      AccumulableInfo(a.id, a.name, Some(3L), None, true, a.countFailedValues)
-    }
-    val metrics1 = TaskMetrics.fromAccumulatorUpdates(accumUpdates1)
-    assertUpdatesEquals(metrics1.accumulatorUpdates(), accumUpdates1)
-    // Test this with additional accumulators to ensure that we do not crash when handling
-    // updates from unregistered accumulators. In practice, all accumulators created
-    // on the driver, internal or not, should be registered with `Accumulators` at some point.
-    val param = IntAccumulatorParam
-    val registeredAccums = Seq(
-      new Accumulator(0, param, Some("a"), countFailedValues = true),
-      new Accumulator(0, param, Some("b"), countFailedValues = false))
-    val unregisteredAccums = Seq(
-      new Accumulator(0, param, Some("c"), countFailedValues = true),
-      new Accumulator(0, param, Some("d"), countFailedValues = false))
-    registeredAccums.foreach(Accumulators.register)
-    registeredAccums.foreach(a => assert(Accumulators.originals.contains(a.id)))
-    unregisteredAccums.foreach(a => Accumulators.remove(a.id))
-    unregisteredAccums.foreach(a => assert(!Accumulators.originals.contains(a.id)))
-    // set some values in these accums
-    registeredAccums.zipWithIndex.foreach { case (a, i) => a.setValue(i) }
-    unregisteredAccums.zipWithIndex.foreach { case (a, i) => a.setValue(i) }
-    val registeredAccumInfos = registeredAccums.map(makeInfo)
-    val unregisteredAccumInfos = unregisteredAccums.map(makeInfo)
-    val accumUpdates2 = accumUpdates1 ++ registeredAccumInfos ++ unregisteredAccumInfos
-    // Simply checking that this does not crash:
-    TaskMetrics.fromAccumulatorUpdates(accumUpdates2)
-  }
 }
 
 
@@ -264,21 +230,14 @@ private[spark] object TaskMetricsSuite extends Assertions {
    * Note: this does NOT check accumulator ID equality.
    */
   def assertUpdatesEquals(
-      updates1: Seq[AccumulableInfo],
-      updates2: Seq[AccumulableInfo]): Unit = {
+      updates1: Seq[NewAccumulator[_, _]],
+      updates2: Seq[NewAccumulator[_, _]]): Unit = {
     assert(updates1.size === updates2.size)
-    updates1.zip(updates2).foreach { case (info1, info2) =>
+    updates1.zip(updates2).foreach { case (acc1, acc2) =>
       // do not assert ID equals here
-      assert(info1.name === info2.name)
-      assert(info1.update === info2.update)
-      assert(info1.value === info2.value)
-      assert(info1.countFailedValues === info2.countFailedValues)
+      assert(acc1.name === acc2.name)
+      assert(acc1.countFailedValues === acc2.countFailedValues)
+      assert(acc1.value == acc2.value)
     }
   }
-
-  /**
-   * Make an [[AccumulableInfo]] out of an [[Accumulable]] with the intent to use the
-   * info as an accumulator update.
-   */
-  def makeInfo(a: Accumulable[_, _]): AccumulableInfo = a.toInfo(Some(a.value), None)
 }
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index b76c0a4bd1dded4b3e87d633285e282c35a1d753..9912d1f3bc5a7550e66afbacf05cae5027752a34 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -112,7 +112,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
     override def stop() = {}
     override def executorHeartbeatReceived(
         execId: String,
-        accumUpdates: Array[(Long, Seq[AccumulableInfo])],
+        accumUpdates: Array[(Long, Seq[NewAccumulator[_, _]])],
         blockManagerId: BlockManagerId): Boolean = true
     override def submitTasks(taskSet: TaskSet) = {
       // normally done by TaskSetManager
@@ -277,8 +277,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
           taskSet.tasks(i),
           result._1,
           result._2,
-          Seq(new AccumulableInfo(
-            accumId, Some(""), Some(1), None, internal = false, countFailedValues = false))))
+          Seq(AccumulatorSuite.createLongAccum("", initValue = 1, id = accumId))))
       }
     }
   }
@@ -484,7 +483,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
       override def defaultParallelism(): Int = 2
       override def executorHeartbeatReceived(
           execId: String,
-          accumUpdates: Array[(Long, Seq[AccumulableInfo])],
+          accumUpdates: Array[(Long, Seq[NewAccumulator[_, _]])],
           blockManagerId: BlockManagerId): Boolean = true
       override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {}
       override def applicationAttemptId(): Option[String] = None
@@ -997,10 +996,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
     // complete two tasks
     runEvent(makeCompletionEvent(
       taskSets(0).tasks(0), Success, 42,
-      Seq.empty[AccumulableInfo], createFakeTaskInfoWithId(0)))
+      Seq.empty, createFakeTaskInfoWithId(0)))
     runEvent(makeCompletionEvent(
       taskSets(0).tasks(1), Success, 42,
-      Seq.empty[AccumulableInfo], createFakeTaskInfoWithId(1)))
+      Seq.empty, createFakeTaskInfoWithId(1)))
     sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
     // verify stage exists
     assert(scheduler.stageIdToStage.contains(0))
@@ -1009,10 +1008,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
     // finish other 2 tasks
     runEvent(makeCompletionEvent(
       taskSets(0).tasks(2), Success, 42,
-      Seq.empty[AccumulableInfo], createFakeTaskInfoWithId(2)))
+      Seq.empty, createFakeTaskInfoWithId(2)))
     runEvent(makeCompletionEvent(
       taskSets(0).tasks(3), Success, 42,
-      Seq.empty[AccumulableInfo], createFakeTaskInfoWithId(3)))
+      Seq.empty, createFakeTaskInfoWithId(3)))
     sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
     assert(sparkListener.endedTasks.size == 4)
 
@@ -1023,14 +1022,14 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
     // with a speculative task and make sure the event is sent out
     runEvent(makeCompletionEvent(
       taskSets(0).tasks(3), Success, 42,
-      Seq.empty[AccumulableInfo], createFakeTaskInfoWithId(5)))
+      Seq.empty, createFakeTaskInfoWithId(5)))
     sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
     assert(sparkListener.endedTasks.size == 5)
 
     // make sure non successful tasks also send out event
     runEvent(makeCompletionEvent(
       taskSets(0).tasks(3), UnknownReason, 42,
-      Seq.empty[AccumulableInfo], createFakeTaskInfoWithId(6)))
+      Seq.empty, createFakeTaskInfoWithId(6)))
     sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
     assert(sparkListener.endedTasks.size == 6)
   }
@@ -1613,37 +1612,43 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
 
   test("accumulator not calculated for resubmitted result stage") {
     // just for register
-    val accum = new Accumulator[Int](0, AccumulatorParam.IntAccumulatorParam)
+    val accum = AccumulatorSuite.createLongAccum("a")
     val finalRdd = new MyRDD(sc, 1, Nil)
     submit(finalRdd, Array(0))
     completeWithAccumulator(accum.id, taskSets(0), Seq((Success, 42)))
     completeWithAccumulator(accum.id, taskSets(0), Seq((Success, 42)))
     assert(results === Map(0 -> 42))
 
-    val accVal = Accumulators.originals(accum.id).get.get.value
-
-    assert(accVal === 1)
-
+    assert(accum.value === 1)
     assertDataStructuresEmpty()
   }
 
   test("accumulators are updated on exception failures") {
-    val acc1 = sc.accumulator(0L, "ingenieur")
-    val acc2 = sc.accumulator(0L, "boulanger")
-    val acc3 = sc.accumulator(0L, "agriculteur")
-    assert(Accumulators.get(acc1.id).isDefined)
-    assert(Accumulators.get(acc2.id).isDefined)
-    assert(Accumulators.get(acc3.id).isDefined)
-    val accInfo1 = acc1.toInfo(Some(15L), None)
-    val accInfo2 = acc2.toInfo(Some(13L), None)
-    val accInfo3 = acc3.toInfo(Some(18L), None)
-    val accumUpdates = Seq(accInfo1, accInfo2, accInfo3)
-    val exceptionFailure = new ExceptionFailure(new SparkException("fondue?"), accumUpdates)
+    val acc1 = AccumulatorSuite.createLongAccum("ingenieur")
+    val acc2 = AccumulatorSuite.createLongAccum("boulanger")
+    val acc3 = AccumulatorSuite.createLongAccum("agriculteur")
+    assert(AccumulatorContext.get(acc1.id).isDefined)
+    assert(AccumulatorContext.get(acc2.id).isDefined)
+    assert(AccumulatorContext.get(acc3.id).isDefined)
+    val accUpdate1 = new LongAccumulator
+    accUpdate1.metadata = acc1.metadata
+    accUpdate1.setValue(15)
+    val accUpdate2 = new LongAccumulator
+    accUpdate2.metadata = acc2.metadata
+    accUpdate2.setValue(13)
+    val accUpdate3 = new LongAccumulator
+    accUpdate3.metadata = acc3.metadata
+    accUpdate3.setValue(18)
+    val accumUpdates = Seq(accUpdate1, accUpdate2, accUpdate3)
+    val accumInfo = accumUpdates.map(AccumulatorSuite.makeInfo)
+    val exceptionFailure = new ExceptionFailure(
+      new SparkException("fondue?"),
+      accumInfo).copy(accums = accumUpdates)
     submit(new MyRDD(sc, 1, Nil), Array(0))
     runEvent(makeCompletionEvent(taskSets.head.tasks.head, exceptionFailure, "result"))
-    assert(Accumulators.get(acc1.id).get.value === 15L)
-    assert(Accumulators.get(acc2.id).get.value === 13L)
-    assert(Accumulators.get(acc3.id).get.value === 18L)
+    assert(AccumulatorContext.get(acc1.id).get.value === 15L)
+    assert(AccumulatorContext.get(acc2.id).get.value === 13L)
+    assert(AccumulatorContext.get(acc3.id).get.value === 18L)
   }
 
   test("reduce tasks should be placed locally with map output") {
@@ -2007,12 +2012,12 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
       task: Task[_],
       reason: TaskEndReason,
       result: Any,
-      extraAccumUpdates: Seq[AccumulableInfo] = Seq.empty[AccumulableInfo],
+      extraAccumUpdates: Seq[NewAccumulator[_, _]] = Seq.empty,
       taskInfo: TaskInfo = createFakeTaskInfo()): CompletionEvent = {
     val accumUpdates = reason match {
-      case Success => task.metrics.accumulatorUpdates()
-      case ef: ExceptionFailure => ef.accumUpdates
-      case _ => Seq.empty[AccumulableInfo]
+      case Success => task.metrics.accumulators()
+      case ef: ExceptionFailure => ef.accums
+      case _ => Seq.empty
     }
     CompletionEvent(task, reason, result, accumUpdates ++ extraAccumUpdates, taskInfo)
   }
diff --git a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala
index 9971d48a52ce7f1e12148df926b6eece4ca41209..16027d944fdfd842b7eca7532fc6241a6c71c833 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala
@@ -17,12 +17,11 @@
 
 package org.apache.spark.scheduler
 
-import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite}
+import org.apache.spark.{LocalSparkContext, NewAccumulator, SparkConf, SparkContext, SparkFunSuite}
 import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
 import org.apache.spark.storage.BlockManagerId
 
-class ExternalClusterManagerSuite extends SparkFunSuite with LocalSparkContext
-{
+class ExternalClusterManagerSuite extends SparkFunSuite with LocalSparkContext {
   test("launch of backend and scheduler") {
     val conf = new SparkConf().setMaster("myclusterManager").
         setAppName("testcm").set("spark.driver.allowMultipleContexts", "true")
@@ -68,6 +67,6 @@ private class DummyTaskScheduler extends TaskScheduler {
   override def applicationAttemptId(): Option[String] = None
   def executorHeartbeatReceived(
       execId: String,
-      accumUpdates: Array[(Long, Seq[AccumulableInfo])],
+      accumUpdates: Array[(Long, Seq[NewAccumulator[_, _]])],
       blockManagerId: BlockManagerId): Boolean = true
 }
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index d55f6f60ece8692260fe187c3d1d24a96da574e2..9aca4dbc236441edbeaca1ceb4a1a839b6f2c4f0 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -162,18 +162,17 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
     }.count()
     // The one that counts failed values should be 4x the one that didn't,
     // since we ran each task 4 times
-    assert(Accumulators.get(acc1.id).get.value === 40L)
-    assert(Accumulators.get(acc2.id).get.value === 10L)
+    assert(AccumulatorContext.get(acc1.id).get.value === 40L)
+    assert(AccumulatorContext.get(acc2.id).get.value === 10L)
   }
 
   test("failed tasks collect only accumulators whose values count during failures") {
     sc = new SparkContext("local", "test")
-    val param = AccumulatorParam.LongAccumulatorParam
-    val acc1 = new Accumulator(0L, param, Some("x"), countFailedValues = true)
-    val acc2 = new Accumulator(0L, param, Some("y"), countFailedValues = false)
+    val acc1 = AccumulatorSuite.createLongAccum("x", true)
+    val acc2 = AccumulatorSuite.createLongAccum("y", false)
     // Create a dummy task. We won't end up running this; we just want to collect
     // accumulator updates from it.
-    val taskMetrics = new TaskMetrics
+    val taskMetrics = TaskMetrics.empty
     val task = new Task[Int](0, 0, 0) {
       context = new TaskContextImpl(0, 0, 0L, 0,
         new TaskMemoryManager(SparkEnv.get.memoryManager, 0L),
@@ -186,12 +185,11 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
     }
     // First, simulate task success. This should give us all the accumulators.
     val accumUpdates1 = task.collectAccumulatorUpdates(taskFailed = false)
-    val accumUpdates2 = (taskMetrics.internalAccums ++ Seq(acc1, acc2))
-      .map(TaskMetricsSuite.makeInfo)
+    val accumUpdates2 = taskMetrics.internalAccums ++ Seq(acc1, acc2)
     TaskMetricsSuite.assertUpdatesEquals(accumUpdates1, accumUpdates2)
     // Now, simulate task failures. This should give us only the accums that count failed values.
     val accumUpdates3 = task.collectAccumulatorUpdates(taskFailed = true)
-    val accumUpdates4 = (taskMetrics.internalAccums ++ Seq(acc1)).map(TaskMetricsSuite.makeInfo)
+    val accumUpdates4 = taskMetrics.internalAccums ++ Seq(acc1)
     TaskMetricsSuite.assertUpdatesEquals(accumUpdates3, accumUpdates4)
   }
 
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
index b5385c11a926e2132002ede25c6590b9342bbe58..9e472f900b655bb62cd6db13b08b3cc1b7c409e1 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
@@ -241,8 +241,8 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local
     assert(resultGetter.taskResults.size === 1)
     val resBefore = resultGetter.taskResults.head
     val resAfter = captor.getValue
-    val resSizeBefore = resBefore.accumUpdates.find(_.name == Some(RESULT_SIZE)).flatMap(_.update)
-    val resSizeAfter = resAfter.accumUpdates.find(_.name == Some(RESULT_SIZE)).flatMap(_.update)
+    val resSizeBefore = resBefore.accumUpdates.find(_.name == Some(RESULT_SIZE)).map(_.value)
+    val resSizeAfter = resAfter.accumUpdates.find(_.name == Some(RESULT_SIZE)).map(_.value)
     assert(resSizeBefore.exists(_ == 0L))
     assert(resSizeAfter.exists(_.toString.toLong > 0L))
   }
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index ecf4b76da5586c588ddb8a3f82f1d25397787bf6..339fc4254d53afa91a2c1377c209eb4d63d98fd1 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -37,7 +37,7 @@ class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler)
       task: Task[_],
       reason: TaskEndReason,
       result: Any,
-      accumUpdates: Seq[AccumulableInfo],
+      accumUpdates: Seq[NewAccumulator[_, _]],
       taskInfo: TaskInfo) {
     taskScheduler.endedTasks(taskInfo.index) = reason
   }
@@ -166,8 +166,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
     val taskSet = FakeTask.createTaskSet(1)
     val clock = new ManualClock
     val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock)
-    val accumUpdates =
-      taskSet.tasks.head.metrics.internalAccums.map { a => a.toInfo(Some(0L), None) }
+    val accumUpdates = taskSet.tasks.head.metrics.internalAccums
 
     // Offer a host with NO_PREF as the constraint,
     // we should get a nopref task immediately since that's what we only have
@@ -185,8 +184,8 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
     val sched = new FakeTaskScheduler(sc, ("exec1", "host1"))
     val taskSet = FakeTask.createTaskSet(3)
     val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES)
-    val accumUpdatesByTask: Array[Seq[AccumulableInfo]] = taskSet.tasks.map { task =>
-      task.metrics.internalAccums.map { a => a.toInfo(Some(0L), None) }
+    val accumUpdatesByTask: Array[Seq[NewAccumulator[_, _]]] = taskSet.tasks.map { task =>
+      task.metrics.internalAccums
     }
 
     // First three offers should all find tasks
@@ -792,7 +791,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
 
   private def createTaskResult(
       id: Int,
-      accumUpdates: Seq[AccumulableInfo] = Seq.empty[AccumulableInfo]): DirectTaskResult[Int] = {
+      accumUpdates: Seq[NewAccumulator[_, _]] = Seq.empty): DirectTaskResult[Int] = {
     val valueSer = SparkEnv.get.serializer.newInstance()
     new DirectTaskResult[Int](valueSer.serialize(id), accumUpdates)
   }
diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
index 221124829fc545597a5a91580f193e7d1271a099..ce7d51d1c371b8a2d9baf85138a579ad23f6bcfc 100644
--- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
@@ -183,7 +183,7 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with
   test("test executor id to summary") {
     val conf = new SparkConf()
     val listener = new JobProgressListener(conf)
-    val taskMetrics = new TaskMetrics()
+    val taskMetrics = TaskMetrics.empty
     val shuffleReadMetrics = taskMetrics.createTempShuffleReadMetrics()
     assert(listener.stageIdToData.size === 0)
 
@@ -230,7 +230,7 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with
   test("test task success vs failure counting for different task end reasons") {
     val conf = new SparkConf()
     val listener = new JobProgressListener(conf)
-    val metrics = new TaskMetrics()
+    val metrics = TaskMetrics.empty
     val taskInfo = new TaskInfo(1234L, 0, 3, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false)
     taskInfo.finishTime = 1
     val task = new ShuffleMapTask(0)
@@ -269,7 +269,7 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with
     val execId = "exe-1"
 
     def makeTaskMetrics(base: Int): TaskMetrics = {
-      val taskMetrics = new TaskMetrics
+      val taskMetrics = TaskMetrics.empty
       val shuffleReadMetrics = taskMetrics.createTempShuffleReadMetrics()
       val shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics
       val inputMetrics = taskMetrics.inputMetrics
@@ -300,9 +300,9 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with
     listener.onTaskStart(SparkListenerTaskStart(1, 0, makeTaskInfo(1237L)))
 
     listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate(execId, Array(
-      (1234L, 0, 0, makeTaskMetrics(0).accumulatorUpdates()),
-      (1235L, 0, 0, makeTaskMetrics(100).accumulatorUpdates()),
-      (1236L, 1, 0, makeTaskMetrics(200).accumulatorUpdates()))))
+      (1234L, 0, 0, makeTaskMetrics(0).accumulators().map(AccumulatorSuite.makeInfo)),
+      (1235L, 0, 0, makeTaskMetrics(100).accumulators().map(AccumulatorSuite.makeInfo)),
+      (1236L, 1, 0, makeTaskMetrics(200).accumulators().map(AccumulatorSuite.makeInfo)))))
 
     var stage0Data = listener.stageIdToData.get((0, 0)).get
     var stage1Data = listener.stageIdToData.get((1, 0)).get
diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
index d3b6cdfe86eeccebfcf4939e41022b3b6ffe20da..6fda7378e6cef9a53e94668c773a61b26bac68bc 100644
--- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
@@ -85,7 +85,8 @@ class JsonProtocolSuite extends SparkFunSuite {
       // Use custom accum ID for determinism
       val accumUpdates =
         makeTaskMetrics(300L, 400L, 500L, 600L, 700, 800, hasHadoopInput = true, hasOutput = true)
-          .accumulatorUpdates().zipWithIndex.map { case (a, i) => a.copy(id = i) }
+          .accumulators().map(AccumulatorSuite.makeInfo)
+          .zipWithIndex.map { case (a, i) => a.copy(id = i) }
       SparkListenerExecutorMetricsUpdate("exec3", Seq((1L, 2, 3, accumUpdates)))
     }
 
@@ -385,7 +386,7 @@ class JsonProtocolSuite extends SparkFunSuite {
     // "Task Metrics" field, if it exists.
     val tm = makeTaskMetrics(1L, 2L, 3L, 4L, 5, 6, hasHadoopInput = true, hasOutput = true)
     val tmJson = JsonProtocol.taskMetricsToJson(tm)
-    val accumUpdates = tm.accumulatorUpdates()
+    val accumUpdates = tm.accumulators().map(AccumulatorSuite.makeInfo)
     val exception = new SparkException("sentimental")
     val exceptionFailure = new ExceptionFailure(exception, accumUpdates)
     val exceptionFailureJson = JsonProtocol.taskEndReasonToJson(exceptionFailure)
@@ -813,7 +814,7 @@ private[spark] object JsonProtocolSuite extends Assertions {
       hasHadoopInput: Boolean,
       hasOutput: Boolean,
       hasRecords: Boolean = true) = {
-    val t = new TaskMetrics
+    val t = TaskMetrics.empty
     t.setExecutorDeserializeTime(a)
     t.setExecutorRunTime(b)
     t.setResultSize(c)
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 0f8648f890c7a2cbad51a83fc16128a5677ec539..6fc49a08fe31662ce9119833ffef6d0f9c9bea9b 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -688,6 +688,18 @@ object MimaExcludes {
       ) ++ Seq(
         // [SPARK-4452][Core]Shuffle data structures can starve others on the same thread for memory
         ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.util.collection.Spillable")
+      ) ++ Seq(
+        // SPARK-14654: New accumulator API
+        ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ExceptionFailure$"),
+        ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ExceptionFailure.apply"),
+        ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ExceptionFailure.metrics"),
+        ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ExceptionFailure.copy"),
+        ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ExceptionFailure.this"),
+        ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.executor.ShuffleReadMetrics.remoteBlocksFetched"),
+        ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.executor.ShuffleReadMetrics.totalBlocksFetched"),
+        ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.executor.ShuffleReadMetrics.localBlocksFetched"),
+        ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.ShuffleReadMetrics.remoteBlocksFetched"),
+        ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.ShuffleReadMetrics.localBlocksFetched")
       )
     case v if v.startsWith("1.6") =>
       Seq(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index 520ceaaaea65477bca6e85860efa4db008222ca4..d6516f26a70f359312eba5154a543a603151edd1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -106,7 +106,7 @@ private[sql] case class RDDScanExec(
     override val nodeName: String) extends LeafExecNode {
 
   private[sql] override lazy val metrics = Map(
-    "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
 
   protected override def doExecute(): RDD[InternalRow] = {
     val numOutputRows = longMetric("numOutputRows")
@@ -147,7 +147,7 @@ private[sql] case class RowDataSourceScanExec(
   extends DataSourceScanExec with CodegenSupport {
 
   private[sql] override lazy val metrics =
-    Map("numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+    Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
 
   val outputUnsafeRows = relation match {
     case r: HadoopFsRelation if r.fileFormat.isInstanceOf[ParquetSource] =>
@@ -216,7 +216,7 @@ private[sql] case class BatchedDataSourceScanExec(
   extends DataSourceScanExec with CodegenSupport {
 
   private[sql] override lazy val metrics =
-    Map("numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"),
+    Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
       "scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time"))
 
   protected override def doExecute(): RDD[InternalRow] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
index 7c4756663a6e0d9ab0f130d22fa95ca03e08ac0e..c201822d4479a445823583c3e1f2a452ddde76b1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
@@ -40,7 +40,7 @@ case class ExpandExec(
   extends UnaryExecNode with CodegenSupport {
 
   private[sql] override lazy val metrics = Map(
-    "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
 
   // The GroupExpressions can output data with arbitrary partitioning, so set it
   // as UNKNOWN partitioning
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
index 10cfec3330a2dc6c9609655b55f542a7f3720fbc..934bc38dc47cbaedb7471277eaee8b5ca52e834e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
@@ -56,7 +56,7 @@ case class GenerateExec(
   extends UnaryExecNode {
 
   private[sql] override lazy val metrics = Map(
-    "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
 
   override def producedAttributes: AttributeSet = AttributeSet(output)
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala
index 4ab447a47b2c9ad7bf08c48be5286e8f3d870449..c5e78b033359d07dcfb53319a971abcf9b1b5954 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala
@@ -31,7 +31,7 @@ private[sql] case class LocalTableScanExec(
     rows: Seq[InternalRow]) extends LeafExecNode {
 
   private[sql] override lazy val metrics = Map(
-    "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
 
   private val unsafeRows: Array[InternalRow] = {
     val proj = UnsafeProjection.create(output, output)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 861ff3cd15874fb4721864b82eef662dbddb894b..0bbe970420707ca1cbaab9d16c05842561b40b21 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.plans.QueryPlan
 import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetric}
+import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.types.DataType
 import org.apache.spark.util.ThreadUtils
 
@@ -77,7 +77,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
   /**
    * Return all metrics containing metrics of this SparkPlan.
    */
-  private[sql] def metrics: Map[String, SQLMetric[_, _]] = Map.empty
+  private[sql] def metrics: Map[String, SQLMetric] = Map.empty
 
   /**
    * Reset all the metrics.
@@ -89,8 +89,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
   /**
    * Return a LongSQLMetric according to the name.
    */
-  private[sql] def longMetric(name: String): LongSQLMetric =
-    metrics(name).asInstanceOf[LongSQLMetric]
+  private[sql] def longMetric(name: String): SQLMetric = metrics(name)
 
   // TODO: Move to `DistributedPlan`
   /** Specifies how data is partitioned across different nodes in the cluster. */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
index cb4b1cfeb9ba382f023ffc77953337541a3d42cd..f84070a0c4bcbc45861dfb54ff410f4516d877a7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
@@ -55,8 +55,7 @@ private[sql] object SparkPlanInfo {
       case _ => plan.children ++ plan.subqueries
     }
     val metrics = plan.metrics.toSeq.map { case (key, metric) =>
-      new SQLMetricInfo(metric.name.getOrElse(key), metric.id,
-        Utils.getFormattedClassName(metric.param))
+      new SQLMetricInfo(metric.name.getOrElse(key), metric.id, metric.metricType)
     }
 
     new SparkPlanInfo(plan.nodeName, plan.simpleString, children.map(fromSparkPlan),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
index 362d0d7a72264aef41a54226c4ae1eeb4aa7df5f..484923428f4adcfb35fc24f61e1d96d8f4e796b2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
@@ -26,7 +26,7 @@ import com.google.common.io.ByteStreams
 
 import org.apache.spark.serializer.{DeserializationStream, SerializationStream, Serializer, SerializerInstance}
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
-import org.apache.spark.sql.execution.metric.LongSQLMetric
+import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.unsafe.Platform
 
 /**
@@ -42,7 +42,7 @@ import org.apache.spark.unsafe.Platform
  */
 private[sql] class UnsafeRowSerializer(
     numFields: Int,
-    dataSize: LongSQLMetric = null) extends Serializer with Serializable {
+    dataSize: SQLMetric = null) extends Serializer with Serializable {
   override def newInstance(): SerializerInstance =
     new UnsafeRowSerializerInstance(numFields, dataSize)
   override private[spark] def supportsRelocationOfSerializedObjects: Boolean = true
@@ -50,7 +50,7 @@ private[sql] class UnsafeRowSerializer(
 
 private class UnsafeRowSerializerInstance(
     numFields: Int,
-    dataSize: LongSQLMetric) extends SerializerInstance {
+    dataSize: SQLMetric) extends SerializerInstance {
   /**
    * Serializes a stream of UnsafeRows. Within the stream, each record consists of a record
    * length (stored as a 4-byte integer, written high byte first), followed by the record's bytes.
@@ -60,13 +60,10 @@ private class UnsafeRowSerializerInstance(
     private[this] val dOut: DataOutputStream =
       new DataOutputStream(new BufferedOutputStream(out))
 
-    // LongSQLMetricParam.add() is faster than LongSQLMetric.+=
-    val localDataSize = if (dataSize != null) dataSize.localValue else null
-
     override def writeValue[T: ClassTag](value: T): SerializationStream = {
       val row = value.asInstanceOf[UnsafeRow]
-      if (localDataSize != null) {
-        localDataSize.add(row.getSizeInBytes)
+      if (dataSize != null) {
+        dataSize.add(row.getSizeInBytes)
       }
       dOut.writeInt(row.getSizeInBytes)
       row.writeToStream(dOut, writeBuffer)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index 6a03bd08c547e8cdb4df8e336ce6944195900cc0..15b4abe806678ea95cc03cb90d7300531914ce17 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.util.toCommentSafeString
 import org.apache.spark.sql.execution.aggregate.TungstenAggregate
 import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
-import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetrics}
+import org.apache.spark.sql.execution.metric.SQLMetrics
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 
@@ -52,11 +52,7 @@ trait CodegenSupport extends SparkPlan {
    * @return name of the variable representing the metric
    */
   def metricTerm(ctx: CodegenContext, name: String): String = {
-    val metric = ctx.addReferenceObj(name, longMetric(name))
-    val value = ctx.freshName("metricValue")
-    val cls = classOf[LongSQLMetricValue].getName
-    ctx.addMutableState(cls, value, s"$value = ($cls) $metric.localValue();")
-    value
+    ctx.addReferenceObj(name, longMetric(name))
   }
 
   /**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregateExec.scala
index 3169e0a2fd86a34018e2572eace0ef055a16a3c9..2e74d59c5f5b60dc2a8f2c7ed292bdc84d1efe89 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregateExec.scala
@@ -46,7 +46,7 @@ case class SortBasedAggregateExec(
       AttributeSet(aggregateBufferAttributes)
 
   override private[sql] lazy val metrics = Map(
-    "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
 
   override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
index c35d781d3ebf5f28f302c3f071b8fe9d486d0250..f392b135ce7871a60ccfb6006f881e0edbd0adab 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.aggregate
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction}
-import org.apache.spark.sql.execution.metric.LongSQLMetric
+import org.apache.spark.sql.execution.metric.SQLMetric
 
 /**
  * An iterator used to evaluate [[AggregateFunction]]. It assumes the input rows have been
@@ -35,7 +35,7 @@ class SortBasedAggregationIterator(
     initialInputBufferOffset: Int,
     resultExpressions: Seq[NamedExpression],
     newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection,
-    numOutputRows: LongSQLMetric)
+    numOutputRows: SQLMetric)
   extends AggregationIterator(
     groupingExpressions,
     valueAttributes,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 16362f756f78707ac8f25ba85a41377f448ef25e..d0ba37ee1338be622834057d951bd8d7b82e2b4d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetrics}
+import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
 import org.apache.spark.sql.types.{DecimalType, StringType, StructType}
 import org.apache.spark.unsafe.KVIterator
 
@@ -51,7 +51,7 @@ case class TungstenAggregate(
       aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
 
   override private[sql] lazy val metrics = Map(
-    "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"),
+    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
     "peakMemory" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory"),
     "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"),
     "aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "aggregate time"))
@@ -309,8 +309,8 @@ case class TungstenAggregate(
   def finishAggregate(
       hashMap: UnsafeFixedWidthAggregationMap,
       sorter: UnsafeKVExternalSorter,
-      peakMemory: LongSQLMetricValue,
-      spillSize: LongSQLMetricValue): KVIterator[UnsafeRow, UnsafeRow] = {
+      peakMemory: SQLMetric,
+      spillSize: SQLMetric): KVIterator[UnsafeRow, UnsafeRow] = {
 
     // update peak execution memory
     val mapMemory = hashMap.getPeakMemoryUsedBytes
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index 9db5087fe03ef2b37bb8cccce91d1e38a5bcce08..243aa15deba3db7846c68ced87f1ef6614b7830e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
 import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, UnsafeKVExternalSorter}
-import org.apache.spark.sql.execution.metric.LongSQLMetric
+import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.unsafe.KVIterator
 
@@ -86,9 +86,9 @@ class TungstenAggregationIterator(
     originalInputAttributes: Seq[Attribute],
     inputIter: Iterator[InternalRow],
     testFallbackStartsAt: Option[(Int, Int)],
-    numOutputRows: LongSQLMetric,
-    peakMemory: LongSQLMetric,
-    spillSize: LongSQLMetric)
+    numOutputRows: SQLMetric,
+    peakMemory: SQLMetric,
+    spillSize: SQLMetric)
   extends AggregationIterator(
     groupingExpressions,
     originalInputAttributes,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index 83f527f5551acd8a6a9da2681685f0fb1d619525..77be613b837622b6a19177f14d1eb77f5d7cecc3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -103,7 +103,7 @@ case class FilterExec(condition: Expression, child: SparkPlan)
   }
 
   private[sql] override lazy val metrics = Map(
-    "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
 
   override def inputRDDs(): Seq[RDD[InternalRow]] = {
     child.asInstanceOf[CodegenSupport].inputRDDs()
@@ -229,7 +229,7 @@ case class SampleExec(
   override def output: Seq[Attribute] = child.output
 
   private[sql] override lazy val metrics = Map(
-    "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
 
   protected override def doExecute(): RDD[InternalRow] = {
     if (withReplacement) {
@@ -322,7 +322,7 @@ case class RangeExec(
   extends LeafExecNode with CodegenSupport {
 
   private[sql] override lazy val metrics = Map(
-    "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
 
   // output attributes should not affect the results
   override lazy val cleanArgs: Seq[Any] = Seq(start, step, numSlices, numElements)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
index cb957b9666f550c5aead5afe7dae6799579c172c..577c34ba618d9a5a2434fb66cd6d2d17f0604b93 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.columnar
 
 import scala.collection.mutable.ArrayBuffer
 
-import org.apache.spark.{Accumulable, Accumulator, Accumulators}
+import org.apache.spark.{Accumulable, Accumulator, AccumulatorContext}
 import org.apache.spark.network.util.JavaUtils
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
@@ -204,7 +204,7 @@ private[sql] case class InMemoryRelation(
     Seq(_cachedColumnBuffers, statisticsToBePropagated, batchStats)
 
   private[sql] def uncache(blocking: Boolean): Unit = {
-    Accumulators.remove(batchStats.id)
+    AccumulatorContext.remove(batchStats.id)
     cachedColumnBuffers.unpersist(blocking)
     _cachedColumnBuffers = null
   }
@@ -217,7 +217,7 @@ private[sql] case class InMemoryTableScanExec(
   extends LeafExecNode {
 
   private[sql] override lazy val metrics = Map(
-    "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
 
   override def output: Seq[Attribute] = attributes
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
index 573ca195ac13fb80b830c1afd9e2b886c872ae37..b6ecd3cb065ae21378ca6adcab159ab30043f9f1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
@@ -38,10 +38,10 @@ case class BroadcastExchangeExec(
     child: SparkPlan) extends Exchange {
 
   override private[sql] lazy val metrics = Map(
-    "dataSize" -> SQLMetrics.createLongMetric(sparkContext, "data size (bytes)"),
-    "collectTime" -> SQLMetrics.createLongMetric(sparkContext, "time to collect (ms)"),
-    "buildTime" -> SQLMetrics.createLongMetric(sparkContext, "time to build (ms)"),
-    "broadcastTime" -> SQLMetrics.createLongMetric(sparkContext, "time to broadcast (ms)"))
+    "dataSize" -> SQLMetrics.createMetric(sparkContext, "data size (bytes)"),
+    "collectTime" -> SQLMetrics.createMetric(sparkContext, "time to collect (ms)"),
+    "buildTime" -> SQLMetrics.createMetric(sparkContext, "time to build (ms)"),
+    "broadcastTime" -> SQLMetrics.createMetric(sparkContext, "time to broadcast (ms)"))
 
   override def outputPartitioning: Partitioning = BroadcastPartitioning(mode)
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
index b0a6b8f28a4670e5f3713c1ed58dfbcf07fb5197..587c603192cceec6e02a4265ed0e4cd22f80ccc5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
@@ -46,7 +46,7 @@ case class BroadcastHashJoinExec(
   extends BinaryExecNode with HashJoin with CodegenSupport {
 
   override private[sql] lazy val metrics = Map(
-    "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
 
   override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
index 51afa0017dd268357626086a9d7a0a47c71a3e1b..a659bf26e32df1e12494236e412c8c7892e097f5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
@@ -35,7 +35,7 @@ case class BroadcastNestedLoopJoinExec(
     condition: Option[Expression]) extends BinaryExecNode {
 
   override private[sql] lazy val metrics = Map(
-    "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
 
   /** BuildRight means the right relation <=> the broadcast relation. */
   private val (streamed, broadcast) = buildSide match {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
index 67f59197ad0d4df14ec06124675827e1eca7ac65..8d7ecc442a9e1647c55981db3e691e2a7d20dfe5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
@@ -86,7 +86,7 @@ case class CartesianProductExec(
   override def output: Seq[Attribute] = left.output ++ right.output
 
   override private[sql] lazy val metrics = Map(
-    "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
 
   protected override def doExecute(): RDD[InternalRow] = {
     val numOutputRows = longMetric("numOutputRows")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index d6feedc27244bff89e09e567ca27a5d6aa793db2..9c173d7bf1011beac9ea1abf9ffeac6541d9f7ca 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.execution.{RowIterator, SparkPlan}
-import org.apache.spark.sql.execution.metric.LongSQLMetric
+import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.types.{IntegralType, LongType}
 
 trait HashJoin {
@@ -201,7 +201,7 @@ trait HashJoin {
   protected def join(
       streamedIter: Iterator[InternalRow],
       hashed: HashedRelation,
-      numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
+      numOutputRows: SQLMetric): Iterator[InternalRow] = {
 
     val joinedIter = joinType match {
       case Inner =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
index a242a078f608105e96291132d2bbe6fe047f04ff..3ef2fec352203cccf08e2cc2e01e443db9ecda9d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
@@ -40,7 +40,7 @@ case class ShuffledHashJoinExec(
   extends BinaryExecNode with HashJoin {
 
   override private[sql] lazy val metrics = Map(
-    "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"),
+    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
     "buildDataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size of build side"),
     "buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map"))
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index a4c5491affe839a55419d53ab9ec7ad034038e7d..775f8ac50818f087ed864c17fb43893706c0acf3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCo
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, RowIterator, SparkPlan}
-import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics}
+import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
 import org.apache.spark.util.collection.BitSet
 
 /**
@@ -41,7 +41,7 @@ case class SortMergeJoinExec(
     right: SparkPlan) extends BinaryExecNode with CodegenSupport {
 
   override private[sql] lazy val metrics = Map(
-    "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
 
   override def output: Seq[Attribute] = {
     joinType match {
@@ -734,7 +734,7 @@ private class LeftOuterIterator(
     rightNullRow: InternalRow,
     boundCondition: InternalRow => Boolean,
     resultProj: InternalRow => InternalRow,
-    numOutputRows: LongSQLMetric)
+    numOutputRows: SQLMetric)
   extends OneSideOuterIterator(
     smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows) {
 
@@ -750,7 +750,7 @@ private class RightOuterIterator(
     leftNullRow: InternalRow,
     boundCondition: InternalRow => Boolean,
     resultProj: InternalRow => InternalRow,
-    numOutputRows: LongSQLMetric)
+    numOutputRows: SQLMetric)
   extends OneSideOuterIterator(smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows) {
 
   protected override def setStreamSideOutput(row: InternalRow): Unit = joinedRow.withRight(row)
@@ -778,7 +778,7 @@ private abstract class OneSideOuterIterator(
     bufferedSideNullRow: InternalRow,
     boundCondition: InternalRow => Boolean,
     resultProj: InternalRow => InternalRow,
-    numOutputRows: LongSQLMetric) extends RowIterator {
+    numOutputRows: SQLMetric) extends RowIterator {
 
   // A row to store the joined result, reused many times
   protected[this] val joinedRow: JoinedRow = new JoinedRow()
@@ -1016,7 +1016,7 @@ private class SortMergeFullOuterJoinScanner(
 private class FullOuterIterator(
     smjScanner: SortMergeFullOuterJoinScanner,
     resultProj: InternalRow => InternalRow,
-    numRows: LongSQLMetric) extends RowIterator {
+    numRows: SQLMetric) extends RowIterator {
   private[this] val joinedRow: JoinedRow = smjScanner.getJoinedRow()
 
   override def advanceNext(): Boolean = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala
index 2708219ad34857383597ec2e3a645b114e0f0f7d..adb81519dbc8351e5ae01809a50ef83b517ab69e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala
@@ -27,4 +27,4 @@ import org.apache.spark.annotation.DeveloperApi
 class SQLMetricInfo(
     val name: String,
     val accumulatorId: Long,
-    val metricParam: String)
+    val metricType: String)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
index 5755c00c1f9336ed80c92d9e0a012f10e0c64fb2..7bf92252726129b8c045ddc5f2767b7f6f191cb0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
@@ -19,200 +19,106 @@ package org.apache.spark.sql.execution.metric
 
 import java.text.NumberFormat
 
-import org.apache.spark.{Accumulable, AccumulableParam, Accumulators, SparkContext}
+import org.apache.spark.{NewAccumulator, SparkContext}
 import org.apache.spark.scheduler.AccumulableInfo
 import org.apache.spark.util.Utils
 
-/**
- * Create a layer for specialized metric. We cannot add `@specialized` to
- * `Accumulable/AccumulableParam` because it will break Java source compatibility.
- *
- * An implementation of SQLMetric should override `+=` and `add` to avoid boxing.
- */
-private[sql] abstract class SQLMetric[R <: SQLMetricValue[T], T](
-    name: String,
-    val param: SQLMetricParam[R, T]) extends Accumulable[R, T](param.zero, param, Some(name)) {
 
-  // Provide special identifier as metadata so we can tell that this is a `SQLMetric` later
-  override def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = {
-    new AccumulableInfo(id, Some(name), update, value, true, countFailedValues,
-      Some(SQLMetrics.ACCUM_IDENTIFIER))
-  }
-
-  def reset(): Unit = {
-    this.value = param.zero
-  }
-}
-
-/**
- * Create a layer for specialized metric. We cannot add `@specialized` to
- * `Accumulable/AccumulableParam` because it will break Java source compatibility.
- */
-private[sql] trait SQLMetricParam[R <: SQLMetricValue[T], T] extends AccumulableParam[R, T] {
-
-  /**
-   * A function that defines how we aggregate the final accumulator results among all tasks,
-   * and represent it in string for a SQL physical operator.
-   */
-  val stringValue: Seq[T] => String
-
-  def zero: R
-}
+class SQLMetric(val metricType: String, initValue: Long = 0L) extends NewAccumulator[Long, Long] {
+  // This is a workaround for SPARK-11013.
+  // We may use -1 as initial value of the accumulator, if the accumulator is valid, we will
+  // update it at the end of task and the value will be at least 0. Then we can filter out the -1
+  // values before calculate max, min, etc.
+  private[this] var _value = initValue
 
-/**
- * Create a layer for specialized metric. We cannot add `@specialized` to
- * `Accumulable/AccumulableParam` because it will break Java source compatibility.
- */
-private[sql] trait SQLMetricValue[T] extends Serializable {
+  override def copyAndReset(): SQLMetric = new SQLMetric(metricType, initValue)
 
-  def value: T
-
-  override def toString: String = value.toString
-}
-
-/**
- * A wrapper of Long to avoid boxing and unboxing when using Accumulator
- */
-private[sql] class LongSQLMetricValue(private var _value : Long) extends SQLMetricValue[Long] {
-
-  def add(incr: Long): LongSQLMetricValue = {
-    _value += incr
-    this
+  override def merge(other: NewAccumulator[Long, Long]): Unit = other match {
+    case o: SQLMetric => _value += o.localValue
+    case _ => throw new UnsupportedOperationException(
+      s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
   }
 
-  // Although there is a boxing here, it's fine because it's only called in SQLListener
-  override def value: Long = _value
-
-  // Needed for SQLListenerSuite
-  override def equals(other: Any): Boolean = other match {
-    case o: LongSQLMetricValue => value == o.value
-    case _ => false
-  }
+  override def isZero(): Boolean = _value == initValue
 
-  override def hashCode(): Int = _value.hashCode()
-}
+  override def add(v: Long): Unit = _value += v
 
-/**
- * A specialized long Accumulable to avoid boxing and unboxing when using Accumulator's
- * `+=` and `add`.
- */
-private[sql] class LongSQLMetric private[metric](name: String, param: LongSQLMetricParam)
-  extends SQLMetric[LongSQLMetricValue, Long](name, param) {
+  def +=(v: Long): Unit = _value += v
 
-  override def +=(term: Long): Unit = {
-    localValue.add(term)
-  }
+  override def localValue: Long = _value
 
-  override def add(term: Long): Unit = {
-    localValue.add(term)
+  // Provide special identifier as metadata so we can tell that this is a `SQLMetric` later
+  private[spark] override def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = {
+    new AccumulableInfo(id, name, update, value, true, true, Some(SQLMetrics.ACCUM_IDENTIFIER))
   }
-}
-
-private class LongSQLMetricParam(val stringValue: Seq[Long] => String, initialValue: Long)
-  extends SQLMetricParam[LongSQLMetricValue, Long] {
-
-  override def addAccumulator(r: LongSQLMetricValue, t: Long): LongSQLMetricValue = r.add(t)
 
-  override def addInPlace(r1: LongSQLMetricValue, r2: LongSQLMetricValue): LongSQLMetricValue =
-    r1.add(r2.value)
-
-  override def zero(initialValue: LongSQLMetricValue): LongSQLMetricValue = zero
-
-  override def zero: LongSQLMetricValue = new LongSQLMetricValue(initialValue)
+  def reset(): Unit = _value = initValue
 }
 
-private object LongSQLMetricParam
-  extends LongSQLMetricParam(x => NumberFormat.getInstance().format(x.sum), 0L)
-
-private object StatisticsBytesSQLMetricParam extends LongSQLMetricParam(
-  (values: Seq[Long]) => {
-    // This is a workaround for SPARK-11013.
-    // We use -1 as initial value of the accumulator, if the accumulator is valid, we will update
-    // it at the end of task and the value will be at least 0.
-    val validValues = values.filter(_ >= 0)
-    val Seq(sum, min, med, max) = {
-      val metric = if (validValues.length == 0) {
-        Seq.fill(4)(0L)
-      } else {
-        val sorted = validValues.sorted
-        Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1))
-      }
-      metric.map(Utils.bytesToString)
-    }
-    s"\n$sum ($min, $med, $max)"
-  }, -1L)
-
-private object StatisticsTimingSQLMetricParam extends LongSQLMetricParam(
-  (values: Seq[Long]) => {
-    // This is a workaround for SPARK-11013.
-    // We use -1 as initial value of the accumulator, if the accumulator is valid, we will update
-    // it at the end of task and the value will be at least 0.
-    val validValues = values.filter(_ >= 0)
-    val Seq(sum, min, med, max) = {
-      val metric = if (validValues.length == 0) {
-        Seq.fill(4)(0L)
-      } else {
-        val sorted = validValues.sorted
-        Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1))
-      }
-      metric.map(Utils.msDurationToString)
-    }
-    s"\n$sum ($min, $med, $max)"
-  }, -1L)
 
 private[sql] object SQLMetrics {
-
   // Identifier for distinguishing SQL metrics from other accumulators
   private[sql] val ACCUM_IDENTIFIER = "sql"
 
-  private def createLongMetric(
-      sc: SparkContext,
-      name: String,
-      param: LongSQLMetricParam): LongSQLMetric = {
-    val acc = new LongSQLMetric(name, param)
-    // This is an internal accumulator so we need to register it explicitly.
-    Accumulators.register(acc)
-    sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc))
-    acc
-  }
+  private[sql] val SUM_METRIC = "sum"
+  private[sql] val SIZE_METRIC = "size"
+  private[sql] val TIMING_METRIC = "timing"
 
-  def createLongMetric(sc: SparkContext, name: String): LongSQLMetric = {
-    createLongMetric(sc, name, LongSQLMetricParam)
+  def createMetric(sc: SparkContext, name: String): SQLMetric = {
+    val acc = new SQLMetric(SUM_METRIC)
+    acc.register(sc, name = Some(name), countFailedValues = true)
+    acc
   }
 
   /**
    * Create a metric to report the size information (including total, min, med, max) like data size,
    * spill size, etc.
    */
-  def createSizeMetric(sc: SparkContext, name: String): LongSQLMetric = {
+  def createSizeMetric(sc: SparkContext, name: String): SQLMetric = {
     // The final result of this metric in physical operator UI may looks like:
     // data size total (min, med, max):
     // 100GB (100MB, 1GB, 10GB)
-    createLongMetric(sc, s"$name total (min, med, max)", StatisticsBytesSQLMetricParam)
+    val acc = new SQLMetric(SIZE_METRIC, -1)
+    acc.register(sc, name = Some(s"$name total (min, med, max)"), countFailedValues = true)
+    acc
   }
 
-  def createTimingMetric(sc: SparkContext, name: String): LongSQLMetric = {
+  def createTimingMetric(sc: SparkContext, name: String): SQLMetric = {
     // The final result of this metric in physical operator UI may looks like:
     // duration(min, med, max):
     // 5s (800ms, 1s, 2s)
-    createLongMetric(sc, s"$name total (min, med, max)", StatisticsTimingSQLMetricParam)
-  }
-
-  def getMetricParam(metricParamName: String): SQLMetricParam[SQLMetricValue[Any], Any] = {
-    val longSQLMetricParam = Utils.getFormattedClassName(LongSQLMetricParam)
-    val bytesSQLMetricParam = Utils.getFormattedClassName(StatisticsBytesSQLMetricParam)
-    val timingsSQLMetricParam = Utils.getFormattedClassName(StatisticsTimingSQLMetricParam)
-    val metricParam = metricParamName match {
-      case `longSQLMetricParam` => LongSQLMetricParam
-      case `bytesSQLMetricParam` => StatisticsBytesSQLMetricParam
-      case `timingsSQLMetricParam` => StatisticsTimingSQLMetricParam
-    }
-    metricParam.asInstanceOf[SQLMetricParam[SQLMetricValue[Any], Any]]
+    val acc = new SQLMetric(TIMING_METRIC, -1)
+    acc.register(sc, name = Some(s"$name total (min, med, max)"), countFailedValues = true)
+    acc
   }
 
   /**
-   * A metric that its value will be ignored. Use this one when we need a metric parameter but don't
-   * care about the value.
+   * A function that defines how we aggregate the final accumulator results among all tasks,
+   * and represent it in string for a SQL physical operator.
    */
-  val nullLongMetric = new LongSQLMetric("null", LongSQLMetricParam)
+  def stringValue(metricsType: String, values: Seq[Long]): String = {
+    if (metricsType == SUM_METRIC) {
+      NumberFormat.getInstance().format(values.sum)
+    } else {
+      val strFormat: Long => String = if (metricsType == SIZE_METRIC) {
+        Utils.bytesToString
+      } else if (metricsType == TIMING_METRIC) {
+        Utils.msDurationToString
+      } else {
+        throw new IllegalStateException("unexpected metrics type: " + metricsType)
+      }
+
+      val validValues = values.filter(_ >= 0)
+      val Seq(sum, min, med, max) = {
+        val metric = if (validValues.length == 0) {
+          Seq.fill(4)(0L)
+        } else {
+          val sorted = validValues.sorted
+          Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1))
+        }
+        metric.map(strFormat)
+      }
+      s"\n$sum ($min, $med, $max)"
+    }
+  }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
index 5ae9e916adae16d0dc632037ea99e0f1a871a635..9118593c0e4ce6e21bead848058ca0690ff70cb3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
@@ -164,7 +164,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi
         taskEnd.taskInfo.taskId,
         taskEnd.stageId,
         taskEnd.stageAttemptId,
-        taskEnd.taskMetrics.accumulatorUpdates(),
+        taskEnd.taskMetrics.accumulators().map(a => a.toInfo(Some(a.localValue), None)),
         finishTask = true)
     }
   }
@@ -296,7 +296,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi
           }
         }.filter { case (id, _) => executionUIData.accumulatorMetrics.contains(id) }
         mergeAccumulatorUpdates(accumulatorUpdates, accumulatorId =>
-          executionUIData.accumulatorMetrics(accumulatorId).metricParam)
+          executionUIData.accumulatorMetrics(accumulatorId).metricType)
       case None =>
         // This execution has been dropped
         Map.empty
@@ -305,11 +305,11 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi
 
   private def mergeAccumulatorUpdates(
       accumulatorUpdates: Seq[(Long, Any)],
-      paramFunc: Long => SQLMetricParam[SQLMetricValue[Any], Any]): Map[Long, String] = {
+      metricTypeFunc: Long => String): Map[Long, String] = {
     accumulatorUpdates.groupBy(_._1).map { case (accumulatorId, values) =>
-      val param = paramFunc(accumulatorId)
-      (accumulatorId,
-        param.stringValue(values.map(_._2.asInstanceOf[SQLMetricValue[Any]].value)))
+      val metricType = metricTypeFunc(accumulatorId)
+      accumulatorId ->
+        SQLMetrics.stringValue(metricType, values.map(_._2.asInstanceOf[Long]))
     }
   }
 
@@ -337,7 +337,7 @@ private[spark] class SQLHistoryListener(conf: SparkConf, sparkUI: SparkUI)
         // Filter out accumulators that are not SQL metrics
         // For now we assume all SQL metrics are Long's that have been JSON serialized as String's
         if (a.metadata == Some(SQLMetrics.ACCUM_IDENTIFIER)) {
-          val newValue = new LongSQLMetricValue(a.update.map(_.toString.toLong).getOrElse(0L))
+          val newValue = a.update.map(_.toString.toLong).getOrElse(0L)
           Some(a.copy(update = Some(newValue)))
         } else {
           None
@@ -403,7 +403,7 @@ private[ui] class SQLExecutionUIData(
 private[ui] case class SQLPlanMetric(
     name: String,
     accumulatorId: Long,
-    metricParam: SQLMetricParam[SQLMetricValue[Any], Any])
+    metricType: String)
 
 /**
  * Store all accumulatorUpdates for all tasks in a Spark stage.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
index 1959f1e3680a0464f769b1b71cc90ee253c55767..8f5681bfc7cc62ef8966d8414051970598ed593f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
@@ -80,8 +80,7 @@ private[sql] object SparkPlanGraph {
     planInfo.nodeName match {
       case "WholeStageCodegen" =>
         val metrics = planInfo.metrics.map { metric =>
-          SQLPlanMetric(metric.name, metric.accumulatorId,
-            SQLMetrics.getMetricParam(metric.metricParam))
+          SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType)
         }
 
         val cluster = new SparkPlanGraphCluster(
@@ -106,8 +105,7 @@ private[sql] object SparkPlanGraph {
         edges += SparkPlanGraphEdge(node.id, parent.id)
       case name =>
         val metrics = planInfo.metrics.map { metric =>
-          SQLPlanMetric(metric.name, metric.accumulatorId,
-            SQLMetrics.getMetricParam(metric.metricParam))
+          SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType)
         }
         val node = new SparkPlanGraphNode(
           nodeIdGenerator.getAndIncrement(), planInfo.nodeName,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index 4aea21e52a685c8f2721ef2861a1d519277c95c2..0e6356b5781e6838fff7c179a328aa347b38536d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -22,7 +22,7 @@ import scala.language.postfixOps
 
 import org.scalatest.concurrent.Eventually._
 
-import org.apache.spark.Accumulators
+import org.apache.spark.AccumulatorContext
 import org.apache.spark.sql.execution.RDDScanExec
 import org.apache.spark.sql.execution.columnar._
 import org.apache.spark.sql.execution.exchange.ShuffleExchange
@@ -333,11 +333,11 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
     sql("SELECT * FROM t1").count()
     sql("SELECT * FROM t2").count()
 
-    Accumulators.synchronized {
-      val accsSize = Accumulators.originals.size
+    AccumulatorContext.synchronized {
+      val accsSize = AccumulatorContext.originals.size
       sqlContext.uncacheTable("t1")
       sqlContext.uncacheTable("t2")
-      assert((accsSize - 2) == Accumulators.originals.size)
+      assert((accsSize - 2) == AccumulatorContext.originals.size)
     }
   }
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index 1859c6e7adde8ace8945c501660af92cf8cd4723..8de4d8bbd4e07a6c60b3f58d6ecd36c7fe6b8011 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -37,8 +37,8 @@ import org.apache.spark.util.{JsonProtocol, Utils}
 class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
   import testImplicits._
 
-  test("LongSQLMetric should not box Long") {
-    val l = SQLMetrics.createLongMetric(sparkContext, "long")
+  test("SQLMetric should not box Long") {
+    val l = SQLMetrics.createMetric(sparkContext, "long")
     val f = () => {
       l += 1L
       l.add(1L)
@@ -300,12 +300,12 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
   }
 
   test("metrics can be loaded by history server") {
-    val metric = new LongSQLMetric("zanzibar", LongSQLMetricParam)
+    val metric = SQLMetrics.createMetric(sparkContext, "zanzibar")
     metric += 10L
     val metricInfo = metric.toInfo(Some(metric.localValue), None)
     metricInfo.update match {
-      case Some(v: LongSQLMetricValue) => assert(v.value === 10L)
-      case Some(v) => fail(s"metric value was not a LongSQLMetricValue: ${v.getClass.getName}")
+      case Some(v: Long) => assert(v === 10L)
+      case Some(v) => fail(s"metric value was not a Long: ${v.getClass.getName}")
       case _ => fail("metric update is missing")
     }
     assert(metricInfo.metadata === Some(SQLMetrics.ACCUM_IDENTIFIER))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
index 09bd7f6e8f0a8378fb761240108b1c4b8736275f..8572ed16aa2615b6cf32cb6855efeed7b2169701 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
@@ -21,18 +21,19 @@ import java.util.Properties
 
 import org.mockito.Mockito.{mock, when}
 
-import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite}
+import org.apache.spark._
 import org.apache.spark.executor.TaskMetrics
 import org.apache.spark.scheduler._
 import org.apache.spark.sql.{DataFrame, SQLContext}
 import org.apache.spark.sql.catalyst.util.quietly
 import org.apache.spark.sql.execution.{SparkPlanInfo, SQLExecution}
-import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetrics}
+import org.apache.spark.sql.execution.metric.SQLMetrics
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.ui.SparkUI
 
 class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
   import testImplicits._
+  import org.apache.spark.AccumulatorSuite.makeInfo
 
   private def createTestDataFrame: DataFrame = {
     Seq(
@@ -72,9 +73,11 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
 
   private def createTaskMetrics(accumulatorUpdates: Map[Long, Long]): TaskMetrics = {
     val metrics = mock(classOf[TaskMetrics])
-    when(metrics.accumulatorUpdates()).thenReturn(accumulatorUpdates.map { case (id, update) =>
-      new AccumulableInfo(id, Some(""), Some(new LongSQLMetricValue(update)),
-        value = None, internal = true, countFailedValues = true)
+    when(metrics.accumulators()).thenReturn(accumulatorUpdates.map { case (id, update) =>
+      val acc = new LongAccumulator
+      acc.metadata = AccumulatorMetadata(id, Some(""), true)
+      acc.setValue(update)
+      acc
     }.toSeq)
     metrics
   }
@@ -130,16 +133,17 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
 
     listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq(
       // (task id, stage id, stage attempt, accum updates)
-      (0L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()),
-      (1L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates())
+      (0L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)),
+      (1L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo))
     )))
 
     checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2))
 
     listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq(
       // (task id, stage id, stage attempt, accum updates)
-      (0L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()),
-      (1L, 0, 0, createTaskMetrics(accumulatorUpdates.mapValues(_ * 2)).accumulatorUpdates())
+      (0L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)),
+      (1L, 0, 0,
+        createTaskMetrics(accumulatorUpdates.mapValues(_ * 2)).accumulators().map(makeInfo))
     )))
 
     checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 3))
@@ -149,8 +153,8 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
 
     listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq(
       // (task id, stage id, stage attempt, accum updates)
-      (0L, 0, 1, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()),
-      (1L, 0, 1, createTaskMetrics(accumulatorUpdates).accumulatorUpdates())
+      (0L, 0, 1, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)),
+      (1L, 0, 1, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo))
     )))
 
     checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2))
@@ -189,8 +193,8 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
 
     listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq(
       // (task id, stage id, stage attempt, accum updates)
-      (0L, 1, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()),
-      (1L, 1, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates())
+      (0L, 1, 0, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)),
+      (1L, 1, 0, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo))
     )))
 
     checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 7))
@@ -358,7 +362,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
     val stageSubmitted = SparkListenerStageSubmitted(stageInfo)
     // This task has both accumulators that are SQL metrics and accumulators that are not.
     // The listener should only track the ones that are actually SQL metrics.
-    val sqlMetric = SQLMetrics.createLongMetric(sparkContext, "beach umbrella")
+    val sqlMetric = SQLMetrics.createMetric(sparkContext, "beach umbrella")
     val nonSqlMetric = sparkContext.accumulator[Int](0, "baseball")
     val sqlMetricInfo = sqlMetric.toInfo(Some(sqlMetric.localValue), None)
     val nonSqlMetricInfo = nonSqlMetric.toInfo(Some(nonSqlMetric.localValue), None)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
index eb25ea06295e6f4e820b14542d9b28536e79b4dd..8a0578c1ff53782af00261221d92e0f151e6f338 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
@@ -96,7 +96,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
           case w: WholeStageCodegenExec => w.child.longMetric("numOutputRows")
           case other => other.longMetric("numOutputRows")
         }
-        metrics += metric.value.value
+        metrics += metric.value
       }
     }
     sqlContext.listenerManager.register(listener)
@@ -126,9 +126,9 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
       override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
 
       override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
-        metrics += qe.executedPlan.longMetric("dataSize").value.value
+        metrics += qe.executedPlan.longMetric("dataSize").value
         val bottomAgg = qe.executedPlan.children(0).children(0)
-        metrics += bottomAgg.longMetric("dataSize").value.value
+        metrics += bottomAgg.longMetric("dataSize").value
       }
     }
     sqlContext.listenerManager.register(listener)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
index 007c3384e5701533dff7f66d42287a31b6b057be..b52b96a80447ccdcf59f4f381b860c36f8a32311 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
@@ -55,7 +55,7 @@ case class HiveTableScanExec(
     "Partition pruning predicates only supported for partitioned tables.")
 
   private[sql] override lazy val metrics = Map(
-    "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
 
   override def producedAttributes: AttributeSet = outputSet ++
     AttributeSet(partitionPruningPred.flatMap(_.references))