diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
index 1aa6ba42012618e84cd6773547c07cb2e12b3560..bf4eaa59ff589cbf04976d3f50ecdc6b7723f3bd 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
@@ -20,6 +20,7 @@ package org.apache.spark.shuffle.unsafe;
 import java.io.File;
 import java.io.IOException;
 import java.util.LinkedList;
+import javax.annotation.Nullable;
 
 import scala.Tuple2;
 
@@ -86,9 +87,12 @@ final class UnsafeShuffleExternalSorter {
 
   private final LinkedList<SpillInfo> spills = new LinkedList<SpillInfo>();
 
+  /** Peak memory used by this sorter so far, in bytes. **/
+  private long peakMemoryUsedBytes;
+
   // These variables are reset after spilling:
-  private UnsafeShuffleInMemorySorter sorter;
-  private MemoryBlock currentPage = null;
+  @Nullable private UnsafeShuffleInMemorySorter sorter;
+  @Nullable private MemoryBlock currentPage = null;
   private long currentPagePosition = -1;
   private long freeSpaceInCurrentPage = 0;
 
@@ -106,6 +110,7 @@ final class UnsafeShuffleExternalSorter {
     this.blockManager = blockManager;
     this.taskContext = taskContext;
     this.initialSize = initialSize;
+    this.peakMemoryUsedBytes = initialSize;
     this.numPartitions = numPartitions;
     // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
     this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
@@ -279,10 +284,26 @@ final class UnsafeShuffleExternalSorter {
     for (MemoryBlock page : allocatedPages) {
       totalPageSize += page.size();
     }
-    return sorter.getMemoryUsage() + totalPageSize;
+    return ((sorter == null) ? 0 : sorter.getMemoryUsage()) + totalPageSize;
+  }
+
+  private void updatePeakMemoryUsed() {
+    long mem = getMemoryUsage();
+    if (mem > peakMemoryUsedBytes) {
+      peakMemoryUsedBytes = mem;
+    }
+  }
+
+  /**
+   * Return the peak memory used so far, in bytes.
+   */
+  long getPeakMemoryUsedBytes() {
+    updatePeakMemoryUsed();
+    return peakMemoryUsedBytes;
   }
 
   private long freeMemory() {
+    updatePeakMemoryUsed();
     long memoryFreed = 0;
     for (MemoryBlock block : allocatedPages) {
       memoryManager.freePage(block);
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
index d47d6fc9c2ac4772d021b224cfd7d1b532b319be..6e2eeb37c86f145729dc29bb652438f35c451c46 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
@@ -27,6 +27,7 @@ import scala.Product2;
 import scala.collection.JavaConversions;
 import scala.reflect.ClassTag;
 import scala.reflect.ClassTag$;
+import scala.collection.immutable.Map;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.io.ByteStreams;
@@ -78,8 +79,9 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
   private final SparkConf sparkConf;
   private final boolean transferToEnabled;
 
-  private MapStatus mapStatus = null;
-  private UnsafeShuffleExternalSorter sorter = null;
+  @Nullable private MapStatus mapStatus;
+  @Nullable private UnsafeShuffleExternalSorter sorter;
+  private long peakMemoryUsedBytes = 0;
 
   /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */
   private static final class MyByteArrayOutputStream extends ByteArrayOutputStream {
@@ -131,9 +133,28 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
 
   @VisibleForTesting
   public int maxRecordSizeBytes() {
+    assert(sorter != null);
     return sorter.maxRecordSizeBytes;
   }
 
+  private void updatePeakMemoryUsed() {
+    // sorter can be null if this writer is closed
+    if (sorter != null) {
+      long mem = sorter.getPeakMemoryUsedBytes();
+      if (mem > peakMemoryUsedBytes) {
+        peakMemoryUsedBytes = mem;
+      }
+    }
+  }
+
+  /**
+   * Return the peak memory used so far, in bytes.
+   */
+  public long getPeakMemoryUsedBytes() {
+    updatePeakMemoryUsed();
+    return peakMemoryUsedBytes;
+  }
+
   /**
    * This convenience method should only be called in test code.
    */
@@ -144,7 +165,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
 
   @Override
   public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOException {
-    // Keep track of success so we know if we ecountered an exception
+    // Keep track of success so we know if we encountered an exception
     // We do this rather than a standard try/catch/re-throw to handle
     // generic throwables.
     boolean success = false;
@@ -189,6 +210,8 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
 
   @VisibleForTesting
   void closeAndWriteOutput() throws IOException {
+    assert(sorter != null);
+    updatePeakMemoryUsed();
     serBuffer = null;
     serOutputStream = null;
     final SpillInfo[] spills = sorter.closeAndGetSpills();
@@ -209,6 +232,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
 
   @VisibleForTesting
   void insertRecordIntoSorter(Product2<K, V> record) throws IOException {
+    assert(sorter != null);
     final K key = record._1();
     final int partitionId = partitioner.getPartition(key);
     serBuffer.reset();
@@ -431,6 +455,14 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
   @Override
   public Option<MapStatus> stop(boolean success) {
     try {
+      // Update task metrics from accumulators (null in UnsafeShuffleWriterSuite)
+      Map<String, Accumulator<Object>> internalAccumulators =
+        taskContext.internalMetricsToAccumulators();
+      if (internalAccumulators != null) {
+        internalAccumulators.apply(InternalAccumulator.PEAK_EXECUTION_MEMORY())
+          .add(getPeakMemoryUsedBytes());
+      }
+
       if (stopping) {
         return Option.apply(null);
       } else {
diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index 01a66084e918e509526842fd8a7e29759f77b0a4..20347433e16b2285f7b2db1277053461dbe1676b 100644
--- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -505,7 +505,7 @@ public final class BytesToBytesMap {
       // Here, we'll copy the data into our data pages. Because we only store a relative offset from
       // the key address instead of storing the absolute address of the value, the key and value
       // must be stored in the same memory page.
-      // (8 byte key length) (key) (8 byte value length) (value)
+      // (8 byte key length) (key) (value)
       final long requiredSize = 8 + keyLengthBytes + valueLengthBytes;
 
       // --- Figure out where to insert the new record ---------------------------------------------
@@ -655,7 +655,10 @@ public final class BytesToBytesMap {
     return pageSizeBytes;
   }
 
-  /** Returns the total amount of memory, in bytes, consumed by this map's managed structures. */
+  /**
+   * Returns the total amount of memory, in bytes, consumed by this map's managed structures.
+   * Note that this is also the peak memory used by this map, since the map is append-only.
+   */
   public long getTotalMemoryConsumption() {
     long totalDataPagesSize = 0L;
     for (MemoryBlock dataPage : dataPages) {
@@ -674,7 +677,6 @@ public final class BytesToBytesMap {
     return timeSpentResizingNs;
   }
 
-
   /**
    * Returns the average number of probes per key lookup.
    */
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index b984301cbbf2bbd55402e1167a12bc9036cfb68a..bf5f965a9d8dc748854ee63964983aff06eac13b 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -70,13 +70,14 @@ public final class UnsafeExternalSorter {
   private final LinkedList<UnsafeSorterSpillWriter> spillWriters = new LinkedList<>();
 
   // These variables are reset after spilling:
-  private UnsafeInMemorySorter inMemSorter;
+  @Nullable private UnsafeInMemorySorter inMemSorter;
   // Whether the in-mem sorter is created internally, or passed in from outside.
   // If it is passed in from outside, we shouldn't release the in-mem sorter's memory.
   private boolean isInMemSorterExternal = false;
   private MemoryBlock currentPage = null;
   private long currentPagePosition = -1;
   private long freeSpaceInCurrentPage = 0;
+  private long peakMemoryUsedBytes = 0;
 
   public static UnsafeExternalSorter createWithExistingInMemorySorter(
       TaskMemoryManager taskMemoryManager,
@@ -183,6 +184,7 @@ public final class UnsafeExternalSorter {
    * Sort and spill the current records in response to memory pressure.
    */
   public void spill() throws IOException {
+    assert(inMemSorter != null);
     logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)",
       Thread.currentThread().getId(),
       Utils.bytesToString(getMemoryUsage()),
@@ -219,7 +221,22 @@ public final class UnsafeExternalSorter {
     for (MemoryBlock page : allocatedPages) {
       totalPageSize += page.size();
     }
-    return inMemSorter.getMemoryUsage() + totalPageSize;
+    return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + totalPageSize;
+  }
+
+  private void updatePeakMemoryUsed() {
+    long mem = getMemoryUsage();
+    if (mem > peakMemoryUsedBytes) {
+      peakMemoryUsedBytes = mem;
+    }
+  }
+
+  /**
+   * Return the peak memory used so far, in bytes.
+   */
+  public long getPeakMemoryUsedBytes() {
+    updatePeakMemoryUsed();
+    return peakMemoryUsedBytes;
   }
 
   @VisibleForTesting
@@ -233,6 +250,7 @@ public final class UnsafeExternalSorter {
    * @return the number of bytes freed.
    */
   public long freeMemory() {
+    updatePeakMemoryUsed();
     long memoryFreed = 0;
     for (MemoryBlock block : allocatedPages) {
       taskMemoryManager.freePage(block);
@@ -277,7 +295,8 @@ public final class UnsafeExternalSorter {
    * @return true if the record can be inserted without requiring more allocations, false otherwise.
    */
   private boolean haveSpaceForRecord(int requiredSpace) {
-    assert (requiredSpace > 0);
+    assert(requiredSpace > 0);
+    assert(inMemSorter != null);
     return (inMemSorter.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage));
   }
 
@@ -290,6 +309,7 @@ public final class UnsafeExternalSorter {
    *                      the record size.
    */
   private void allocateSpaceForRecord(int requiredSpace) throws IOException {
+    assert(inMemSorter != null);
     // TODO: merge these steps to first calculate total memory requirements for this insert,
     // then try to acquire; no point in acquiring sort buffer only to spill due to no space in the
     // data page.
@@ -350,6 +370,7 @@ public final class UnsafeExternalSorter {
     if (!haveSpaceForRecord(totalSpaceRequired)) {
       allocateSpaceForRecord(totalSpaceRequired);
     }
+    assert(inMemSorter != null);
 
     final long recordAddress =
       taskMemoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition);
@@ -382,6 +403,7 @@ public final class UnsafeExternalSorter {
     if (!haveSpaceForRecord(totalSpaceRequired)) {
       allocateSpaceForRecord(totalSpaceRequired);
     }
+    assert(inMemSorter != null);
 
     final long recordAddress =
       taskMemoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition);
@@ -405,6 +427,7 @@ public final class UnsafeExternalSorter {
   }
 
   public UnsafeSorterIterator getSortedIterator() throws IOException {
+    assert(inMemSorter != null);
     final UnsafeSorterIterator inMemoryIterator = inMemSorter.getSortedIterator();
     int numIteratorsToMerge = spillWriters.size() + (inMemoryIterator.hasNext() ? 1 : 0);
     if (spillWriters.isEmpty()) {
diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css
index b1cef4704224793532f72fa67e8d7442d79dd0ed..648cd1b10480277bce49c52764942795338265cc 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/webui.css
+++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css
@@ -207,7 +207,7 @@ span.additional-metric-title {
 /* Hide all additional metrics by default. This is done here rather than using JavaScript to
  * avoid slow page loads for stage pages with large numbers (e.g., thousands) of tasks. */
 .scheduler_delay, .deserialization_time, .fetch_wait_time, .shuffle_read_remote,
-.serialization_time, .getting_result_time {
+.serialization_time, .getting_result_time, .peak_execution_memory {
   display: none;
 }
 
diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala
index eb75f26718e1971a7d08c959b5e3e50968fd197f..b6a0119c696fd14a8260aaa3d93b2e0853f0a431 100644
--- a/core/src/main/scala/org/apache/spark/Accumulators.scala
+++ b/core/src/main/scala/org/apache/spark/Accumulators.scala
@@ -152,8 +152,14 @@ class Accumulable[R, T] private[spark] (
     in.defaultReadObject()
     value_ = zero
     deserialized = true
-    val taskContext = TaskContext.get()
-    taskContext.registerAccumulator(this)
+    // Automatically register the accumulator when it is deserialized with the task closure.
+    // Note that internal accumulators are deserialized before the TaskContext is created and
+    // are registered in the TaskContext constructor.
+    if (!isInternal) {
+      val taskContext = TaskContext.get()
+      assume(taskContext != null, "Task context was null when deserializing user accumulators")
+      taskContext.registerAccumulator(this)
+    }
   }
 
   override def toString: String = if (value_ == null) "null" else value_.toString
@@ -248,10 +254,20 @@ GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializa
  * @param param helper object defining how to add elements of type `T`
  * @tparam T result type
  */
-class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T], name: Option[String])
-  extends Accumulable[T, T](initialValue, param, name) {
+class Accumulator[T] private[spark] (
+    @transient initialValue: T,
+    param: AccumulatorParam[T],
+    name: Option[String],
+    internal: Boolean)
+  extends Accumulable[T, T](initialValue, param, name, internal) {
+
+  def this(initialValue: T, param: AccumulatorParam[T], name: Option[String]) = {
+    this(initialValue, param, name, false)
+  }
 
-  def this(initialValue: T, param: AccumulatorParam[T]) = this(initialValue, param, None)
+  def this(initialValue: T, param: AccumulatorParam[T]) = {
+    this(initialValue, param, None, false)
+  }
 }
 
 /**
@@ -342,3 +358,37 @@ private[spark] object Accumulators extends Logging {
   }
 
 }
+
+private[spark] object InternalAccumulator {
+  val PEAK_EXECUTION_MEMORY = "peakExecutionMemory"
+  val TEST_ACCUMULATOR = "testAccumulator"
+
+  // For testing only.
+  // This needs to be a def since we don't want to reuse the same accumulator across stages.
+  private def maybeTestAccumulator: Option[Accumulator[Long]] = {
+    if (sys.props.contains("spark.testing")) {
+      Some(new Accumulator(
+        0L, AccumulatorParam.LongAccumulatorParam, Some(TEST_ACCUMULATOR), internal = true))
+    } else {
+      None
+    }
+  }
+
+  /**
+   * Accumulators for tracking internal metrics.
+   *
+   * These accumulators are created with the stage such that all tasks in the stage will
+   * add to the same set of accumulators. We do this to report the distribution of accumulator
+   * values across all tasks within each stage.
+   */
+  def create(): Seq[Accumulator[Long]] = {
+    Seq(
+      // Execution memory refers to the memory used by internal data structures created
+      // during shuffles, aggregations and joins. The value of this accumulator should be
+      // approximately the sum of the peak sizes across all such data structures created
+      // in this task. For SQL jobs, this only tracks all unsafe operators and ExternalSort.
+      new Accumulator(
+        0L, AccumulatorParam.LongAccumulatorParam, Some(PEAK_EXECUTION_MEMORY), internal = true)
+    ) ++ maybeTestAccumulator.toSeq
+  }
+}
diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala
index ceeb58075d3455cea98e852ed9e9169a1a65e421..289aab9bd9e516f910f0db1d97049f570439bf07 100644
--- a/core/src/main/scala/org/apache/spark/Aggregator.scala
+++ b/core/src/main/scala/org/apache/spark/Aggregator.scala
@@ -58,12 +58,7 @@ case class Aggregator[K, V, C] (
     } else {
       val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners)
       combiners.insertAll(iter)
-      // Update task metrics if context is not null
-      // TODO: Make context non optional in a future release
-      Option(context).foreach { c =>
-        c.taskMetrics.incMemoryBytesSpilled(combiners.memoryBytesSpilled)
-        c.taskMetrics.incDiskBytesSpilled(combiners.diskBytesSpilled)
-      }
+      updateMetrics(context, combiners)
       combiners.iterator
     }
   }
@@ -89,13 +84,18 @@ case class Aggregator[K, V, C] (
     } else {
       val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners)
       combiners.insertAll(iter)
-      // Update task metrics if context is not null
-      // TODO: Make context non-optional in a future release
-      Option(context).foreach { c =>
-        c.taskMetrics.incMemoryBytesSpilled(combiners.memoryBytesSpilled)
-        c.taskMetrics.incDiskBytesSpilled(combiners.diskBytesSpilled)
-      }
+      updateMetrics(context, combiners)
       combiners.iterator
     }
   }
+
+  /** Update task metrics after populating the external map. */
+  private def updateMetrics(context: TaskContext, map: ExternalAppendOnlyMap[_, _, _]): Unit = {
+    Option(context).foreach { c =>
+      c.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled)
+      c.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled)
+      c.internalMetricsToAccumulators(
+        InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes)
+    }
+  }
 }
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index 5d2c551d58514b8c94fe01dcc6a7fc7c258d1a10..63cca80b2d734ac595cdc9432ed869534f75934d 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -61,12 +61,12 @@ object TaskContext {
   protected[spark] def unset(): Unit = taskContext.remove()
 
   /**
-   * Return an empty task context that is not actually used.
-   * Internal use only.
+   * An empty task context that does not represent an actual task.
    */
-  private[spark] def empty(): TaskContext = {
-    new TaskContextImpl(0, 0, 0, 0, null, null)
+  private[spark] def empty(): TaskContextImpl = {
+    new TaskContextImpl(0, 0, 0, 0, null, null, Seq.empty)
   }
+
 }
 
 
@@ -187,4 +187,9 @@ abstract class TaskContext extends Serializable {
    * accumulator id and the value of the Map is the latest accumulator local value.
    */
   private[spark] def collectAccumulators(): Map[Long, Any]
+
+  /**
+   * Accumulators for tracking internal metrics indexed by the name.
+   */
+  private[spark] val internalMetricsToAccumulators: Map[String, Accumulator[Long]]
 }
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index 9ee168ae016f8b7d1d901a7896917100419c383f..5df94c6d3a103500eed031a3dd29a5371a119219 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -32,6 +32,7 @@ private[spark] class TaskContextImpl(
     override val attemptNumber: Int,
     override val taskMemoryManager: TaskMemoryManager,
     @transient private val metricsSystem: MetricsSystem,
+    internalAccumulators: Seq[Accumulator[Long]],
     val runningLocally: Boolean = false,
     val taskMetrics: TaskMetrics = TaskMetrics.empty)
   extends TaskContext
@@ -114,4 +115,11 @@ private[spark] class TaskContextImpl(
   private[spark] override def collectAccumulators(): Map[Long, Any] = synchronized {
     accumulators.mapValues(_.localValue).toMap
   }
+
+  private[spark] override val internalMetricsToAccumulators: Map[String, Accumulator[Long]] = {
+    // Explicitly register internal accumulators here because these are
+    // not captured in the task closure and are already deserialized
+    internalAccumulators.foreach(registerAccumulator)
+    internalAccumulators.map { a => (a.name.get, a) }.toMap
+  }
 }
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index 130b58882d8ee9bd9af842a886460e48d0413823..9c617fc719cb54afdc58ab504e85047b41c19ca5 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -23,8 +23,7 @@ import java.io.{IOException, ObjectOutputStream}
 
 import scala.collection.mutable.ArrayBuffer
 
-import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext}
-import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
+import org.apache.spark._
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap, CompactBuffer}
 import org.apache.spark.util.Utils
@@ -169,8 +168,10 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
       for ((it, depNum) <- rddIterators) {
         map.insertAll(it.map(pair => (pair._1, new CoGroupValue(pair._2, depNum))))
       }
-      context.taskMetrics.incMemoryBytesSpilled(map.memoryBytesSpilled)
-      context.taskMetrics.incDiskBytesSpilled(map.diskBytesSpilled)
+      context.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled)
+      context.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled)
+      context.internalMetricsToAccumulators(
+        InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes)
       new InterruptibleIterator(context,
         map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]])
     }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala
index e0edd7d4ae9685f7641b25d57912a0a2e49f52ce..11d123eec43ca2accc2561c9eefb2c013e47c6f0 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala
@@ -24,11 +24,12 @@ import org.apache.spark.annotation.DeveloperApi
  * Information about an [[org.apache.spark.Accumulable]] modified during a task or stage.
  */
 @DeveloperApi
-class AccumulableInfo (
+class AccumulableInfo private[spark] (
     val id: Long,
     val name: String,
     val update: Option[String], // represents a partial update within a task
-    val value: String) {
+    val value: String,
+    val internal: Boolean) {
 
   override def equals(other: Any): Boolean = other match {
     case acc: AccumulableInfo =>
@@ -40,10 +41,10 @@ class AccumulableInfo (
 
 object AccumulableInfo {
   def apply(id: Long, name: String, update: Option[String], value: String): AccumulableInfo = {
-    new AccumulableInfo(id, name, update, value)
+    new AccumulableInfo(id, name, update, value, internal = false)
   }
 
   def apply(id: Long, name: String, value: String): AccumulableInfo = {
-    new AccumulableInfo(id, name, None, value)
+    new AccumulableInfo(id, name, None, value, internal = false)
   }
 }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index c4fa277c21254b81a08f6968bcbd305f59cff2b0..bb489c6b6e98fe81edd411f064849633e02ed286 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -773,16 +773,26 @@ class DAGScheduler(
     stage.pendingTasks.clear()
 
     // First figure out the indexes of partition ids to compute.
-    val partitionsToCompute: Seq[Int] = {
+    val (allPartitions: Seq[Int], partitionsToCompute: Seq[Int]) = {
       stage match {
         case stage: ShuffleMapStage =>
-          (0 until stage.numPartitions).filter(id => stage.outputLocs(id).isEmpty)
+          val allPartitions = 0 until stage.numPartitions
+          val filteredPartitions = allPartitions.filter { id => stage.outputLocs(id).isEmpty }
+          (allPartitions, filteredPartitions)
         case stage: ResultStage =>
           val job = stage.resultOfJob.get
-          (0 until job.numPartitions).filter(id => !job.finished(id))
+          val allPartitions = 0 until job.numPartitions
+          val filteredPartitions = allPartitions.filter { id => !job.finished(id) }
+          (allPartitions, filteredPartitions)
       }
     }
 
+    // Reset internal accumulators only if this stage is not partially submitted
+    // Otherwise, we may override existing accumulator values from some tasks
+    if (allPartitions == partitionsToCompute) {
+      stage.resetInternalAccumulators()
+    }
+
     val properties = jobIdToActiveJob.get(stage.firstJobId).map(_.properties).orNull
 
     runningStages += stage
@@ -852,7 +862,8 @@ class DAGScheduler(
           partitionsToCompute.map { id =>
             val locs = taskIdToLocations(id)
             val part = stage.rdd.partitions(id)
-            new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs)
+            new ShuffleMapTask(stage.id, stage.latestInfo.attemptId,
+              taskBinary, part, locs, stage.internalAccumulators)
           }
 
         case stage: ResultStage =>
@@ -861,7 +872,8 @@ class DAGScheduler(
             val p: Int = job.partitions(id)
             val part = stage.rdd.partitions(p)
             val locs = taskIdToLocations(id)
-            new ResultTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs, id)
+            new ResultTask(stage.id, stage.latestInfo.attemptId,
+              taskBinary, part, locs, id, stage.internalAccumulators)
           }
       }
     } catch {
@@ -916,9 +928,11 @@ class DAGScheduler(
           // To avoid UI cruft, ignore cases where value wasn't updated
           if (acc.name.isDefined && partialValue != acc.zero) {
             val name = acc.name.get
-            stage.latestInfo.accumulables(id) = AccumulableInfo(id, name, s"${acc.value}")
+            val value = s"${acc.value}"
+            stage.latestInfo.accumulables(id) =
+              new AccumulableInfo(id, name, None, value, acc.isInternal)
             event.taskInfo.accumulables +=
-              AccumulableInfo(id, name, Some(s"$partialValue"), s"${acc.value}")
+              new AccumulableInfo(id, name, Some(s"$partialValue"), value, acc.isInternal)
           }
         }
       } catch {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
index 9c2606e278c54dcaa67c58039e2387cb852978c3..c4dc080e2b22b524a6fde4286d71aaa944f35615 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -45,8 +45,10 @@ private[spark] class ResultTask[T, U](
     taskBinary: Broadcast[Array[Byte]],
     partition: Partition,
     @transient locs: Seq[TaskLocation],
-    val outputId: Int)
-  extends Task[U](stageId, stageAttemptId, partition.index) with Serializable {
+    val outputId: Int,
+    internalAccumulators: Seq[Accumulator[Long]])
+  extends Task[U](stageId, stageAttemptId, partition.index, internalAccumulators)
+  with Serializable {
 
   @transient private[this] val preferredLocs: Seq[TaskLocation] = {
     if (locs == null) Nil else locs.toSet.toSeq
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index 14c8c00961487a25b8a3bdf7e6d5d502994af1ff..f478f9982afefccab863565531048daf54d64cfe 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -43,12 +43,14 @@ private[spark] class ShuffleMapTask(
     stageAttemptId: Int,
     taskBinary: Broadcast[Array[Byte]],
     partition: Partition,
-    @transient private var locs: Seq[TaskLocation])
-  extends Task[MapStatus](stageId, stageAttemptId, partition.index) with Logging {
+    @transient private var locs: Seq[TaskLocation],
+    internalAccumulators: Seq[Accumulator[Long]])
+  extends Task[MapStatus](stageId, stageAttemptId, partition.index, internalAccumulators)
+  with Logging {
 
   /** A constructor used only in test suites. This does not require passing in an RDD. */
   def this(partitionId: Int) {
-    this(0, 0, null, new Partition { override def index: Int = 0 }, null)
+    this(0, 0, null, new Partition { override def index: Int = 0 }, null, null)
   }
 
   @transient private val preferredLocs: Seq[TaskLocation] = {
@@ -69,7 +71,7 @@ private[spark] class ShuffleMapTask(
       val manager = SparkEnv.get.shuffleManager
       writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
       writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
-      return writer.stop(success = true).get
+      writer.stop(success = true).get
     } catch {
       case e: Exception =>
         try {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index 40a333a3e06b2e81165ef643b559a6604c973f38..de05ee256dbfc653913d26db49c7bc8acf11b134 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -68,6 +68,22 @@ private[spark] abstract class Stage(
   val name = callSite.shortForm
   val details = callSite.longForm
 
+  private var _internalAccumulators: Seq[Accumulator[Long]] = Seq.empty
+
+  /** Internal accumulators shared across all tasks in this stage. */
+  def internalAccumulators: Seq[Accumulator[Long]] = _internalAccumulators
+
+  /**
+   * Re-initialize the internal accumulators associated with this stage.
+   *
+   * This is called every time the stage is submitted, *except* when a subset of tasks
+   * belonging to this stage has already finished. Otherwise, reinitializing the internal
+   * accumulators here again will override partial values from the finished tasks.
+   */
+  def resetInternalAccumulators(): Unit = {
+    _internalAccumulators = InternalAccumulator.create()
+  }
+
   /**
    * Pointer to the [StageInfo] object for the most recent attempt. This needs to be initialized
    * here, before any attempts have actually been created, because the DAGScheduler uses this
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 1978305cfefbdefc2e81fe15a953dcc30e3c3341..9edf9f048f9fd7567ba6be6ecb7c34dcdff94eda 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -23,7 +23,7 @@ import java.nio.ByteBuffer
 import scala.collection.mutable.HashMap
 
 import org.apache.spark.metrics.MetricsSystem
-import org.apache.spark.{SparkEnv, TaskContextImpl, TaskContext}
+import org.apache.spark.{Accumulator, SparkEnv, TaskContextImpl, TaskContext}
 import org.apache.spark.executor.TaskMetrics
 import org.apache.spark.serializer.SerializerInstance
 import org.apache.spark.unsafe.memory.TaskMemoryManager
@@ -47,7 +47,8 @@ import org.apache.spark.util.Utils
 private[spark] abstract class Task[T](
     val stageId: Int,
     val stageAttemptId: Int,
-    var partitionId: Int) extends Serializable {
+    val partitionId: Int,
+    internalAccumulators: Seq[Accumulator[Long]]) extends Serializable {
 
   /**
    * The key of the Map is the accumulator id and the value of the Map is the latest accumulator
@@ -68,12 +69,13 @@ private[spark] abstract class Task[T](
     metricsSystem: MetricsSystem)
   : (T, AccumulatorUpdates) = {
     context = new TaskContextImpl(
-      stageId = stageId,
-      partitionId = partitionId,
-      taskAttemptId = taskAttemptId,
-      attemptNumber = attemptNumber,
-      taskMemoryManager = taskMemoryManager,
-      metricsSystem = metricsSystem,
+      stageId,
+      partitionId,
+      taskAttemptId,
+      attemptNumber,
+      taskMemoryManager,
+      metricsSystem,
+      internalAccumulators,
       runningLocally = false)
     TaskContext.setTaskContext(context)
     context.taskMetrics.setHostname(Utils.localHostName())
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
index de79fa56f017bf6d055fe9bc32135595e196236d..0c8f08f0f3b1b206e7922421c9c62869b6622618 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.shuffle.hash
 
-import org.apache.spark.{InterruptibleIterator, Logging, MapOutputTracker, SparkEnv, TaskContext}
+import org.apache.spark._
 import org.apache.spark.serializer.Serializer
 import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
 import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator}
@@ -100,8 +100,10 @@ private[spark] class HashShuffleReader[K, C](
         // the ExternalSorter won't spill to disk.
         val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser))
         sorter.insertAll(aggregatedIter)
-        context.taskMetrics.incMemoryBytesSpilled(sorter.memoryBytesSpilled)
-        context.taskMetrics.incDiskBytesSpilled(sorter.diskBytesSpilled)
+        context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
+        context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
+        context.internalMetricsToAccumulators(
+          InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes)
         sorter.iterator
       case None =>
         aggregatedIter
diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
index e2d25e36365faff76a08bb8dac8a8b69282520c1..cb122eaed83d10912ce8fea16a85ea9ce37f0a70 100644
--- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
+++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
@@ -62,6 +62,13 @@ private[spark] object ToolTips {
     """Time that the executor spent paused for Java garbage collection while the task was
        running."""
 
+  val PEAK_EXECUTION_MEMORY =
+    """Execution memory refers to the memory used by internal data structures created during
+       shuffles, aggregations and joins when Tungsten is enabled. The value of this accumulator
+       should be approximately the sum of the peak sizes across all such data structures created
+       in this task. For SQL jobs, this only tracks all unsafe operators, broadcast joins, and
+       external sort."""
+
   val JOB_TIMELINE =
     """Shows when jobs started and ended and when executors joined or left. Drag to scroll.
        Click Enable Zooming and use mouse wheel to zoom in/out."""
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index cf04b5e59239b923151c9ad39424c126b056f7f9..3954c3d1ef8942b8613041619f186d0a6ce5f17a 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -26,6 +26,7 @@ import scala.xml.{Elem, Node, Unparsed}
 
 import org.apache.commons.lang3.StringEscapeUtils
 
+import org.apache.spark.{InternalAccumulator, SparkConf}
 import org.apache.spark.executor.TaskMetrics
 import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo}
 import org.apache.spark.ui._
@@ -67,6 +68,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
   // if we find that it's okay.
   private val MAX_TIMELINE_TASKS = parent.conf.getInt("spark.ui.timeline.tasks.maximum", 1000)
 
+  private val displayPeakExecutionMemory =
+    parent.conf.getOption("spark.sql.unsafe.enabled").exists(_.toBoolean)
 
   def render(request: HttpServletRequest): Seq[Node] = {
     progressListener.synchronized {
@@ -114,10 +117,11 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
 
       val stageData = stageDataOption.get
       val tasks = stageData.taskData.values.toSeq.sortBy(_.taskInfo.launchTime)
-
       val numCompleted = tasks.count(_.taskInfo.finished)
-      val accumulables = progressListener.stageIdToData((stageId, stageAttemptId)).accumulables
-      val hasAccumulators = accumulables.size > 0
+
+      val allAccumulables = progressListener.stageIdToData((stageId, stageAttemptId)).accumulables
+      val externalAccumulables = allAccumulables.values.filter { acc => !acc.internal }
+      val hasAccumulators = externalAccumulables.size > 0
 
       val summary =
         <div>
@@ -221,6 +225,15 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
                   <span class="additional-metric-title">Getting Result Time</span>
                 </span>
               </li>
+              {if (displayPeakExecutionMemory) {
+                <li>
+                  <span data-toggle="tooltip"
+                        title={ToolTips.PEAK_EXECUTION_MEMORY} data-placement="right">
+                    <input type="checkbox" name={TaskDetailsClassNames.PEAK_EXECUTION_MEMORY}/>
+                    <span class="additional-metric-title">Peak Execution Memory</span>
+                  </span>
+                </li>
+              }}
             </ul>
           </div>
         </div>
@@ -241,11 +254,12 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
       val accumulableTable = UIUtils.listingTable(
         accumulableHeaders,
         accumulableRow,
-        accumulables.values.toSeq)
+        externalAccumulables.toSeq)
 
       val currentTime = System.currentTimeMillis()
       val (taskTable, taskTableHTML) = try {
         val _taskTable = new TaskPagedTable(
+          parent.conf,
           UIUtils.prependBaseUri(parent.basePath) +
             s"/stages/stage?id=${stageId}&attempt=${stageAttemptId}",
           tasks,
@@ -294,12 +308,14 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
         else {
           def getDistributionQuantiles(data: Seq[Double]): IndexedSeq[Double] =
             Distribution(data).get.getQuantiles()
-
           def getFormattedTimeQuantiles(times: Seq[Double]): Seq[Node] = {
             getDistributionQuantiles(times).map { millis =>
               <td>{UIUtils.formatDuration(millis.toLong)}</td>
             }
           }
+          def getFormattedSizeQuantiles(data: Seq[Double]): Seq[Elem] = {
+            getDistributionQuantiles(data).map(d => <td>{Utils.bytesToString(d.toLong)}</td>)
+          }
 
           val deserializationTimes = validTasks.map { case TaskUIData(_, metrics, _) =>
             metrics.get.executorDeserializeTime.toDouble
@@ -349,6 +365,23 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
               </span>
             </td> +:
             getFormattedTimeQuantiles(gettingResultTimes)
+
+          val peakExecutionMemory = validTasks.map { case TaskUIData(info, _, _) =>
+            info.accumulables
+              .find { acc => acc.name == InternalAccumulator.PEAK_EXECUTION_MEMORY }
+              .map { acc => acc.value.toLong }
+              .getOrElse(0L)
+              .toDouble
+          }
+          val peakExecutionMemoryQuantiles = {
+            <td>
+              <span data-toggle="tooltip"
+                    title={ToolTips.PEAK_EXECUTION_MEMORY} data-placement="right">
+                Peak Execution Memory
+              </span>
+            </td> +: getFormattedSizeQuantiles(peakExecutionMemory)
+          }
+
           // The scheduler delay includes the network delay to send the task to the worker
           // machine and to send back the result (but not the time to fetch the task result,
           // if it needed to be fetched from the block manager on the worker).
@@ -359,10 +392,6 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
             title={ToolTips.SCHEDULER_DELAY} data-placement="right">Scheduler Delay</span></td>
           val schedulerDelayQuantiles = schedulerDelayTitle +:
             getFormattedTimeQuantiles(schedulerDelays)
-
-          def getFormattedSizeQuantiles(data: Seq[Double]): Seq[Elem] =
-            getDistributionQuantiles(data).map(d => <td>{Utils.bytesToString(d.toLong)}</td>)
-
           def getFormattedSizeQuantilesWithRecords(data: Seq[Double], records: Seq[Double])
             : Seq[Elem] = {
             val recordDist = getDistributionQuantiles(records).iterator
@@ -466,6 +495,13 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
               {serializationQuantiles}
             </tr>,
             <tr class={TaskDetailsClassNames.GETTING_RESULT_TIME}>{gettingResultQuantiles}</tr>,
+            if (displayPeakExecutionMemory) {
+              <tr class={TaskDetailsClassNames.PEAK_EXECUTION_MEMORY}>
+                {peakExecutionMemoryQuantiles}
+              </tr>
+            } else {
+              Nil
+            },
             if (stageData.hasInput) <tr>{inputQuantiles}</tr> else Nil,
             if (stageData.hasOutput) <tr>{outputQuantiles}</tr> else Nil,
             if (stageData.hasShuffleRead) {
@@ -499,7 +535,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
       val executorTable = new ExecutorTable(stageId, stageAttemptId, parent)
 
       val maybeAccumulableTable: Seq[Node] =
-        if (accumulables.size > 0) { <h4>Accumulators</h4> ++ accumulableTable } else Seq()
+        if (hasAccumulators) { <h4>Accumulators</h4> ++ accumulableTable } else Seq()
 
       val content =
         summary ++
@@ -750,29 +786,30 @@ private[ui] case class TaskTableRowBytesSpilledData(
  * Contains all data that needs for sorting and generating HTML. Using this one rather than
  * TaskUIData to avoid creating duplicate contents during sorting the data.
  */
-private[ui] case class TaskTableRowData(
-    index: Int,
-    taskId: Long,
-    attempt: Int,
-    speculative: Boolean,
-    status: String,
-    taskLocality: String,
-    executorIdAndHost: String,
-    launchTime: Long,
-    duration: Long,
-    formatDuration: String,
-    schedulerDelay: Long,
-    taskDeserializationTime: Long,
-    gcTime: Long,
-    serializationTime: Long,
-    gettingResultTime: Long,
-    accumulators: Option[String], // HTML
-    input: Option[TaskTableRowInputData],
-    output: Option[TaskTableRowOutputData],
-    shuffleRead: Option[TaskTableRowShuffleReadData],
-    shuffleWrite: Option[TaskTableRowShuffleWriteData],
-    bytesSpilled: Option[TaskTableRowBytesSpilledData],
-    error: String)
+private[ui] class TaskTableRowData(
+    val index: Int,
+    val taskId: Long,
+    val attempt: Int,
+    val speculative: Boolean,
+    val status: String,
+    val taskLocality: String,
+    val executorIdAndHost: String,
+    val launchTime: Long,
+    val duration: Long,
+    val formatDuration: String,
+    val schedulerDelay: Long,
+    val taskDeserializationTime: Long,
+    val gcTime: Long,
+    val serializationTime: Long,
+    val gettingResultTime: Long,
+    val peakExecutionMemoryUsed: Long,
+    val accumulators: Option[String], // HTML
+    val input: Option[TaskTableRowInputData],
+    val output: Option[TaskTableRowOutputData],
+    val shuffleRead: Option[TaskTableRowShuffleReadData],
+    val shuffleWrite: Option[TaskTableRowShuffleWriteData],
+    val bytesSpilled: Option[TaskTableRowBytesSpilledData],
+    val error: String)
 
 private[ui] class TaskDataSource(
     tasks: Seq[TaskUIData],
@@ -816,10 +853,15 @@ private[ui] class TaskDataSource(
     val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L)
     val gettingResultTime = getGettingResultTime(info, currentTime)
 
-    val maybeAccumulators = info.accumulables
-    val accumulatorsReadable = maybeAccumulators.map { acc =>
+    val (taskInternalAccumulables, taskExternalAccumulables) =
+      info.accumulables.partition(_.internal)
+    val externalAccumulableReadable = taskExternalAccumulables.map { acc =>
       StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update.get}")
     }
+    val peakExecutionMemoryUsed = taskInternalAccumulables
+      .find { acc => acc.name == InternalAccumulator.PEAK_EXECUTION_MEMORY }
+      .map { acc => acc.value.toLong }
+      .getOrElse(0L)
 
     val maybeInput = metrics.flatMap(_.inputMetrics)
     val inputSortable = maybeInput.map(_.bytesRead).getOrElse(0L)
@@ -923,7 +965,7 @@ private[ui] class TaskDataSource(
         None
       }
 
-    TaskTableRowData(
+    new TaskTableRowData(
       info.index,
       info.taskId,
       info.attempt,
@@ -939,7 +981,8 @@ private[ui] class TaskDataSource(
       gcTime,
       serializationTime,
       gettingResultTime,
-      if (hasAccumulators) Some(accumulatorsReadable.mkString("<br/>")) else None,
+      peakExecutionMemoryUsed,
+      if (hasAccumulators) Some(externalAccumulableReadable.mkString("<br/>")) else None,
       input,
       output,
       shuffleRead,
@@ -1006,6 +1049,10 @@ private[ui] class TaskDataSource(
         override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
           Ordering.Long.compare(x.gettingResultTime, y.gettingResultTime)
       }
+      case "Peak Execution Memory" => new Ordering[TaskTableRowData] {
+        override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
+          Ordering.Long.compare(x.peakExecutionMemoryUsed, y.peakExecutionMemoryUsed)
+      }
       case "Accumulators" =>
         if (hasAccumulators) {
           new Ordering[TaskTableRowData] {
@@ -1132,6 +1179,7 @@ private[ui] class TaskDataSource(
 }
 
 private[ui] class TaskPagedTable(
+    conf: SparkConf,
     basePath: String,
     data: Seq[TaskUIData],
     hasAccumulators: Boolean,
@@ -1143,7 +1191,11 @@ private[ui] class TaskPagedTable(
     currentTime: Long,
     pageSize: Int,
     sortColumn: String,
-    desc: Boolean) extends PagedTable[TaskTableRowData]{
+    desc: Boolean) extends PagedTable[TaskTableRowData] {
+
+  // We only track peak memory used for unsafe operators
+  private val displayPeakExecutionMemory =
+    conf.getOption("spark.sql.unsafe.enabled").exists(_.toBoolean)
 
   override def tableId: String = ""
 
@@ -1195,6 +1247,13 @@ private[ui] class TaskPagedTable(
         ("GC Time", ""),
         ("Result Serialization Time", TaskDetailsClassNames.RESULT_SERIALIZATION_TIME),
         ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME)) ++
+        {
+          if (displayPeakExecutionMemory) {
+            Seq(("Peak Execution Memory", TaskDetailsClassNames.PEAK_EXECUTION_MEMORY))
+          } else {
+            Nil
+          }
+        } ++
         {if (hasAccumulators) Seq(("Accumulators", "")) else Nil} ++
         {if (hasInput) Seq(("Input Size / Records", "")) else Nil} ++
         {if (hasOutput) Seq(("Output Size / Records", "")) else Nil} ++
@@ -1271,6 +1330,11 @@ private[ui] class TaskPagedTable(
       <td class={TaskDetailsClassNames.GETTING_RESULT_TIME}>
         {UIUtils.formatDuration(task.gettingResultTime)}
       </td>
+      {if (displayPeakExecutionMemory) {
+        <td class={TaskDetailsClassNames.PEAK_EXECUTION_MEMORY}>
+          {Utils.bytesToString(task.peakExecutionMemoryUsed)}
+        </td>
+      }}
       {if (task.accumulators.nonEmpty) {
         <td>{Unparsed(task.accumulators.get)}</td>
       }}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala
index 9bf67db8acde1f3256b1d8a8cc58df07531eef71..d2dfc5a32915c2a6f499ec9d0df610d9b5af12df 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala
@@ -31,4 +31,5 @@ private[spark] object TaskDetailsClassNames {
   val SHUFFLE_READ_REMOTE_SIZE = "shuffle_read_remote"
   val RESULT_SERIALIZATION_TIME = "serialization_time"
   val GETTING_RESULT_TIME = "getting_result_time"
+  val PEAK_EXECUTION_MEMORY = "peak_execution_memory"
 }
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index d166037351c31d5f0f90976663ee3ebb12faaa1d..f929b12606f0a3b2419381b5527b0eabf3860bd4 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -89,6 +89,7 @@ class ExternalAppendOnlyMap[K, V, C](
 
   // Number of bytes spilled in total
   private var _diskBytesSpilled = 0L
+  def diskBytesSpilled: Long = _diskBytesSpilled
 
   // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
   private val fileBufferSize =
@@ -97,6 +98,10 @@ class ExternalAppendOnlyMap[K, V, C](
   // Write metrics for current spill
   private var curWriteMetrics: ShuffleWriteMetrics = _
 
+  // Peak size of the in-memory map observed so far, in bytes
+  private var _peakMemoryUsedBytes: Long = 0L
+  def peakMemoryUsedBytes: Long = _peakMemoryUsedBytes
+
   private val keyComparator = new HashComparator[K]
   private val ser = serializer.newInstance()
 
@@ -126,7 +131,11 @@ class ExternalAppendOnlyMap[K, V, C](
 
     while (entries.hasNext) {
       curEntry = entries.next()
-      if (maybeSpill(currentMap, currentMap.estimateSize())) {
+      val estimatedSize = currentMap.estimateSize()
+      if (estimatedSize > _peakMemoryUsedBytes) {
+        _peakMemoryUsedBytes = estimatedSize
+      }
+      if (maybeSpill(currentMap, estimatedSize)) {
         currentMap = new SizeTrackingAppendOnlyMap[K, C]
       }
       currentMap.changeValue(curEntry._1, update)
@@ -207,8 +216,6 @@ class ExternalAppendOnlyMap[K, V, C](
     spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes))
   }
 
-  def diskBytesSpilled: Long = _diskBytesSpilled
-
   /**
    * Return an iterator that merges the in-memory map with the spilled maps.
    * If no spill has occurred, simply return the in-memory map's iterator.
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index ba7ec834d622d2df37466eb60ca8d30bf82492c6..19287edbaf1661ce73674034c4c721840e2a5a73 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -152,6 +152,9 @@ private[spark] class ExternalSorter[K, V, C](
   private var _diskBytesSpilled = 0L
   def diskBytesSpilled: Long = _diskBytesSpilled
 
+  // Peak size of the in-memory data structure observed so far, in bytes
+  private var _peakMemoryUsedBytes: Long = 0L
+  def peakMemoryUsedBytes: Long = _peakMemoryUsedBytes
 
   // A comparator for keys K that orders them within a partition to allow aggregation or sorting.
   // Can be a partial ordering by hash code if a total ordering is not provided through by the
@@ -224,15 +227,22 @@ private[spark] class ExternalSorter[K, V, C](
       return
     }
 
+    var estimatedSize = 0L
     if (usingMap) {
-      if (maybeSpill(map, map.estimateSize())) {
+      estimatedSize = map.estimateSize()
+      if (maybeSpill(map, estimatedSize)) {
         map = new PartitionedAppendOnlyMap[K, C]
       }
     } else {
-      if (maybeSpill(buffer, buffer.estimateSize())) {
+      estimatedSize = buffer.estimateSize()
+      if (maybeSpill(buffer, estimatedSize)) {
         buffer = newBuffer()
       }
     }
+
+    if (estimatedSize > _peakMemoryUsedBytes) {
+      _peakMemoryUsedBytes = estimatedSize
+    }
   }
 
   /**
@@ -684,8 +694,10 @@ private[spark] class ExternalSorter[K, V, C](
       }
     }
 
-    context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled)
-    context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled)
+    context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled)
+    context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
+    context.internalMetricsToAccumulators(
+      InternalAccumulator.PEAK_EXECUTION_MEMORY).add(peakMemoryUsedBytes)
 
     lengths
   }
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index e948ca33471a4518beaeb670dea2d3c7bbdafbb5..ffe4b4baffb2a69505b74d06e202dbe485fa2237 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -51,7 +51,6 @@ import org.junit.Test;
 
 import org.apache.spark.api.java.*;
 import org.apache.spark.api.java.function.*;
-import org.apache.spark.executor.TaskMetrics;
 import org.apache.spark.input.PortableDataStream;
 import org.apache.spark.partial.BoundedDouble;
 import org.apache.spark.partial.PartialResult;
@@ -1011,7 +1010,7 @@ public class JavaAPISuite implements Serializable {
   @Test
   public void iterator() {
     JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
-    TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, null, false, new TaskMetrics());
+    TaskContext context = TaskContext$.MODULE$.empty();
     Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue());
   }
 
diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
index 04fc09b323dbb2073beeeede6d0fb75a6a849f67..98c32bbc298d75294e7bcec8e652c0fd1c54ff52 100644
--- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
@@ -190,6 +190,7 @@ public class UnsafeShuffleWriterSuite {
       });
 
     when(taskContext.taskMetrics()).thenReturn(taskMetrics);
+    when(taskContext.internalMetricsToAccumulators()).thenReturn(null);
 
     when(shuffleDep.serializer()).thenReturn(Option.<Serializer>apply(serializer));
     when(shuffleDep.partitioner()).thenReturn(hashPartitioner);
@@ -542,4 +543,57 @@ public class UnsafeShuffleWriterSuite {
     writer.stop(false);
     assertSpillFilesWereCleanedUp();
   }
+
+  @Test
+  public void testPeakMemoryUsed() throws Exception {
+    final long recordLengthBytes = 8;
+    final long pageSizeBytes = 256;
+    final long numRecordsPerPage = pageSizeBytes / recordLengthBytes;
+    final SparkConf conf = new SparkConf().set("spark.buffer.pageSize", pageSizeBytes + "b");
+    final UnsafeShuffleWriter<Object, Object> writer =
+      new UnsafeShuffleWriter<Object, Object>(
+        blockManager,
+        shuffleBlockResolver,
+        taskMemoryManager,
+        shuffleMemoryManager,
+        new UnsafeShuffleHandle<Object, Object>(0, 1, shuffleDep),
+        0, // map id
+        taskContext,
+        conf);
+
+    // Peak memory should be monotonically increasing. More specifically, every time
+    // we allocate a new page it should increase by exactly the size of the page.
+    long previousPeakMemory = writer.getPeakMemoryUsedBytes();
+    long newPeakMemory;
+    try {
+      for (int i = 0; i < numRecordsPerPage * 10; i++) {
+        writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1));
+        newPeakMemory = writer.getPeakMemoryUsedBytes();
+        if (i % numRecordsPerPage == 0) {
+          // We allocated a new page for this record, so peak memory should change
+          assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
+        } else {
+          assertEquals(previousPeakMemory, newPeakMemory);
+        }
+        previousPeakMemory = newPeakMemory;
+      }
+
+      // Spilling should not change peak memory
+      writer.forceSorterToSpill();
+      newPeakMemory = writer.getPeakMemoryUsedBytes();
+      assertEquals(previousPeakMemory, newPeakMemory);
+      for (int i = 0; i < numRecordsPerPage; i++) {
+        writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1));
+      }
+      newPeakMemory = writer.getPeakMemoryUsedBytes();
+      assertEquals(previousPeakMemory, newPeakMemory);
+
+      // Closing the writer should not change peak memory
+      writer.closeAndWriteOutput();
+      newPeakMemory = writer.getPeakMemoryUsedBytes();
+      assertEquals(previousPeakMemory, newPeakMemory);
+    } finally {
+      writer.stop(false);
+    }
+  }
 }
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index dbb7c662d7871b46979ecd1542145a2912589312..0e23a64fb74bb62c5ec9e8d9d982945738e7aa87 100644
--- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -25,6 +25,7 @@ import org.junit.*;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
 import static org.hamcrest.Matchers.greaterThan;
+import static org.junit.Assert.*;
 import static org.mockito.AdditionalMatchers.geq;
 import static org.mockito.Mockito.*;
 
@@ -495,4 +496,42 @@ public abstract class AbstractBytesToBytesMapSuite {
     map.growAndRehash();
     map.free();
   }
+
+  @Test
+  public void testTotalMemoryConsumption() {
+    final long recordLengthBytes = 24;
+    final long pageSizeBytes = 256 + 8; // 8 bytes for end-of-page marker
+    final long numRecordsPerPage = (pageSizeBytes - 8) / recordLengthBytes;
+    final BytesToBytesMap map = new BytesToBytesMap(
+      taskMemoryManager, shuffleMemoryManager, 1024, pageSizeBytes);
+
+    // Since BytesToBytesMap is append-only, we expect the total memory consumption to be
+    // monotonically increasing. More specifically, every time we allocate a new page it
+    // should increase by exactly the size of the page. In this regard, the memory usage
+    // at any given time is also the peak memory used.
+    long previousMemory = map.getTotalMemoryConsumption();
+    long newMemory;
+    try {
+      for (long i = 0; i < numRecordsPerPage * 10; i++) {
+        final long[] value = new long[]{i};
+        map.lookup(value, PlatformDependent.LONG_ARRAY_OFFSET, 8).putNewKey(
+          value,
+          PlatformDependent.LONG_ARRAY_OFFSET,
+          8,
+          value,
+          PlatformDependent.LONG_ARRAY_OFFSET,
+          8);
+        newMemory = map.getTotalMemoryConsumption();
+        if (i % numRecordsPerPage == 0) {
+          // We allocated a new page for this record, so peak memory should change
+          assertEquals(previousMemory + pageSizeBytes, newMemory);
+        } else {
+          assertEquals(previousMemory, newMemory);
+        }
+        previousMemory = newMemory;
+      }
+    } finally {
+      map.free();
+    }
+  }
 }
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
index 52fa8bcd57e79701aa65b84cab9a72b263bcc6f5..c11949d57a0ea5a06db9aeb76195080f3ce8773e 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -247,4 +247,50 @@ public class UnsafeExternalSorterSuite {
     assertSpillFilesWereCleanedUp();
   }
 
+  @Test
+  public void testPeakMemoryUsed() throws Exception {
+    final long recordLengthBytes = 8;
+    final long pageSizeBytes = 256;
+    final long numRecordsPerPage = pageSizeBytes / recordLengthBytes;
+    final UnsafeExternalSorter sorter = UnsafeExternalSorter.create(
+      taskMemoryManager,
+      shuffleMemoryManager,
+      blockManager,
+      taskContext,
+      recordComparator,
+      prefixComparator,
+      1024,
+      pageSizeBytes);
+
+    // Peak memory should be monotonically increasing. More specifically, every time
+    // we allocate a new page it should increase by exactly the size of the page.
+    long previousPeakMemory = sorter.getPeakMemoryUsedBytes();
+    long newPeakMemory;
+    try {
+      for (int i = 0; i < numRecordsPerPage * 10; i++) {
+        insertNumber(sorter, i);
+        newPeakMemory = sorter.getPeakMemoryUsedBytes();
+        if (i % numRecordsPerPage == 0) {
+          // We allocated a new page for this record, so peak memory should change
+          assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
+        } else {
+          assertEquals(previousPeakMemory, newPeakMemory);
+        }
+        previousPeakMemory = newPeakMemory;
+      }
+
+      // Spilling should not change peak memory
+      sorter.spill();
+      newPeakMemory = sorter.getPeakMemoryUsedBytes();
+      assertEquals(previousPeakMemory, newPeakMemory);
+      for (int i = 0; i < numRecordsPerPage; i++) {
+        insertNumber(sorter, i);
+      }
+      newPeakMemory = sorter.getPeakMemoryUsedBytes();
+      assertEquals(previousPeakMemory, newPeakMemory);
+    } finally {
+      sorter.freeMemory();
+    }
+  }
+
 }
diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
index e942d6579b2fd9eefa37c6c35b352ee39c98e668..48f549575f4d19599fc2a1ddad5842851f2deaff 100644
--- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
@@ -18,13 +18,17 @@
 package org.apache.spark
 
 import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
 import scala.ref.WeakReference
 
 import org.scalatest.Matchers
+import org.scalatest.exceptions.TestFailedException
 
+import org.apache.spark.scheduler._
 
-class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContext {
 
+class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContext {
+  import InternalAccumulator._
 
   implicit def setAccum[A]: AccumulableParam[mutable.Set[A], A] =
     new AccumulableParam[mutable.Set[A], A] {
@@ -155,4 +159,191 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex
     assert(!Accumulators.originals.get(accId).isDefined)
   }
 
+  test("internal accumulators in TaskContext") {
+    val accums = InternalAccumulator.create()
+    val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null, accums)
+    val internalMetricsToAccums = taskContext.internalMetricsToAccumulators
+    val collectedInternalAccums = taskContext.collectInternalAccumulators()
+    val collectedAccums = taskContext.collectAccumulators()
+    assert(internalMetricsToAccums.size > 0)
+    assert(internalMetricsToAccums.values.forall(_.isInternal))
+    assert(internalMetricsToAccums.contains(TEST_ACCUMULATOR))
+    val testAccum = internalMetricsToAccums(TEST_ACCUMULATOR)
+    assert(collectedInternalAccums.size === internalMetricsToAccums.size)
+    assert(collectedInternalAccums.size === collectedAccums.size)
+    assert(collectedInternalAccums.contains(testAccum.id))
+    assert(collectedAccums.contains(testAccum.id))
+  }
+
+  test("internal accumulators in a stage") {
+    val listener = new SaveInfoListener
+    val numPartitions = 10
+    sc = new SparkContext("local", "test")
+    sc.addSparkListener(listener)
+    // Have each task add 1 to the internal accumulator
+    sc.parallelize(1 to 100, numPartitions).mapPartitions { iter =>
+      TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1
+      iter
+    }.count()
+    val stageInfos = listener.getCompletedStageInfos
+    val taskInfos = listener.getCompletedTaskInfos
+    assert(stageInfos.size === 1)
+    assert(taskInfos.size === numPartitions)
+    // The accumulator values should be merged in the stage
+    val stageAccum = findAccumulableInfo(stageInfos.head.accumulables.values, TEST_ACCUMULATOR)
+    assert(stageAccum.value.toLong === numPartitions)
+    // The accumulator should be updated locally on each task
+    val taskAccumValues = taskInfos.map { taskInfo =>
+      val taskAccum = findAccumulableInfo(taskInfo.accumulables, TEST_ACCUMULATOR)
+      assert(taskAccum.update.isDefined)
+      assert(taskAccum.update.get.toLong === 1)
+      taskAccum.value.toLong
+    }
+    // Each task should keep track of the partial value on the way, i.e. 1, 2, ... numPartitions
+    assert(taskAccumValues.sorted === (1L to numPartitions).toSeq)
+  }
+
+  test("internal accumulators in multiple stages") {
+    val listener = new SaveInfoListener
+    val numPartitions = 10
+    sc = new SparkContext("local", "test")
+    sc.addSparkListener(listener)
+    // Each stage creates its own set of internal accumulators so the
+    // values for the same metric should not be mixed up across stages
+    sc.parallelize(1 to 100, numPartitions)
+      .map { i => (i, i) }
+      .mapPartitions { iter =>
+        TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1
+        iter
+      }
+      .reduceByKey { case (x, y) => x + y }
+      .mapPartitions { iter =>
+        TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 10
+        iter
+      }
+      .repartition(numPartitions * 2)
+      .mapPartitions { iter =>
+        TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 100
+        iter
+      }
+      .count()
+    // We ran 3 stages, and the accumulator values should be distinct
+    val stageInfos = listener.getCompletedStageInfos
+    assert(stageInfos.size === 3)
+    val firstStageAccum = findAccumulableInfo(stageInfos(0).accumulables.values, TEST_ACCUMULATOR)
+    val secondStageAccum = findAccumulableInfo(stageInfos(1).accumulables.values, TEST_ACCUMULATOR)
+    val thirdStageAccum = findAccumulableInfo(stageInfos(2).accumulables.values, TEST_ACCUMULATOR)
+    assert(firstStageAccum.value.toLong === numPartitions)
+    assert(secondStageAccum.value.toLong === numPartitions * 10)
+    assert(thirdStageAccum.value.toLong === numPartitions * 2 * 100)
+  }
+
+  test("internal accumulators in fully resubmitted stages") {
+    testInternalAccumulatorsWithFailedTasks((i: Int) => true) // fail all tasks
+  }
+
+  test("internal accumulators in partially resubmitted stages") {
+    testInternalAccumulatorsWithFailedTasks((i: Int) => i % 2 == 0) // fail a subset
+  }
+
+  /**
+   * Return the accumulable info that matches the specified name.
+   */
+  private def findAccumulableInfo(
+      accums: Iterable[AccumulableInfo],
+      name: String): AccumulableInfo = {
+    accums.find { a => a.name == name }.getOrElse {
+      throw new TestFailedException(s"internal accumulator '$name' not found", 0)
+    }
+  }
+
+  /**
+   * Test whether internal accumulators are merged properly if some tasks fail.
+   */
+  private def testInternalAccumulatorsWithFailedTasks(failCondition: (Int => Boolean)): Unit = {
+    val listener = new SaveInfoListener
+    val numPartitions = 10
+    val numFailedPartitions = (0 until numPartitions).count(failCondition)
+    // This says use 1 core and retry tasks up to 2 times
+    sc = new SparkContext("local[1, 2]", "test")
+    sc.addSparkListener(listener)
+    sc.parallelize(1 to 100, numPartitions).mapPartitionsWithIndex { case (i, iter) =>
+      val taskContext = TaskContext.get()
+      taskContext.internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1
+      // Fail the first attempts of a subset of the tasks
+      if (failCondition(i) && taskContext.attemptNumber() == 0) {
+        throw new Exception("Failing a task intentionally.")
+      }
+      iter
+    }.count()
+    val stageInfos = listener.getCompletedStageInfos
+    val taskInfos = listener.getCompletedTaskInfos
+    assert(stageInfos.size === 1)
+    assert(taskInfos.size === numPartitions + numFailedPartitions)
+    val stageAccum = findAccumulableInfo(stageInfos.head.accumulables.values, TEST_ACCUMULATOR)
+    // We should not double count values in the merged accumulator
+    assert(stageAccum.value.toLong === numPartitions)
+    val taskAccumValues = taskInfos.flatMap { taskInfo =>
+      if (!taskInfo.failed) {
+        // If a task succeeded, its update value should always be 1
+        val taskAccum = findAccumulableInfo(taskInfo.accumulables, TEST_ACCUMULATOR)
+        assert(taskAccum.update.isDefined)
+        assert(taskAccum.update.get.toLong === 1)
+        Some(taskAccum.value.toLong)
+      } else {
+        // If a task failed, we should not get its accumulator values
+        assert(taskInfo.accumulables.isEmpty)
+        None
+      }
+    }
+    assert(taskAccumValues.sorted === (1L to numPartitions).toSeq)
+  }
+
+}
+
+private[spark] object AccumulatorSuite {
+
+  /**
+   * Run one or more Spark jobs and verify that the peak execution memory accumulator
+   * is updated afterwards.
+   */
+  def verifyPeakExecutionMemorySet(
+      sc: SparkContext,
+      testName: String)(testBody: => Unit): Unit = {
+    val listener = new SaveInfoListener
+    sc.addSparkListener(listener)
+    // Verify that the accumulator does not already exist
+    sc.parallelize(1 to 10).count()
+    val accums = listener.getCompletedStageInfos.flatMap(_.accumulables.values)
+    assert(!accums.exists(_.name == InternalAccumulator.PEAK_EXECUTION_MEMORY))
+    testBody
+    // Verify that peak execution memory is updated
+    val accum = listener.getCompletedStageInfos
+      .flatMap(_.accumulables.values)
+      .find(_.name == InternalAccumulator.PEAK_EXECUTION_MEMORY)
+      .getOrElse {
+        throw new TestFailedException(
+          s"peak execution memory accumulator not set in '$testName'", 0)
+      }
+    assert(accum.value.toLong > 0)
+  }
+}
+
+/**
+ * A simple listener that keeps track of the TaskInfos and StageInfos of all completed jobs.
+ */
+private class SaveInfoListener extends SparkListener {
+  private val completedStageInfos: ArrayBuffer[StageInfo] = new ArrayBuffer[StageInfo]
+  private val completedTaskInfos: ArrayBuffer[TaskInfo] = new ArrayBuffer[TaskInfo]
+
+  def getCompletedStageInfos: Seq[StageInfo] = completedStageInfos.toArray.toSeq
+  def getCompletedTaskInfos: Seq[TaskInfo] = completedTaskInfos.toArray.toSeq
+
+  override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
+    completedStageInfos += stageCompleted.stageInfo
+  }
+
+  override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
+    completedTaskInfos += taskEnd.taskInfo
+  }
 }
diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
index 618a5fb24710fc4ecf024411cb04a46b53f385d4..cb8bd04e496a7eb83d6f41d64f6eaa60e691a03e 100644
--- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
@@ -21,7 +21,7 @@ import org.mockito.Mockito._
 import org.scalatest.BeforeAndAfter
 import org.scalatest.mock.MockitoSugar
 
-import org.apache.spark.executor.DataReadMethod
+import org.apache.spark.executor.{DataReadMethod, TaskMetrics}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.storage._
 
@@ -65,7 +65,7 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before
     // in blockManager.put is a losing battle. You have been warned.
     blockManager = sc.env.blockManager
     cacheManager = sc.env.cacheManager
-    val context = new TaskContextImpl(0, 0, 0, 0, null, null)
+    val context = TaskContext.empty()
     val computeValue = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
     val getValue = blockManager.get(RDDBlockId(rdd.id, split.index))
     assert(computeValue.toList === List(1, 2, 3, 4))
@@ -77,7 +77,7 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before
     val result = new BlockResult(Array(5, 6, 7).iterator, DataReadMethod.Memory, 12)
     when(blockManager.get(RDDBlockId(0, 0))).thenReturn(Some(result))
 
-    val context = new TaskContextImpl(0, 0, 0, 0, null, null)
+    val context = TaskContext.empty()
     val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
     assert(value.toList === List(5, 6, 7))
   }
@@ -86,14 +86,14 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before
     // Local computation should not persist the resulting value, so don't expect a put().
     when(blockManager.get(RDDBlockId(0, 0))).thenReturn(None)
 
-    val context = new TaskContextImpl(0, 0, 0, 0, null, null, true)
+    val context = new TaskContextImpl(0, 0, 0, 0, null, null, Seq.empty, runningLocally = true)
     val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
     assert(value.toList === List(1, 2, 3, 4))
   }
 
   test("verify task metrics updated correctly") {
     cacheManager = sc.env.cacheManager
-    val context = new TaskContextImpl(0, 0, 0, 0, null, null)
+    val context = TaskContext.empty()
     cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY)
     assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2)
   }
diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
index 3e8816a4c65be1e33ccd19df32daa1d3a6f74ae3..5f73ec8675966958d823307c7bf55ea2c65dcead 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
@@ -175,7 +175,7 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext {
       }
       val hadoopPart1 = generateFakeHadoopPartition()
       val pipedRdd = new PipedRDD(nums, "printenv " + varName)
-      val tContext = new TaskContextImpl(0, 0, 0, 0, null, null)
+      val tContext = TaskContext.empty()
       val rddIter = pipedRdd.compute(hadoopPart1, tContext)
       val arr = rddIter.toArray
       assert(arr(0) == "/some/path")
diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
index b3ca150195a5f7c8b6871e4e4d6c630754c0a632..f7e16af9d3a928f20e20764747aaf10c8249106e 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
@@ -19,9 +19,11 @@ package org.apache.spark.scheduler
 
 import org.apache.spark.TaskContext
 
-class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0, 0) {
+class FakeTask(
+    stageId: Int,
+    prefLocs: Seq[TaskLocation] = Nil)
+  extends Task[Int](stageId, 0, 0, Seq.empty) {
   override def runTask(context: TaskContext): Int = 0
-
   override def preferredLocations: Seq[TaskLocation] = prefLocs
 }
 
diff --git a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
index 383855caefa2f84cd91b39dca4c3d787a385be1e..f33324792495b9d1c890403300d9ee263804b2b2 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
@@ -25,7 +25,7 @@ import org.apache.spark.TaskContext
  * A Task implementation that fails to serialize.
  */
 private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int)
-  extends Task[Array[Byte]](stageId, 0, 0) {
+  extends Task[Array[Byte]](stageId, 0, 0, Seq.empty) {
 
   override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte]
   override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]()
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index 9201d1e1f328b5e719c77795c0336ca07f74617d..450ab7b9fe92bc2d29dd85f9df9cebfad5c7f221 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -57,8 +57,9 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
     }
     val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
     val func = (c: TaskContext, i: Iterator[String]) => i.next()
-    val task = new ResultTask[String, String](0, 0,
-      sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0)
+    val taskBinary = sc.broadcast(closureSerializer.serialize((rdd, func)).array)
+    val task = new ResultTask[String, String](
+      0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, Seq.empty)
     intercept[RuntimeException] {
       task.run(0, 0, null)
     }
@@ -66,7 +67,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
   }
 
   test("all TaskCompletionListeners should be called even if some fail") {
-    val context = new TaskContextImpl(0, 0, 0, 0, null, null)
+    val context = TaskContext.empty()
     val listener = mock(classOf[TaskCompletionListener])
     context.addTaskCompletionListener(_ => throw new Exception("blah"))
     context.addTaskCompletionListener(listener)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index 3abb99c4b2b54f320d0cabc6d24cdb63446c1cb3..f7cc4bb61d57487cad95086ae0f16f7fcd1c0893 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -136,7 +136,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex
 /**
  * A Task implementation that results in a large serialized task.
  */
-class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0) {
+class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0, Seq.empty) {
   val randomBuffer = new Array[Byte](TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024)
   val random = new Random(0)
   random.nextBytes(randomBuffer)
diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala
index db718ecabbdb92488a6b36ed841a32f047048009..05b3afef5b83961cc9e4e3e56bb457e79d7800b4 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala
@@ -138,7 +138,7 @@ class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext {
       shuffleHandle,
       reduceId,
       reduceId + 1,
-      new TaskContextImpl(0, 0, 0, 0, null, null),
+      TaskContext.empty(),
       blockManager,
       mapOutputTracker)
 
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index cf8bd8ae6962502f164314b30943e19afee80125..828153bdbfc44cfd338c2f35fcb3e10187120049 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -29,7 +29,7 @@ import org.mockito.invocation.InvocationOnMock
 import org.mockito.stubbing.Answer
 import org.scalatest.PrivateMethodTester
 
-import org.apache.spark.{SparkFunSuite, TaskContextImpl}
+import org.apache.spark.{SparkFunSuite, TaskContext}
 import org.apache.spark.network._
 import org.apache.spark.network.buffer.ManagedBuffer
 import org.apache.spark.network.shuffle.BlockFetchingListener
@@ -95,7 +95,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
     )
 
     val iterator = new ShuffleBlockFetcherIterator(
-      new TaskContextImpl(0, 0, 0, 0, null, null),
+      TaskContext.empty(),
       transfer,
       blockManager,
       blocksByAddress,
@@ -165,7 +165,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
     val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
       (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))
 
-    val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null)
+    val taskContext = TaskContext.empty()
     val iterator = new ShuffleBlockFetcherIterator(
       taskContext,
       transfer,
@@ -227,7 +227,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
     val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
       (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))
 
-    val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null)
+    val taskContext = TaskContext.empty()
     val iterator = new ShuffleBlockFetcherIterator(
       taskContext,
       transfer,
diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..98f9314f31dff474efd41132aebabda2bb834561
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui
+
+import javax.servlet.http.HttpServletRequest
+
+import scala.xml.Node
+
+import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS}
+
+import org.apache.spark.{LocalSparkContext, SparkConf, SparkFunSuite, Success}
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.scheduler._
+import org.apache.spark.ui.jobs.{JobProgressListener, StagePage, StagesTab}
+import org.apache.spark.ui.scope.RDDOperationGraphListener
+
+class StagePageSuite extends SparkFunSuite with LocalSparkContext {
+
+  test("peak execution memory only displayed if unsafe is enabled") {
+    val unsafeConf = "spark.sql.unsafe.enabled"
+    val conf = new SparkConf().set(unsafeConf, "true")
+    val html = renderStagePage(conf).toString().toLowerCase
+    val targetString = "peak execution memory"
+    assert(html.contains(targetString))
+    // Disable unsafe and make sure it's not there
+    val conf2 = new SparkConf().set(unsafeConf, "false")
+    val html2 = renderStagePage(conf2).toString().toLowerCase
+    assert(!html2.contains(targetString))
+  }
+
+  /**
+   * Render a stage page started with the given conf and return the HTML.
+   * This also runs a dummy stage to populate the page with useful content.
+   */
+  private def renderStagePage(conf: SparkConf): Seq[Node] = {
+    val jobListener = new JobProgressListener(conf)
+    val graphListener = new RDDOperationGraphListener(conf)
+    val tab = mock(classOf[StagesTab], RETURNS_SMART_NULLS)
+    val request = mock(classOf[HttpServletRequest])
+    when(tab.conf).thenReturn(conf)
+    when(tab.progressListener).thenReturn(jobListener)
+    when(tab.operationGraphListener).thenReturn(graphListener)
+    when(tab.appName).thenReturn("testing")
+    when(tab.headerTabs).thenReturn(Seq.empty)
+    when(request.getParameter("id")).thenReturn("0")
+    when(request.getParameter("attempt")).thenReturn("0")
+    val page = new StagePage(tab)
+
+    // Simulate a stage in job progress listener
+    val stageInfo = new StageInfo(0, 0, "dummy", 1, Seq.empty, Seq.empty, "details")
+    val taskInfo = new TaskInfo(0, 0, 0, 0, "0", "localhost", TaskLocality.ANY, false)
+    jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo))
+    jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo))
+    taskInfo.markSuccessful()
+    jobListener.onTaskEnd(
+      SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, TaskMetrics.empty))
+    jobListener.onStageCompleted(SparkListenerStageCompleted(stageInfo))
+    page.render(request)
+  }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
index 9c362f0de70768c93617dba336eec40007261b11..12e9bafcc92c121116fe184a53305485caf8adb6 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
@@ -399,4 +399,19 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext {
     sc.stop()
   }
 
+  test("external aggregation updates peak execution memory") {
+    val conf = createSparkConf(loadDefaults = false)
+      .set("spark.shuffle.memoryFraction", "0.001")
+      .set("spark.shuffle.manager", "hash") // make sure we're not also using ExternalSorter
+    sc = new SparkContext("local", "test", conf)
+    // No spilling
+    AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external map without spilling") {
+      sc.parallelize(1 to 10, 2).map { i => (i, i) }.reduceByKey(_ + _).count()
+    }
+    // With spilling
+    AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external map with spilling") {
+      sc.parallelize(1 to 1000 * 1000, 2).map { i => (i, i) }.reduceByKey(_ + _).count()
+    }
+  }
+
 }
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
index 986cd8623d14547a9c4bfb84dca30b76bf5281d1..bdb0f4d507a7e14170332c64944c65952406ca3e 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
@@ -692,7 +692,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
     sortWithoutBreakingSortingContracts(createSparkConf(true, false))
   }
 
-  def sortWithoutBreakingSortingContracts(conf: SparkConf) {
+  private def sortWithoutBreakingSortingContracts(conf: SparkConf) {
     conf.set("spark.shuffle.memoryFraction", "0.01")
     conf.set("spark.shuffle.manager", "sort")
     sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
@@ -743,5 +743,15 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
     }
 
     sorter2.stop()
- }
+  }
+
+  test("sorting updates peak execution memory") {
+    val conf = createSparkConf(loadDefaults = false, kryo = false)
+      .set("spark.shuffle.manager", "sort")
+    sc = new SparkContext("local", "test", conf)
+    // Avoid aggregating here to make sure we're not also using ExternalAppendOnlyMap
+    AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external sorter") {
+      sc.parallelize(1 to 1000, 2).repartition(100).count()
+    }
+  }
 }
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index 5e4c6232c947171830c4948b9a14a69a4092b350..193906d24790eccf90e4bb5c0eda2306af93ce70 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -106,6 +106,13 @@ final class UnsafeExternalRowSorter {
     sorter.spill();
   }
 
+  /**
+   * Return the peak memory used so far, in bytes.
+   */
+  public long getPeakMemoryUsage() {
+    return sorter.getPeakMemoryUsedBytes();
+  }
+
   private void cleanupResources() {
     sorter.freeMemory();
   }
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index 9e2c9334a7bee38c8923ed86a626a8e925a68d81..43d06ce9bdfa31cfc8c68dc3431410eba308321c 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -208,6 +208,14 @@ public final class UnsafeFixedWidthAggregationMap {
     };
   }
 
+  /**
+   * The memory used by this map's managed structures, in bytes.
+   * Note that this is also the peak memory used by this map, since the map is append-only.
+   */
+  public long getMemoryUsage() {
+    return map.getTotalMemoryConsumption();
+  }
+
   /**
    * Free the memory associated with this map. This is idempotent and can be called multiple times.
    */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index cd87b8deba0c2c4d160f5372ec10decac09fefed..bf4905dc1eef99a2beab3ef199b24b31f4473dd9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution
 
 import java.io.IOException
 
-import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.{InternalAccumulator, SparkEnv, TaskContext}
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
@@ -263,11 +263,12 @@ case class GeneratedAggregate(
         assert(iter.hasNext, "There should be at least one row for this path")
         log.info("Using Unsafe-based aggregator")
         val pageSizeBytes = SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m")
+        val taskContext = TaskContext.get()
         val aggregationMap = new UnsafeFixedWidthAggregationMap(
           newAggregationBuffer(EmptyRow),
           aggregationBufferSchema,
           groupKeySchema,
-          TaskContext.get.taskMemoryManager(),
+          taskContext.taskMemoryManager(),
           SparkEnv.get.shuffleMemoryManager,
           1024 * 16, // initial capacity
           pageSizeBytes,
@@ -284,6 +285,10 @@ case class GeneratedAggregate(
           updateProjection.target(aggregationBuffer)(joinedRow(aggregationBuffer, currentRow))
         }
 
+        // Record memory used in the process
+        taskContext.internalMetricsToAccumulators(
+          InternalAccumulator.PEAK_EXECUTION_MEMORY).add(aggregationMap.getMemoryUsage)
+
         new Iterator[InternalRow] {
           private[this] val mapIterator = aggregationMap.iterator()
           private[this] val resultProjection = resultProjectionBuilder()
@@ -300,7 +305,7 @@ case class GeneratedAggregate(
               } else {
                 // This is the last element in the iterator, so let's free the buffer. Before we do,
                 // though, we need to make a defensive copy of the result so that we don't return an
-                // object that might contain dangling pointers to the freed memory
+                // object that might contain dangling pointers to the freed memory.
                 val resultCopy = result.copy()
                 aggregationMap.free()
                 resultCopy
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index 624efc1b1d7346fea87e1db7e2c6feadb6bfcb0f..e73e2523a777f3e8dcf0a3dfe15897315c53aa62 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins
 import scala.concurrent._
 import scala.concurrent.duration._
 
+import org.apache.spark.{InternalAccumulator, TaskContext}
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
@@ -70,7 +71,14 @@ case class BroadcastHashJoin(
     val broadcastRelation = Await.result(broadcastFuture, timeout)
 
     streamedPlan.execute().mapPartitions { streamedIter =>
-      hashJoin(streamedIter, broadcastRelation.value)
+      val hashedRelation = broadcastRelation.value
+      hashedRelation match {
+        case unsafe: UnsafeHashedRelation =>
+          TaskContext.get().internalMetricsToAccumulators(
+            InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize)
+        case _ =>
+      }
+      hashJoin(streamedIter, hashedRelation)
     }
   }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
index 309716a0efcc01763fccb1b390d01a5216d83f9b..c35e439cc9deb6bfb0e78b7814e873911a57bc3c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins
 import scala.concurrent._
 import scala.concurrent.duration._
 
+import org.apache.spark.{InternalAccumulator, TaskContext}
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
@@ -75,6 +76,13 @@ case class BroadcastHashOuterJoin(
       val hashTable = broadcastRelation.value
       val keyGenerator = streamedKeyGenerator
 
+      hashTable match {
+        case unsafe: UnsafeHashedRelation =>
+          TaskContext.get().internalMetricsToAccumulators(
+            InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize)
+        case _ =>
+      }
+
       joinType match {
         case LeftOuter =>
           streamedIter.flatMap(currentRow => {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
index a60593911f94fc8a4140ba38b3823a44946ad8d4..5bd06fbdca605600f395447c9c02107c9da42740 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.execution.joins
 
+import org.apache.spark.{InternalAccumulator, TaskContext}
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
@@ -51,7 +52,14 @@ case class BroadcastLeftSemiJoinHash(
       val broadcastedRelation = sparkContext.broadcast(hashRelation)
 
       left.execute().mapPartitions { streamIter =>
-        hashSemiJoin(streamIter, broadcastedRelation.value)
+        val hashedRelation = broadcastedRelation.value
+        hashedRelation match {
+          case unsafe: UnsafeHashedRelation =>
+            TaskContext.get().internalMetricsToAccumulators(
+              InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize)
+          case _ =>
+        }
+        hashSemiJoin(streamIter, hashedRelation)
       }
     }
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index cc8bbfd2f89438f7d080bca0ec8ff8d6d9fbe7f6..58b4236f7b5b5b6d77b7f2e0cefa021a4f9e9f39 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -183,8 +183,27 @@ private[joins] final class UnsafeHashedRelation(
   private[joins] def this() = this(null)  // Needed for serialization
 
   // Use BytesToBytesMap in executor for better performance (it's created when deserialization)
+  // This is used in broadcast joins and distributed mode only
   @transient private[this] var binaryMap: BytesToBytesMap = _
 
+  /**
+   * Return the size of the unsafe map on the executors.
+   *
+   * For broadcast joins, this hashed relation is bigger on the driver because it is
+   * represented as a Java hash map there. While serializing the map to the executors,
+   * however, we rehash the contents in a binary map to reduce the memory footprint on
+   * the executors.
+   *
+   * For non-broadcast joins or in local mode, return 0.
+   */
+  def getUnsafeSize: Long = {
+    if (binaryMap != null) {
+      binaryMap.getTotalMemoryConsumption
+    } else {
+      0
+    }
+  }
+
   override def get(key: InternalRow): Seq[InternalRow] = {
     val unsafeKey = key.asInstanceOf[UnsafeRow]
 
@@ -214,7 +233,7 @@ private[joins] final class UnsafeHashedRelation(
       }
 
     } else {
-      // Use the JavaHashMap in Local mode or ShuffleHashJoin
+      // Use the Java HashMap in local mode or for non-broadcast joins (e.g. ShuffleHashJoin)
       hashTable.get(unsafeKey)
     }
   }
@@ -316,6 +335,7 @@ private[joins] object UnsafeHashedRelation {
       keyGenerator: UnsafeProjection,
       sizeEstimate: Int): HashedRelation = {
 
+    // Use a Java hash table here because unsafe maps expect fixed size records
     val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate)
 
     // Create a mapping of buildKeys -> rows
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
index 92cf328c76cbcd5dd03599e0445c00189de96855..3192b6ebe9075dbbe651530cbbe34bbfa8b930fa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.execution
 
+import org.apache.spark.{InternalAccumulator, TaskContext}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.errors._
@@ -76,6 +77,11 @@ case class ExternalSort(
       val sorter = new ExternalSorter[InternalRow, Null, InternalRow](ordering = Some(ordering))
       sorter.insertAll(iterator.map(r => (r.copy(), null)))
       val baseIterator = sorter.iterator.map(_._1)
+      val context = TaskContext.get()
+      context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
+      context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
+      context.internalMetricsToAccumulators(
+        InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes)
       // TODO(marmbrus): The complex type signature below thwarts inference for no reason.
       CompletionIterator[InternalRow, Iterator[InternalRow]](baseIterator, sorter.stop())
     }, preservesPartitioning = true)
@@ -137,7 +143,11 @@ case class TungstenSort(
       if (testSpillFrequency > 0) {
         sorter.setTestSpillFrequency(testSpillFrequency)
       }
-      sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]])
+      val sortedIterator = sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]])
+      val taskContext = TaskContext.get()
+      taskContext.internalMetricsToAccumulators(
+        InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.getPeakMemoryUsage)
+      sortedIterator
     }, preservesPartitioning = true)
   }
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index f1abae0720058c919af15c02cba903c27065cd97..29dfcf2575227ec21325d880482f2c4be9ebfeb5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -21,6 +21,7 @@ import java.sql.Timestamp
 
 import org.scalatest.BeforeAndAfterAll
 
+import org.apache.spark.AccumulatorSuite
 import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
 import org.apache.spark.sql.catalyst.DefaultParserDialect
 import org.apache.spark.sql.catalyst.errors.DialectException
@@ -258,6 +259,23 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
     }
   }
 
+  private def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = {
+    val df = sql(sqlText)
+    // First, check if we have GeneratedAggregate.
+    val hasGeneratedAgg = df.queryExecution.executedPlan
+      .collect { case _: GeneratedAggregate | _: aggregate.Aggregate => true }
+      .nonEmpty
+    if (!hasGeneratedAgg) {
+      fail(
+        s"""
+           |Codegen is enabled, but query $sqlText does not have GeneratedAggregate in the plan.
+           |${df.queryExecution.simpleString}
+         """.stripMargin)
+    }
+    // Then, check results.
+    checkAnswer(df, expectedResults)
+  }
+
   test("aggregation with codegen") {
     val originalValue = sqlContext.conf.codegenEnabled
     sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true)
@@ -267,26 +285,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
       .unionAll(sqlContext.table("testData"))
       .registerTempTable("testData3x")
 
-    def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = {
-      val df = sql(sqlText)
-      // First, check if we have GeneratedAggregate.
-      var hasGeneratedAgg = false
-      df.queryExecution.executedPlan.foreach {
-        case generatedAgg: GeneratedAggregate => hasGeneratedAgg = true
-        case newAggregate: aggregate.Aggregate => hasGeneratedAgg = true
-        case _ =>
-      }
-      if (!hasGeneratedAgg) {
-        fail(
-          s"""
-             |Codegen is enabled, but query $sqlText does not have GeneratedAggregate in the plan.
-             |${df.queryExecution.simpleString}
-           """.stripMargin)
-      }
-      // Then, check results.
-      checkAnswer(df, expectedResults)
-    }
-
     try {
       // Just to group rows.
       testCodeGen(
@@ -1605,6 +1603,28 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
       Row(new CalendarInterval(-(12 * 3 - 3), -(7L * MICROS_PER_WEEK + 123))))
   }
 
+  test("aggregation with codegen updates peak execution memory") {
+    withSQLConf(
+        (SQLConf.CODEGEN_ENABLED.key, "true"),
+        (SQLConf.USE_SQL_AGGREGATE2.key, "false")) {
+      val sc = sqlContext.sparkContext
+      AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "aggregation with codegen") {
+        testCodeGen(
+          "SELECT key, count(value) FROM testData GROUP BY key",
+          (1 to 100).map(i => Row(i, 1)))
+      }
+    }
+  }
+
+  test("external sorting updates peak execution memory") {
+    withSQLConf((SQLConf.EXTERNAL_SORT.key, "true")) {
+      val sc = sqlContext.sparkContext
+      AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external sort") {
+        sortTest()
+      }
+    }
+  }
+
   test("SPARK-9511: error with table starting with number") {
     val df = sqlContext.sparkContext.parallelize(1 to 10).map(i => (i, i.toString))
       .toDF("num", "str")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
index c7949848513cf404d488dd8ac69a453e751ef612..88bce0e319f9e8696c1b567ca9b5684a53f67f8f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
@@ -21,6 +21,7 @@ import scala.util.Random
 
 import org.scalatest.BeforeAndAfterAll
 
+import org.apache.spark.AccumulatorSuite
 import org.apache.spark.sql.{RandomDataGenerator, Row, SQLConf}
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.test.TestSQLContext
@@ -59,6 +60,17 @@ class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll {
     )
   }
 
+  test("sorting updates peak execution memory") {
+    val sc = TestSQLContext.sparkContext
+    AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "unsafe external sort") {
+      checkThatPlansAgree(
+        (1 to 100).map(v => Tuple1(v)).toDF("a"),
+        (child: SparkPlan) => TungstenSort('a.asc :: Nil, true, child),
+        (child: SparkPlan) => Sort('a.asc :: Nil, global = true, child),
+        sortAnswers = false)
+    }
+  }
+
   // Test sorting on different data types
   for (
     dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType);
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
index 7c591f6143b9ea75175d35f3226c9f79c0568575..ef827b0fe9b5b8f016f2984b32385be3a1824461 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
@@ -69,7 +69,8 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers {
         taskAttemptId = Random.nextInt(10000),
         attemptNumber = 0,
         taskMemoryManager = taskMemoryManager,
-        metricsSystem = null))
+        metricsSystem = null,
+        internalAccumulators = Seq.empty))
 
       try {
         f
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
index 0282b25b9dd501c1c531e1b39e755e0d801dd061..601a5a07ad00295ec7f0259e85b5d046ba45d0a7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
@@ -76,7 +76,8 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite {
         taskAttemptId = 98456,
         attemptNumber = 0,
         taskMemoryManager = taskMemMgr,
-        metricsSystem = null))
+        metricsSystem = null,
+        internalAccumulators = Seq.empty))
 
       // Create the data converters
       val kExternalConverter = CatalystTypeConverters.createToCatalystConverter(keySchema)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..0554e11d252ba65462e48ed12b35fc57c94db577
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -0,0 +1,94 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements.  See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License.  You may obtain a copy of the License at
+*
+*    http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+// TODO: uncomment the test here! It is currently failing due to
+// bad interaction with org.apache.spark.sql.test.TestSQLContext.
+
+// scalastyle:off
+//package org.apache.spark.sql.execution.joins
+//
+//import scala.reflect.ClassTag
+//
+//import org.scalatest.BeforeAndAfterAll
+//
+//import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext}
+//import org.apache.spark.sql.functions._
+//import org.apache.spark.sql.{SQLConf, SQLContext, QueryTest}
+//
+///**
+// * Test various broadcast join operators with unsafe enabled.
+// *
+// * This needs to be its own suite because [[org.apache.spark.sql.test.TestSQLContext]] runs
+// * in local mode, but for tests in this suite we need to run Spark in local-cluster mode.
+// * In particular, the use of [[org.apache.spark.unsafe.map.BytesToBytesMap]] in
+// * [[org.apache.spark.sql.execution.joins.UnsafeHashedRelation]] is not triggered without
+// * serializing the hashed relation, which does not happen in local mode.
+// */
+//class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll {
+//  private var sc: SparkContext = null
+//  private var sqlContext: SQLContext = null
+//
+//  /**
+//   * Create a new [[SQLContext]] running in local-cluster mode with unsafe and codegen enabled.
+//   */
+//  override def beforeAll(): Unit = {
+//    super.beforeAll()
+//    val conf = new SparkConf()
+//      .setMaster("local-cluster[2,1,1024]")
+//      .setAppName("testing")
+//    sc = new SparkContext(conf)
+//    sqlContext = new SQLContext(sc)
+//    sqlContext.setConf(SQLConf.UNSAFE_ENABLED, true)
+//    sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true)
+//  }
+//
+//  override def afterAll(): Unit = {
+//    sc.stop()
+//    sc = null
+//    sqlContext = null
+//  }
+//
+//  /**
+//   * Test whether the specified broadcast join updates the peak execution memory accumulator.
+//   */
+//  private def testBroadcastJoin[T: ClassTag](name: String, joinType: String): Unit = {
+//    AccumulatorSuite.verifyPeakExecutionMemorySet(sc, name) {
+//      val df1 = sqlContext.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value")
+//      val df2 = sqlContext.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value")
+//      // Comparison at the end is for broadcast left semi join
+//      val joinExpression = df1("key") === df2("key") && df1("value") > df2("value")
+//      val df3 = df1.join(broadcast(df2), joinExpression, joinType)
+//      val plan = df3.queryExecution.executedPlan
+//      assert(plan.collect { case p: T => p }.size === 1)
+//      plan.executeCollect()
+//    }
+//  }
+//
+//  test("unsafe broadcast hash join updates peak execution memory") {
+//    testBroadcastJoin[BroadcastHashJoin]("unsafe broadcast hash join", "inner")
+//  }
+//
+//  test("unsafe broadcast hash outer join updates peak execution memory") {
+//    testBroadcastJoin[BroadcastHashOuterJoin]("unsafe broadcast hash outer join", "left_outer")
+//  }
+//
+//  test("unsafe broadcast left semi join updates peak execution memory") {
+//    testBroadcastJoin[BroadcastLeftSemiJoinHash]("unsafe broadcast left semi join", "leftsemi")
+//  }
+//
+//}
+// scalastyle:on