diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java
index 840f13b39464c0967867cf17a588383f3a667958..38a21a896e1fe46a104802a0339d2399214448d8 100644
--- a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java
+++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java
@@ -31,15 +31,24 @@ public abstract class MemoryConsumer {
 
   protected final TaskMemoryManager taskMemoryManager;
   private final long pageSize;
+  private final MemoryMode mode;
   protected long used;
 
-  protected MemoryConsumer(TaskMemoryManager taskMemoryManager, long pageSize) {
+  protected MemoryConsumer(TaskMemoryManager taskMemoryManager, long pageSize, MemoryMode mode) {
     this.taskMemoryManager = taskMemoryManager;
     this.pageSize = pageSize;
+    this.mode = mode;
   }
 
   protected MemoryConsumer(TaskMemoryManager taskMemoryManager) {
-    this(taskMemoryManager, taskMemoryManager.pageSizeBytes());
+    this(taskMemoryManager, taskMemoryManager.pageSizeBytes(), MemoryMode.ON_HEAP);
+  }
+
+  /**
+   * Returns the memory mode, ON_HEAP or OFF_HEAP.
+   */
+  public MemoryMode getMode() {
+    return mode;
   }
 
   /**
@@ -132,19 +141,19 @@ public abstract class MemoryConsumer {
   }
 
   /**
-   * Allocates a heap memory of `size`.
+   * Allocates memory of `size`.
    */
-  public long acquireOnHeapMemory(long size) {
-    long granted = taskMemoryManager.acquireExecutionMemory(size, MemoryMode.ON_HEAP, this);
+  public long acquireMemory(long size) {
+    long granted = taskMemoryManager.acquireExecutionMemory(size, this);
     used += granted;
     return granted;
   }
 
   /**
-   * Release N bytes of heap memory.
+   * Release N bytes of memory.
    */
-  public void freeOnHeapMemory(long size) {
-    taskMemoryManager.releaseExecutionMemory(size, MemoryMode.ON_HEAP, this);
+  public void freeMemory(long size) {
+    taskMemoryManager.releaseExecutionMemory(size, this);
     used -= size;
   }
 }
diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
index a05a79c88df76c70f8a9969a70d59adac11a9c05..a4a571f15a8c01c7b30bb16122960979924ba345 100644
--- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
+++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
@@ -76,9 +76,6 @@ public class TaskMemoryManager {
   /** Bit mask for the lower 51 bits of a long. */
   private static final long MASK_LONG_LOWER_51_BITS = 0x7FFFFFFFFFFFFL;
 
-  /** Bit mask for the upper 13 bits of a long */
-  private static final long MASK_LONG_UPPER_13_BITS = ~MASK_LONG_LOWER_51_BITS;
-
   /**
    * Similar to an operating system's page table, this array maps page numbers into base object
    * pointers, allowing us to translate between the hashtable's internal 64-bit address
@@ -132,11 +129,10 @@ public class TaskMemoryManager {
    *
    * @return number of bytes successfully granted (<= N).
    */
-  public long acquireExecutionMemory(
-      long required,
-      MemoryMode mode,
-      MemoryConsumer consumer) {
+  public long acquireExecutionMemory(long required, MemoryConsumer consumer) {
     assert(required >= 0);
+    assert(consumer != null);
+    MemoryMode mode = consumer.getMode();
     // If we are allocating Tungsten pages off-heap and receive a request to allocate on-heap
     // memory here, then it may not make sense to spill since that would only end up freeing
     // off-heap memory. This is subject to change, though, so it may be risky to make this
@@ -149,10 +145,10 @@ public class TaskMemoryManager {
       if (got < required) {
         // Call spill() on other consumers to release memory
         for (MemoryConsumer c: consumers) {
-          if (c != consumer && c.getUsed() > 0) {
+          if (c != consumer && c.getUsed() > 0 && c.getMode() == mode) {
             try {
               long released = c.spill(required - got, consumer);
-              if (released > 0 && mode == tungstenMemoryMode) {
+              if (released > 0) {
                 logger.debug("Task {} released {} from {} for {}", taskAttemptId,
                   Utils.bytesToString(released), c, consumer);
                 got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId, mode);
@@ -170,10 +166,10 @@ public class TaskMemoryManager {
       }
 
       // call spill() on itself
-      if (got < required && consumer != null) {
+      if (got < required) {
         try {
           long released = consumer.spill(required - got, consumer);
-          if (released > 0 && mode == tungstenMemoryMode) {
+          if (released > 0) {
             logger.debug("Task {} released {} from itself ({})", taskAttemptId,
               Utils.bytesToString(released), consumer);
             got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId, mode);
@@ -185,9 +181,7 @@ public class TaskMemoryManager {
         }
       }
 
-      if (consumer != null) {
-        consumers.add(consumer);
-      }
+      consumers.add(consumer);
       logger.debug("Task {} acquire {} for {}", taskAttemptId, Utils.bytesToString(got), consumer);
       return got;
     }
@@ -196,9 +190,9 @@ public class TaskMemoryManager {
   /**
    * Release N bytes of execution memory for a MemoryConsumer.
    */
-  public void releaseExecutionMemory(long size, MemoryMode mode, MemoryConsumer consumer) {
+  public void releaseExecutionMemory(long size, MemoryConsumer consumer) {
     logger.debug("Task {} release {} from {}", taskAttemptId, Utils.bytesToString(size), consumer);
-    memoryManager.releaseExecutionMemory(size, taskAttemptId, mode);
+    memoryManager.releaseExecutionMemory(size, taskAttemptId, consumer.getMode());
   }
 
   /**
@@ -241,12 +235,14 @@ public class TaskMemoryManager {
    * contains fewer bytes than requested, so callers should verify the size of returned pages.
    */
   public MemoryBlock allocatePage(long size, MemoryConsumer consumer) {
+    assert(consumer != null);
+    assert(consumer.getMode() == tungstenMemoryMode);
     if (size > MAXIMUM_PAGE_SIZE_BYTES) {
       throw new IllegalArgumentException(
         "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE_BYTES + " bytes");
     }
 
-    long acquired = acquireExecutionMemory(size, tungstenMemoryMode, consumer);
+    long acquired = acquireExecutionMemory(size, consumer);
     if (acquired <= 0) {
       return null;
     }
@@ -255,7 +251,7 @@ public class TaskMemoryManager {
     synchronized (this) {
       pageNumber = allocatedPages.nextClearBit(0);
       if (pageNumber >= PAGE_TABLE_SIZE) {
-        releaseExecutionMemory(acquired, tungstenMemoryMode, consumer);
+        releaseExecutionMemory(acquired, consumer);
         throw new IllegalStateException(
           "Have already allocated a maximum of " + PAGE_TABLE_SIZE + " pages");
       }
@@ -299,7 +295,7 @@ public class TaskMemoryManager {
     }
     long pageSize = page.size();
     memoryManager.tungstenMemoryAllocator().free(page);
-    releaseExecutionMemory(pageSize, tungstenMemoryMode, consumer);
+    releaseExecutionMemory(pageSize, consumer);
   }
 
   /**
@@ -396,8 +392,7 @@ public class TaskMemoryManager {
       Arrays.fill(pageTable, null);
     }
 
-
-    // release the memory that is not used by any consumer.
+    // release the memory that is not used by any consumer (acquired for pages in tungsten mode).
     memoryManager.releaseExecutionMemory(acquiredButNotUsed, taskAttemptId, tungstenMemoryMode);
 
     return memoryManager.releaseAllExecutionMemoryForTask(taskAttemptId);
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
index 2be5a16b2d1e4d52c34633600e331ca510a76055..014aef86b5cc641d63e8dd0458a9befa8315f864 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
@@ -104,8 +104,9 @@ final class ShuffleExternalSorter extends MemoryConsumer {
       int numPartitions,
       SparkConf conf,
       ShuffleWriteMetrics writeMetrics) {
-    super(memoryManager, (int) Math.min(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES,
-      memoryManager.pageSizeBytes()));
+    super(memoryManager,
+      (int) Math.min(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, memoryManager.pageSizeBytes()),
+      memoryManager.getTungstenMemoryMode());
     this.taskMemoryManager = memoryManager;
     this.blockManager = blockManager;
     this.taskContext = taskContext;
diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index 6807710f9fef13f129f75046d933cabb9ab8e204..6c00608302c4e2e056a53d1c24ec723a88cf1f81 100644
--- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -182,7 +182,7 @@ public final class BytesToBytesMap extends MemoryConsumer {
       double loadFactor,
       long pageSizeBytes,
       boolean enablePerfMetrics) {
-    super(taskMemoryManager, pageSizeBytes);
+    super(taskMemoryManager, pageSizeBytes, taskMemoryManager.getTungstenMemoryMode());
     this.taskMemoryManager = taskMemoryManager;
     this.blockManager = blockManager;
     this.serializerManager = serializerManager;
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index 7dc0508784981540b2a9014fd7399a7fd220261e..e14a23f4a6a833b8068ea7181a0c3073fa27519d 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -124,7 +124,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
       long pageSizeBytes,
       @Nullable UnsafeInMemorySorter existingInMemorySorter,
       boolean canUseRadixSort) {
-    super(taskMemoryManager, pageSizeBytes);
+    super(taskMemoryManager, pageSizeBytes, taskMemoryManager.getTungstenMemoryMode());
     this.taskMemoryManager = taskMemoryManager;
     this.blockManager = blockManager;
     this.serializerManager = serializerManager;
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index b695aecc13ea1a8dee7f648d387ed0daaaa3e788..9a017f29f7d2112620aaef1bf940751c28ebf39d 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -281,20 +281,20 @@ private[spark] class Executor(
           val releasedLocks = env.blockManager.releaseAllLocksForTask(taskId)
           val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()
 
-          if (freedMemory > 0) {
+          if (freedMemory > 0 && !threwException) {
             val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId"
-            if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false) && !threwException) {
+            if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) {
               throw new SparkException(errMsg)
             } else {
-              logError(errMsg)
+              logWarning(errMsg)
             }
           }
 
-          if (releasedLocks.nonEmpty) {
+          if (releasedLocks.nonEmpty && !threwException) {
             val errMsg =
               s"${releasedLocks.size} block locks were not released by TID = $taskId:\n" +
                 releasedLocks.mkString("[", ", ", "]")
-            if (conf.getBoolean("spark.storage.exceptionOnPinLeak", false) && !threwException) {
+            if (conf.getBoolean("spark.storage.exceptionOnPinLeak", false)) {
               throw new SparkException(errMsg)
             } else {
               logWarning(errMsg)
diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
index bdcbd22fd814ba283172318ca404765c20e8826e..8183f825592c0a34a6296cad02ec1b343bdc5a12 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
@@ -83,7 +83,7 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager)
     if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) {
       // Claim up to double our current memory from the shuffle memory pool
       val amountToRequest = 2 * currentMemory - myMemoryThreshold
-      val granted = acquireOnHeapMemory(amountToRequest)
+      val granted = acquireMemory(amountToRequest)
       myMemoryThreshold += granted
       // If we were granted too little memory to grow further (either tryToAcquire returned 0,
       // or we already had more memory than myMemoryThreshold), spill the current collection
@@ -131,7 +131,7 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager)
    * Release our memory back to the execution pool so that other tasks can grab it.
    */
   def releaseMemory(): Unit = {
-    freeOnHeapMemory(myMemoryThreshold - initialMemoryThreshold)
+    freeMemory(myMemoryThreshold - initialMemoryThreshold)
     myMemoryThreshold = initialMemoryThreshold
   }
 
diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
index 127789b632b44d3c3d5041d8c65c1579ba2d4a34..ad755529dec642a725966be64a317b378f428b6f 100644
--- a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
+++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
@@ -34,7 +34,8 @@ public class TaskMemoryManagerSuite {
         Long.MAX_VALUE,
         1),
       0);
-    manager.allocatePage(4096, null);  // leak memory
+    final MemoryConsumer c = new TestMemoryConsumer(manager);
+    manager.allocatePage(4096, c);  // leak memory
     Assert.assertEquals(4096, manager.getMemoryConsumptionForThisTask());
     Assert.assertEquals(4096, manager.cleanUpAllAllocatedMemory());
   }
@@ -45,7 +46,8 @@ public class TaskMemoryManagerSuite {
       .set("spark.memory.offHeap.enabled", "true")
       .set("spark.memory.offHeap.size", "1000");
     final TaskMemoryManager manager = new TaskMemoryManager(new TestMemoryManager(conf), 0);
-    final MemoryBlock dataPage = manager.allocatePage(256, null);
+    final MemoryConsumer c = new TestMemoryConsumer(manager, MemoryMode.OFF_HEAP);
+    final MemoryBlock dataPage = manager.allocatePage(256, c);
     // In off-heap mode, an offset is an absolute address that may require more than 51 bits to
     // encode. This test exercises that corner-case:
     final long offset = ((1L << TaskMemoryManager.OFFSET_BITS) + 10);
@@ -58,7 +60,8 @@ public class TaskMemoryManagerSuite {
   public void encodePageNumberAndOffsetOnHeap() {
     final TaskMemoryManager manager = new TaskMemoryManager(
       new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0);
-    final MemoryBlock dataPage = manager.allocatePage(256, null);
+    final MemoryConsumer c = new TestMemoryConsumer(manager, MemoryMode.ON_HEAP);
+    final MemoryBlock dataPage = manager.allocatePage(256, c);
     final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, 64);
     Assert.assertEquals(dataPage.getBaseObject(), manager.getPage(encodedAddress));
     Assert.assertEquals(64, manager.getOffsetInPage(encodedAddress));
@@ -106,6 +109,25 @@ public class TaskMemoryManagerSuite {
     Assert.assertEquals(0, manager.cleanUpAllAllocatedMemory());
   }
 
+  @Test
+  public void shouldNotForceSpillingInDifferentModes() {
+    final TestMemoryManager memoryManager = new TestMemoryManager(new SparkConf());
+    memoryManager.limit(100);
+    final TaskMemoryManager manager = new TaskMemoryManager(memoryManager, 0);
+
+    TestMemoryConsumer c1 = new TestMemoryConsumer(manager, MemoryMode.ON_HEAP);
+    TestMemoryConsumer c2 = new TestMemoryConsumer(manager, MemoryMode.OFF_HEAP);
+    c1.use(80);
+    Assert.assertEquals(80, c1.getUsed());
+    c2.use(80);
+    Assert.assertEquals(20, c2.getUsed());  // not enough memory
+    Assert.assertEquals(80, c1.getUsed());  // not spilled
+
+    c2.use(10);
+    Assert.assertEquals(10, c2.getUsed());  // spilled
+    Assert.assertEquals(80, c1.getUsed());  // not spilled
+  }
+
   @Test
   public void offHeapConfigurationBackwardsCompatibility() {
     // Tests backwards-compatibility with the old `spark.unsafe.offHeap` configuration, which
diff --git a/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java
index e6e16fff8040141434abbd92616a9bbd4abec990..db91329c94cb68d060cef3cef8e1b9be1b41d55d 100644
--- a/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java
+++ b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java
@@ -20,8 +20,11 @@ package org.apache.spark.memory;
 import java.io.IOException;
 
 public class TestMemoryConsumer extends MemoryConsumer {
+  public TestMemoryConsumer(TaskMemoryManager memoryManager, MemoryMode mode) {
+    super(memoryManager, 1024L, mode);
+  }
   public TestMemoryConsumer(TaskMemoryManager memoryManager) {
-    super(memoryManager);
+    this(memoryManager, MemoryMode.ON_HEAP);
   }
 
   @Override
@@ -32,19 +35,13 @@ public class TestMemoryConsumer extends MemoryConsumer {
   }
 
   void use(long size) {
-    long got = taskMemoryManager.acquireExecutionMemory(
-      size,
-      taskMemoryManager.tungstenMemoryMode,
-      this);
+    long got = taskMemoryManager.acquireExecutionMemory(size, this);
     used += got;
   }
 
   void free(long size) {
     used -= size;
-    taskMemoryManager.releaseExecutionMemory(
-      size,
-      taskMemoryManager.tungstenMemoryMode,
-      this);
+    taskMemoryManager.releaseExecutionMemory(size, this);
   }
 }
 
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
index fe5abc5c230491b8217da6c9868052ed94fade97..354efe18dbde733e45f3a8d7c24283d07a98aff9 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
@@ -22,8 +22,7 @@ import java.io.IOException;
 import org.junit.Test;
 
 import org.apache.spark.SparkConf;
-import org.apache.spark.memory.TestMemoryManager;
-import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.memory.*;
 import org.apache.spark.unsafe.memory.MemoryBlock;
 
 import static org.apache.spark.shuffle.sort.PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES;
@@ -38,8 +37,9 @@ public class PackedRecordPointerSuite {
     final SparkConf conf = new SparkConf().set("spark.memory.offHeap.enabled", "false");
     final TaskMemoryManager memoryManager =
       new TaskMemoryManager(new TestMemoryManager(conf), 0);
-    final MemoryBlock page0 = memoryManager.allocatePage(128, null);
-    final MemoryBlock page1 = memoryManager.allocatePage(128, null);
+    final MemoryConsumer c = new TestMemoryConsumer(memoryManager, MemoryMode.ON_HEAP);
+    final MemoryBlock page0 = memoryManager.allocatePage(128, c);
+    final MemoryBlock page1 = memoryManager.allocatePage(128, c);
     final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
       page1.getBaseOffset() + 42);
     PackedRecordPointer packedPointer = new PackedRecordPointer();
@@ -59,8 +59,9 @@ public class PackedRecordPointerSuite {
       .set("spark.memory.offHeap.size", "10000");
     final TaskMemoryManager memoryManager =
       new TaskMemoryManager(new TestMemoryManager(conf), 0);
-    final MemoryBlock page0 = memoryManager.allocatePage(128, null);
-    final MemoryBlock page1 = memoryManager.allocatePage(128, null);
+    final MemoryConsumer c = new TestMemoryConsumer(memoryManager, MemoryMode.OFF_HEAP);
+    final MemoryBlock page0 = memoryManager.allocatePage(128, c);
+    final MemoryBlock page1 = memoryManager.allocatePage(128, c);
     final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
       page1.getBaseOffset() + 42);
     PackedRecordPointer packedPointer = new PackedRecordPointer();
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
index 278a827644db77957e88c55ad5c516d9b4a221de..694352ee2af448ec00762334899b25ce0643fec8 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
@@ -26,6 +26,7 @@ import org.junit.Test;
 
 import org.apache.spark.HashPartitioner;
 import org.apache.spark.SparkConf;
+import org.apache.spark.memory.MemoryConsumer;
 import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.memory.TestMemoryConsumer;
 import org.apache.spark.memory.TestMemoryManager;
@@ -71,7 +72,8 @@ public class ShuffleInMemorySorterSuite {
     final SparkConf conf = new SparkConf().set("spark.memory.offHeap.enabled", "false");
     final TaskMemoryManager memoryManager =
       new TaskMemoryManager(new TestMemoryManager(conf), 0);
-    final MemoryBlock dataPage = memoryManager.allocatePage(2048, null);
+    final MemoryConsumer c = new TestMemoryConsumer(memoryManager);
+    final MemoryBlock dataPage = memoryManager.allocatePage(2048, c);
     final Object baseObject = dataPage.getBaseObject();
     final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(
       consumer, 4, shouldUseRadixSort());
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
index 4a2f65a0ed2b2db9ad8d9d4b7eedbc2db1723247..383c5b3b0884ab12de04b18705cb3eb6a5644e97 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
@@ -78,7 +78,7 @@ public class UnsafeInMemorySorterSuite {
     final TaskMemoryManager memoryManager = new TaskMemoryManager(
       new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0);
     final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager);
-    final MemoryBlock dataPage = memoryManager.allocatePage(2048, null);
+    final MemoryBlock dataPage = memoryManager.allocatePage(2048, consumer);
     final Object baseObject = dataPage.getBaseObject();
     // Write the records into the data page:
     long position = dataPage.getBaseOffset();
diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala
index 3def8b0b1850e7a263605e06b59600e802221eb0..333c23bdaf6d6626aea65ca5484f8e2a9e7a961c 100644
--- a/core/src/test/scala/org/apache/spark/FailureSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark
 
 import java.io.{IOException, NotSerializableException, ObjectInputStream}
 
+import org.apache.spark.memory.TestMemoryConsumer
 import org.apache.spark.util.NonSerializable
 
 // Common state shared by FailureSuite-launched tasks. We use a global object
@@ -149,7 +150,8 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext {
     // cause is preserved
     val thrownDueToTaskFailure = intercept[SparkException] {
       sc.parallelize(Seq(0)).mapPartitions { iter =>
-        TaskContext.get().taskMemoryManager().allocatePage(128, null)
+        val c = new TestMemoryConsumer(TaskContext.get().taskMemoryManager())
+        TaskContext.get().taskMemoryManager().allocatePage(128, c)
         throw new Exception("intentional task failure")
         iter
       }.count()
@@ -159,7 +161,8 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext {
     // If the task succeeded but memory was leaked, then the task should fail due to that leak
     val thrownDueToMemoryLeak = intercept[SparkException] {
       sc.parallelize(Seq(0)).mapPartitions { iter =>
-        TaskContext.get().taskMemoryManager().allocatePage(128, null)
+        val c = new TestMemoryConsumer(TaskContext.get().taskMemoryManager())
+        TaskContext.get().taskMemoryManager().allocatePage(128, c)
         iter
       }.count()
     }
diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala
index 2c4928ab907ad7d631744206c758dec09876a977..38bf7e5e5aec3636d0248e78be2542f822193add 100644
--- a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala
@@ -162,39 +162,42 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAft
   test("single task requesting on-heap execution memory") {
     val manager = createMemoryManager(1000L)
     val taskMemoryManager = new TaskMemoryManager(manager, 0)
+    val c = new TestMemoryConsumer(taskMemoryManager)
 
-    assert(taskMemoryManager.acquireExecutionMemory(100L, MemoryMode.ON_HEAP, null) === 100L)
-    assert(taskMemoryManager.acquireExecutionMemory(400L, MemoryMode.ON_HEAP, null) === 400L)
-    assert(taskMemoryManager.acquireExecutionMemory(400L, MemoryMode.ON_HEAP, null) === 400L)
-    assert(taskMemoryManager.acquireExecutionMemory(200L, MemoryMode.ON_HEAP, null) === 100L)
-    assert(taskMemoryManager.acquireExecutionMemory(100L, MemoryMode.ON_HEAP, null) === 0L)
-    assert(taskMemoryManager.acquireExecutionMemory(100L, MemoryMode.ON_HEAP, null) === 0L)
+    assert(taskMemoryManager.acquireExecutionMemory(100L, c) === 100L)
+    assert(taskMemoryManager.acquireExecutionMemory(400L, c) === 400L)
+    assert(taskMemoryManager.acquireExecutionMemory(400L, c) === 400L)
+    assert(taskMemoryManager.acquireExecutionMemory(200L, c) === 100L)
+    assert(taskMemoryManager.acquireExecutionMemory(100L, c) === 0L)
+    assert(taskMemoryManager.acquireExecutionMemory(100L, c) === 0L)
 
-    taskMemoryManager.releaseExecutionMemory(500L, MemoryMode.ON_HEAP, null)
-    assert(taskMemoryManager.acquireExecutionMemory(300L, MemoryMode.ON_HEAP, null) === 300L)
-    assert(taskMemoryManager.acquireExecutionMemory(300L, MemoryMode.ON_HEAP, null) === 200L)
+    taskMemoryManager.releaseExecutionMemory(500L, c)
+    assert(taskMemoryManager.acquireExecutionMemory(300L, c) === 300L)
+    assert(taskMemoryManager.acquireExecutionMemory(300L, c) === 200L)
 
     taskMemoryManager.cleanUpAllAllocatedMemory()
-    assert(taskMemoryManager.acquireExecutionMemory(1000L, MemoryMode.ON_HEAP, null) === 1000L)
-    assert(taskMemoryManager.acquireExecutionMemory(100L, MemoryMode.ON_HEAP, null) === 0L)
+    assert(taskMemoryManager.acquireExecutionMemory(1000L, c) === 1000L)
+    assert(taskMemoryManager.acquireExecutionMemory(100L, c) === 0L)
   }
 
   test("two tasks requesting full on-heap execution memory") {
     val memoryManager = createMemoryManager(1000L)
     val t1MemManager = new TaskMemoryManager(memoryManager, 1)
     val t2MemManager = new TaskMemoryManager(memoryManager, 2)
+    val c1 = new TestMemoryConsumer(t1MemManager)
+    val c2 = new TestMemoryConsumer(t2MemManager)
     val futureTimeout: Duration = 20.seconds
 
     // Have both tasks request 500 bytes, then wait until both requests have been granted:
-    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) }
-    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) }
+    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(500L, c1) }
+    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L, c2) }
     assert(ThreadUtils.awaitResult(t1Result1, futureTimeout) === 500L)
     assert(ThreadUtils.awaitResult(t2Result1, futureTimeout) === 500L)
 
     // Have both tasks each request 500 bytes more; both should immediately return 0 as they are
     // both now at 1 / N
-    val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) }
-    val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) }
+    val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L, c1) }
+    val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, c2) }
     assert(ThreadUtils.awaitResult(t1Result2, 200.millis) === 0L)
     assert(ThreadUtils.awaitResult(t2Result2, 200.millis) === 0L)
   }
@@ -203,18 +206,20 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAft
     val memoryManager = createMemoryManager(1000L)
     val t1MemManager = new TaskMemoryManager(memoryManager, 1)
     val t2MemManager = new TaskMemoryManager(memoryManager, 2)
+    val c1 = new TestMemoryConsumer(t1MemManager)
+    val c2 = new TestMemoryConsumer(t2MemManager)
     val futureTimeout: Duration = 20.seconds
 
     // Have both tasks request 250 bytes, then wait until both requests have been granted:
-    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(250L, MemoryMode.ON_HEAP, null) }
-    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L, MemoryMode.ON_HEAP, null) }
+    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(250L, c1) }
+    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L, c2) }
     assert(ThreadUtils.awaitResult(t1Result1, futureTimeout) === 250L)
     assert(ThreadUtils.awaitResult(t2Result1, futureTimeout) === 250L)
 
     // Have both tasks each request 500 bytes more.
     // We should only grant 250 bytes to each of them on this second request
-    val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) }
-    val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) }
+    val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L, c1) }
+    val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, c2) }
     assert(ThreadUtils.awaitResult(t1Result2, futureTimeout) === 250L)
     assert(ThreadUtils.awaitResult(t2Result2, futureTimeout) === 250L)
   }
@@ -223,20 +228,22 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAft
     val memoryManager = createMemoryManager(1000L)
     val t1MemManager = new TaskMemoryManager(memoryManager, 1)
     val t2MemManager = new TaskMemoryManager(memoryManager, 2)
+    val c1 = new TestMemoryConsumer(t1MemManager)
+    val c2 = new TestMemoryConsumer(t2MemManager)
     val futureTimeout: Duration = 20.seconds
 
     // t1 grabs 1000 bytes and then waits until t2 is ready to make a request.
-    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L, MemoryMode.ON_HEAP, null) }
+    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L, c1) }
     assert(ThreadUtils.awaitResult(t1Result1, futureTimeout) === 1000L)
-    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L, MemoryMode.ON_HEAP, null) }
+    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L, c2) }
     // Make sure that t2 didn't grab the memory right away. This is hacky but it would be difficult
     // to make sure the other thread blocks for some time otherwise.
     Thread.sleep(300)
-    t1MemManager.releaseExecutionMemory(250L, MemoryMode.ON_HEAP, null)
+    t1MemManager.releaseExecutionMemory(250L, c1)
     // The memory freed from t1 should now be granted to t2.
     assert(ThreadUtils.awaitResult(t2Result1, futureTimeout) === 250L)
     // Further requests by t2 should be denied immediately because it now has 1 / 2N of the memory.
-    val t2Result2 = Future { t2MemManager.acquireExecutionMemory(100L, MemoryMode.ON_HEAP, null) }
+    val t2Result2 = Future { t2MemManager.acquireExecutionMemory(100L, c2) }
     assert(ThreadUtils.awaitResult(t2Result2, 200.millis) === 0L)
   }
 
@@ -244,21 +251,23 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAft
     val memoryManager = createMemoryManager(1000L)
     val t1MemManager = new TaskMemoryManager(memoryManager, 1)
     val t2MemManager = new TaskMemoryManager(memoryManager, 2)
+    val c1 = new TestMemoryConsumer(t1MemManager)
+    val c2 = new TestMemoryConsumer(t2MemManager)
     val futureTimeout: Duration = 20.seconds
 
     // t1 grabs 1000 bytes and then waits until t2 is ready to make a request.
-    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L, MemoryMode.ON_HEAP, null) }
+    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L, c1) }
     assert(ThreadUtils.awaitResult(t1Result1, futureTimeout) === 1000L)
-    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) }
+    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L, c2) }
     // Make sure that t2 didn't grab the memory right away. This is hacky but it would be difficult
     // to make sure the other thread blocks for some time otherwise.
     Thread.sleep(300)
     // t1 releases all of its memory, so t2 should be able to grab all of the memory
     t1MemManager.cleanUpAllAllocatedMemory()
     assert(ThreadUtils.awaitResult(t2Result1, futureTimeout) === 500L)
-    val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) }
+    val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, c2) }
     assert(ThreadUtils.awaitResult(t2Result2, futureTimeout) === 500L)
-    val t2Result3 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) }
+    val t2Result3 = Future { t2MemManager.acquireExecutionMemory(500L, c2) }
     assert(ThreadUtils.awaitResult(t2Result3, 200.millis) === 0L)
   }
 
@@ -267,15 +276,17 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAft
     val memoryManager = createMemoryManager(1000L)
     val t1MemManager = new TaskMemoryManager(memoryManager, 1)
     val t2MemManager = new TaskMemoryManager(memoryManager, 2)
+    val c1 = new TestMemoryConsumer(t1MemManager)
+    val c2 = new TestMemoryConsumer(t2MemManager)
     val futureTimeout: Duration = 20.seconds
 
-    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(700L, MemoryMode.ON_HEAP, null) }
+    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(700L, c1) }
     assert(ThreadUtils.awaitResult(t1Result1, futureTimeout) === 700L)
 
-    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(300L, MemoryMode.ON_HEAP, null) }
+    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(300L, c2) }
     assert(ThreadUtils.awaitResult(t2Result1, futureTimeout) === 300L)
 
-    val t1Result2 = Future { t1MemManager.acquireExecutionMemory(300L, MemoryMode.ON_HEAP, null) }
+    val t1Result2 = Future { t1MemManager.acquireExecutionMemory(300L, c1) }
     assert(ThreadUtils.awaitResult(t1Result2, 200.millis) === 0L)
   }
 
@@ -285,17 +296,18 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAft
       maxOffHeapExecutionMemory = 1000L)
 
     val tMemManager = new TaskMemoryManager(memoryManager, 1)
-    val result1 = Future { tMemManager.acquireExecutionMemory(1000L, MemoryMode.OFF_HEAP, null) }
+    val c = new TestMemoryConsumer(tMemManager, MemoryMode.OFF_HEAP)
+    val result1 = Future { tMemManager.acquireExecutionMemory(1000L, c) }
     assert(ThreadUtils.awaitResult(result1, 200.millis) === 1000L)
     assert(tMemManager.getMemoryConsumptionForThisTask === 1000L)
 
-    val result2 = Future { tMemManager.acquireExecutionMemory(300L, MemoryMode.OFF_HEAP, null) }
+    val result2 = Future { tMemManager.acquireExecutionMemory(300L, c) }
     assert(ThreadUtils.awaitResult(result2, 200.millis) === 0L)
 
     assert(tMemManager.getMemoryConsumptionForThisTask === 1000L)
-    tMemManager.releaseExecutionMemory(500L, MemoryMode.OFF_HEAP, null)
+    tMemManager.releaseExecutionMemory(500L, c)
     assert(tMemManager.getMemoryConsumptionForThisTask === 500L)
-    tMemManager.releaseExecutionMemory(500L, MemoryMode.OFF_HEAP, null)
+    tMemManager.releaseExecutionMemory(500L, c)
     assert(tMemManager.getMemoryConsumptionForThisTask === 0L)
   }
 }
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 45f7297048382a2b25aac759ffcf083f29456283..4e99a0965780b5571e54b8e13f84505f0a0ce3e3 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -40,6 +40,7 @@ object MimaExcludes {
         excludePackage("org.spark-project.jetty"),
         excludePackage("org.apache.spark.unused"),
         excludePackage("org.apache.spark.unsafe"),
+        excludePackage("org.apache.spark.memory"),
         excludePackage("org.apache.spark.util.collection.unsafe"),
         excludePackage("org.apache.spark.sql.catalyst"),
         excludePackage("org.apache.spark.sql.execution"),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 315ef6a8796f05ea906b5eb335a672fb0b91f506..cb41457b6653fa096ab16c3c6e63f2ed61463474 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -398,9 +398,9 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
       0)
   }
 
-  private def acquireMemory(size: Long): Unit = {
+  private def ensureAcquireMemory(size: Long): Unit = {
     // do not support spilling
-    val got = mm.acquireExecutionMemory(size, MemoryMode.ON_HEAP, this)
+    val got = acquireMemory(size)
     if (got < size) {
       freeMemory(got)
       throw new SparkException(s"Can't acquire $size bytes memory to build hash relation, " +
@@ -408,15 +408,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
     }
   }
 
-  private def freeMemory(size: Long): Unit = {
-    mm.releaseExecutionMemory(size, MemoryMode.ON_HEAP, this)
-  }
-
   private def init(): Unit = {
     if (mm != null) {
       var n = 1
       while (n < capacity) n *= 2
-      acquireMemory(n * 2 * 8 + (1 << 20))
+      ensureAcquireMemory(n * 2 * 8 + (1 << 20))
       array = new Array[Long](n * 2)
       mask = n * 2 - 2
       page = new Array[Long](1 << 17)  // 1M bytes
@@ -538,7 +534,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
       if (used >= (1 << 30)) {
         sys.error("Can not build a HashedRelation that is larger than 8G")
       }
-      acquireMemory(used * 8L * 2)
+      ensureAcquireMemory(used * 8L * 2)
       val newPage = new Array[Long](used * 2)
       Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET,
         cursor - Platform.LONG_ARRAY_OFFSET)
@@ -591,7 +587,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
     var old_array = array
     val n = array.length
     numKeys = 0
-    acquireMemory(n * 2 * 8L)
+    ensureAcquireMemory(n * 2 * 8L)
     array = new Array[Long](n * 2)
     mask = n * 2 - 2
     var i = 0
@@ -613,7 +609,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
     // Convert to dense mode if it does not require more memory or could fit within L1 cache
     if (range < array.length || range < 1024) {
       try {
-        acquireMemory((range + 1) * 8)
+        ensureAcquireMemory((range + 1) * 8)
       } catch {
         case e: SparkException =>
           // there is no enough memory to convert