diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
similarity index 78%
rename from unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
rename to core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index 198e0684f32f8d330b1c0e54c271f8b2edb6bb95..0f42950e6ed8b61490dc70d4e71e2be465ca1232 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -17,6 +17,7 @@
 
 package org.apache.spark.unsafe.map;
 
+import java.io.IOException;
 import java.lang.Override;
 import java.lang.UnsupportedOperationException;
 import java.util.Iterator;
@@ -24,7 +25,10 @@ import java.util.LinkedList;
 import java.util.List;
 
 import com.google.common.annotations.VisibleForTesting;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
+import org.apache.spark.shuffle.ShuffleMemoryManager;
 import org.apache.spark.unsafe.*;
 import org.apache.spark.unsafe.array.ByteArrayMethods;
 import org.apache.spark.unsafe.array.LongArray;
@@ -45,6 +49,8 @@ import org.apache.spark.unsafe.memory.*;
  */
 public final class BytesToBytesMap {
 
+  private final Logger logger = LoggerFactory.getLogger(BytesToBytesMap.class);
+
   private static final Murmur3_x86_32 HASHER = new Murmur3_x86_32(0);
 
   private static final HashMapGrowthStrategy growthStrategy = HashMapGrowthStrategy.DOUBLING;
@@ -54,7 +60,9 @@ public final class BytesToBytesMap {
    */
   private static final int END_OF_PAGE_MARKER = -1;
 
-  private final TaskMemoryManager memoryManager;
+  private final TaskMemoryManager taskMemoryManager;
+
+  private final ShuffleMemoryManager shuffleMemoryManager;
 
   /**
    * A linked list for tracking all allocated data pages so that we can free all of our memory.
@@ -120,7 +128,7 @@ public final class BytesToBytesMap {
   /**
    * Number of keys defined in the map.
    */
-  private int size;
+  private int numElements;
 
   /**
    * The map will be expanded once the number of keys exceeds this threshold.
@@ -150,12 +158,14 @@ public final class BytesToBytesMap {
   private long numHashCollisions = 0;
 
   public BytesToBytesMap(
-      TaskMemoryManager memoryManager,
+      TaskMemoryManager taskMemoryManager,
+      ShuffleMemoryManager shuffleMemoryManager,
       int initialCapacity,
       double loadFactor,
       long pageSizeBytes,
       boolean enablePerfMetrics) {
-    this.memoryManager = memoryManager;
+    this.taskMemoryManager = taskMemoryManager;
+    this.shuffleMemoryManager = shuffleMemoryManager;
     this.loadFactor = loadFactor;
     this.loc = new Location();
     this.pageSizeBytes = pageSizeBytes;
@@ -175,24 +185,32 @@ public final class BytesToBytesMap {
   }
 
   public BytesToBytesMap(
-      TaskMemoryManager memoryManager,
+      TaskMemoryManager taskMemoryManager,
+      ShuffleMemoryManager shuffleMemoryManager,
       int initialCapacity,
       long pageSizeBytes) {
-    this(memoryManager, initialCapacity, 0.70, pageSizeBytes, false);
+    this(taskMemoryManager, shuffleMemoryManager, initialCapacity, 0.70, pageSizeBytes, false);
   }
 
   public BytesToBytesMap(
-      TaskMemoryManager memoryManager,
+      TaskMemoryManager taskMemoryManager,
+      ShuffleMemoryManager shuffleMemoryManager,
       int initialCapacity,
       long pageSizeBytes,
       boolean enablePerfMetrics) {
-    this(memoryManager, initialCapacity, 0.70, pageSizeBytes, enablePerfMetrics);
+    this(
+      taskMemoryManager,
+      shuffleMemoryManager,
+      initialCapacity,
+      0.70,
+      pageSizeBytes,
+      enablePerfMetrics);
   }
 
   /**
    * Returns the number of keys defined in the map.
    */
-  public int size() { return size; }
+  public int numElements() { return numElements; }
 
   private static final class BytesToBytesMapIterator implements Iterator<Location> {
 
@@ -252,7 +270,7 @@ public final class BytesToBytesMap {
    * `lookup()`, the behavior of the returned iterator is undefined.
    */
   public Iterator<Location> iterator() {
-    return new BytesToBytesMapIterator(size, dataPages.iterator(), loc);
+    return new BytesToBytesMapIterator(numElements, dataPages.iterator(), loc);
   }
 
   /**
@@ -330,7 +348,8 @@ public final class BytesToBytesMap {
 
     private void updateAddressesAndSizes(long fullKeyAddress) {
       updateAddressesAndSizes(
-        memoryManager.getPage(fullKeyAddress), memoryManager.getOffsetInPage(fullKeyAddress));
+        taskMemoryManager.getPage(fullKeyAddress),
+        taskMemoryManager.getOffsetInPage(fullKeyAddress));
     }
 
     private void updateAddressesAndSizes(Object page, long keyOffsetInPage) {
@@ -411,7 +430,8 @@ public final class BytesToBytesMap {
     /**
      * Store a new key and value. This method may only be called once for a given key; if you want
      * to update the value associated with a key, then you can directly manipulate the bytes stored
-     * at the value address.
+     * at the value address. The return value indicates whether the put succeeded or whether it
+     * failed because additional memory could not be acquired.
      * <p>
      * It is only valid to call this method immediately after calling `lookup()` using the same key.
      * </p>
@@ -428,14 +448,19 @@ public final class BytesToBytesMap {
      * <pre>
      *   Location loc = map.lookup(keyBaseObject, keyBaseOffset, keyLengthInBytes);
      *   if (!loc.isDefined()) {
-     *     loc.putNewKey(keyBaseObject, keyBaseOffset, keyLengthInBytes, ...)
+     *     if (!loc.putNewKey(keyBaseObject, keyBaseOffset, keyLengthInBytes, ...)) {
+     *       // handle failure to grow map (by spilling, for example)
+     *     }
      *   }
      * </pre>
      * <p>
      * Unspecified behavior if the key is not defined.
      * </p>
+     *
+     * @return true if the put() was successful and false if the put() failed because memory could
+     *         not be acquired.
      */
-    public void putNewKey(
+    public boolean putNewKey(
         Object keyBaseObject,
         long keyBaseOffset,
         int keyLengthBytes,
@@ -445,63 +470,110 @@ public final class BytesToBytesMap {
       assert (!isDefined) : "Can only set value once for a key";
       assert (keyLengthBytes % 8 == 0);
       assert (valueLengthBytes % 8 == 0);
-      if (size == MAX_CAPACITY) {
+      if (numElements == MAX_CAPACITY) {
         throw new IllegalStateException("BytesToBytesMap has reached maximum capacity");
       }
+
       // Here, we'll copy the data into our data pages. Because we only store a relative offset from
       // the key address instead of storing the absolute address of the value, the key and value
       // must be stored in the same memory page.
       // (8 byte key length) (key) (8 byte value length) (value)
       final long requiredSize = 8 + keyLengthBytes + 8 + valueLengthBytes;
-      assert (requiredSize <= pageSizeBytes - 8); // Reserve 8 bytes for the end-of-page marker.
-      size++;
-      bitset.set(pos);
 
-      // If there's not enough space in the current page, allocate a new page (8 bytes are reserved
-      // for the end-of-page marker).
-      if (currentDataPage == null || pageSizeBytes - 8 - pageCursor < requiredSize) {
+      // --- Figure out where to insert the new record ---------------------------------------------
+
+      final MemoryBlock dataPage;
+      final Object dataPageBaseObject;
+      final long dataPageInsertOffset;
+      boolean useOverflowPage = requiredSize > pageSizeBytes - 8;
+      if (useOverflowPage) {
+        // The record is larger than the page size, so allocate a special overflow page just to hold
+        // that record.
+        final long memoryRequested = requiredSize + 8;
+        final long memoryGranted = shuffleMemoryManager.tryToAcquire(memoryRequested);
+        if (memoryGranted != memoryRequested) {
+          shuffleMemoryManager.release(memoryGranted);
+          logger.debug("Failed to acquire {} bytes of memory", memoryRequested);
+          return false;
+        }
+        MemoryBlock overflowPage = taskMemoryManager.allocatePage(memoryRequested);
+        dataPages.add(overflowPage);
+        dataPage = overflowPage;
+        dataPageBaseObject = overflowPage.getBaseObject();
+        dataPageInsertOffset = overflowPage.getBaseOffset();
+      } else if (currentDataPage == null || pageSizeBytes - 8 - pageCursor < requiredSize) {
+        // The record can fit in a data page, but either we have not allocated any pages yet or
+        // the current page does not have enough space.
         if (currentDataPage != null) {
           // There wasn't enough space in the current page, so write an end-of-page marker:
           final Object pageBaseObject = currentDataPage.getBaseObject();
           final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor;
           PlatformDependent.UNSAFE.putLong(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER);
         }
-        MemoryBlock newPage = memoryManager.allocatePage(pageSizeBytes);
+        final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
+        if (memoryGranted != pageSizeBytes) {
+          shuffleMemoryManager.release(memoryGranted);
+          logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes);
+          return false;
+        }
+        MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes);
         dataPages.add(newPage);
         pageCursor = 0;
         currentDataPage = newPage;
+        dataPage = currentDataPage;
+        dataPageBaseObject = currentDataPage.getBaseObject();
+        dataPageInsertOffset = currentDataPage.getBaseOffset();
+      } else {
+        // There is enough space in the current data page.
+        dataPage = currentDataPage;
+        dataPageBaseObject = currentDataPage.getBaseObject();
+        dataPageInsertOffset = currentDataPage.getBaseOffset() + pageCursor;
       }
 
+      // --- Append the key and value data to the current data page --------------------------------
+
+      long insertCursor = dataPageInsertOffset;
+
       // Compute all of our offsets up-front:
-      final Object pageBaseObject = currentDataPage.getBaseObject();
-      final long pageBaseOffset = currentDataPage.getBaseOffset();
-      final long keySizeOffsetInPage = pageBaseOffset + pageCursor;
-      pageCursor += 8; // word used to store the key size
-      final long keyDataOffsetInPage = pageBaseOffset + pageCursor;
-      pageCursor += keyLengthBytes;
-      final long valueSizeOffsetInPage = pageBaseOffset + pageCursor;
-      pageCursor += 8; // word used to store the value size
-      final long valueDataOffsetInPage = pageBaseOffset + pageCursor;
-      pageCursor += valueLengthBytes;
+      final long keySizeOffsetInPage = insertCursor;
+      insertCursor += 8; // word used to store the key size
+      final long keyDataOffsetInPage = insertCursor;
+      insertCursor += keyLengthBytes;
+      final long valueSizeOffsetInPage = insertCursor;
+      insertCursor += 8; // word used to store the value size
+      final long valueDataOffsetInPage = insertCursor;
+      insertCursor += valueLengthBytes; // word used to store the value size
 
       // Copy the key
-      PlatformDependent.UNSAFE.putLong(pageBaseObject, keySizeOffsetInPage, keyLengthBytes);
+      PlatformDependent.UNSAFE.putLong(dataPageBaseObject, keySizeOffsetInPage, keyLengthBytes);
       PlatformDependent.copyMemory(
-        keyBaseObject, keyBaseOffset, pageBaseObject, keyDataOffsetInPage, keyLengthBytes);
+        keyBaseObject, keyBaseOffset, dataPageBaseObject, keyDataOffsetInPage, keyLengthBytes);
       // Copy the value
-      PlatformDependent.UNSAFE.putLong(pageBaseObject, valueSizeOffsetInPage, valueLengthBytes);
-      PlatformDependent.copyMemory(
-        valueBaseObject, valueBaseOffset, pageBaseObject, valueDataOffsetInPage, valueLengthBytes);
+      PlatformDependent.UNSAFE.putLong(dataPageBaseObject, valueSizeOffsetInPage, valueLengthBytes);
+      PlatformDependent.copyMemory(valueBaseObject, valueBaseOffset, dataPageBaseObject,
+        valueDataOffsetInPage, valueLengthBytes);
+
+      // --- Update bookeeping data structures -----------------------------------------------------
+
+      if (useOverflowPage) {
+        // Store the end-of-page marker at the end of the data page
+        PlatformDependent.UNSAFE.putLong(dataPageBaseObject, insertCursor, END_OF_PAGE_MARKER);
+      } else {
+        pageCursor += requiredSize;
+      }
 
-      final long storedKeyAddress = memoryManager.encodePageNumberAndOffset(
-        currentDataPage, keySizeOffsetInPage);
+      numElements++;
+      bitset.set(pos);
+      final long storedKeyAddress = taskMemoryManager.encodePageNumberAndOffset(
+        dataPage, keySizeOffsetInPage);
       longArray.set(pos * 2, storedKeyAddress);
       longArray.set(pos * 2 + 1, keyHashcode);
       updateAddressesAndSizes(storedKeyAddress);
       isDefined = true;
-      if (size > growthThreshold && longArray.size() < MAX_CAPACITY) {
+      if (numElements > growthThreshold && longArray.size() < MAX_CAPACITY) {
         growAndRehash();
       }
+      return true;
     }
   }
 
@@ -516,7 +588,7 @@ public final class BytesToBytesMap {
     // The capacity needs to be divisible by 64 so that our bit set can be sized properly
     capacity = Math.max((int) Math.min(MAX_CAPACITY, nextPowerOf2(capacity)), 64);
     assert (capacity <= MAX_CAPACITY);
-    longArray = new LongArray(memoryManager.allocate(capacity * 8L * 2));
+    longArray = new LongArray(MemoryBlock.fromLongArray(new long[capacity * 2]));
     bitset = new BitSet(MemoryBlock.fromLongArray(new long[capacity / 64]));
 
     this.growthThreshold = (int) (capacity * loadFactor);
@@ -530,18 +602,14 @@ public final class BytesToBytesMap {
    * This method is idempotent.
    */
   public void free() {
-    if (longArray != null) {
-      memoryManager.free(longArray.memoryBlock());
-      longArray = null;
-    }
-    if (bitset != null) {
-      // The bitset's heap memory isn't managed by a memory manager, so no need to free it here.
-      bitset = null;
-    }
+    longArray = null;
+    bitset = null;
     Iterator<MemoryBlock> dataPagesIterator = dataPages.iterator();
     while (dataPagesIterator.hasNext()) {
-      memoryManager.freePage(dataPagesIterator.next());
+      MemoryBlock dataPage = dataPagesIterator.next();
       dataPagesIterator.remove();
+      taskMemoryManager.freePage(dataPage);
+      shuffleMemoryManager.release(dataPage.size());
     }
     assert(dataPages.isEmpty());
   }
@@ -628,8 +696,6 @@ public final class BytesToBytesMap {
       }
     }
 
-    // Deallocate the old data structures.
-    memoryManager.free(oldLongArray.memoryBlock());
     if (enablePerfMetrics) {
       timeSpentResizingNs += System.nanoTime() - resizeStartTime;
     }
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java b/core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java
similarity index 100%
rename from unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java
rename to core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
index f038b722957b8abac85e4ab0ce56d305938578b1..00c1e078a441c75ccd51231c8e42c2d3b00af56f 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
@@ -85,7 +85,7 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
           return toGrant
         } else {
           logInfo(
-            s"Thread $taskAttemptId waiting for at least 1/2N of shuffle memory pool to be free")
+            s"TID $taskAttemptId waiting for at least 1/2N of shuffle memory pool to be free")
           wait()
         }
       } else {
@@ -116,6 +116,12 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
     taskMemory.remove(taskAttemptId)
     notifyAll()  // Notify waiters who locked "this" in tryToAcquire that memory has been freed
   }
+
+  /** Returns the memory consumption, in bytes, for the current task */
+  def getMemoryConsumptionForThisTask(): Long = synchronized {
+    val taskAttemptId = currentTaskAttemptId()
+    taskMemory.getOrElse(taskAttemptId, 0L)
+  }
 }
 
 private object ShuffleMemoryManager {
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
similarity index 64%
rename from unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
rename to core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index 0be94ad37125580d7d2203eaba2784696cf46d8d..60f483acbcb80581d60fbfb0cdd2154f05400ba6 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -21,15 +21,14 @@ import java.lang.Exception;
 import java.nio.ByteBuffer;
 import java.util.*;
 
-import org.junit.After;
-import org.junit.Assert;
-import org.junit.Before;
-import org.junit.Test;
+import org.junit.*;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
+import static org.hamcrest.Matchers.greaterThan;
 import static org.mockito.AdditionalMatchers.geq;
 import static org.mockito.Mockito.*;
 
+import org.apache.spark.shuffle.ShuffleMemoryManager;
 import org.apache.spark.unsafe.array.ByteArrayMethods;
 import org.apache.spark.unsafe.memory.*;
 import org.apache.spark.unsafe.PlatformDependent;
@@ -41,32 +40,39 @@ public abstract class AbstractBytesToBytesMapSuite {
 
   private final Random rand = new Random(42);
 
-  private TaskMemoryManager memoryManager;
-  private TaskMemoryManager sizeLimitedMemoryManager;
+  private ShuffleMemoryManager shuffleMemoryManager;
+  private TaskMemoryManager taskMemoryManager;
+  private TaskMemoryManager sizeLimitedTaskMemoryManager;
   private final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes
 
   @Before
   public void setup() {
-    memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(getMemoryAllocator()));
+    shuffleMemoryManager = new ShuffleMemoryManager(Long.MAX_VALUE);
+    taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(getMemoryAllocator()));
     // Mocked memory manager for tests that check the maximum array size, since actually allocating
     // such large arrays will cause us to run out of memory in our tests.
-    sizeLimitedMemoryManager = spy(memoryManager);
-    when(sizeLimitedMemoryManager.allocate(geq(1L << 20))).thenAnswer(new Answer<MemoryBlock>() {
-      @Override
-      public MemoryBlock answer(InvocationOnMock invocation) throws Throwable {
-        if (((Long) invocation.getArguments()[0] / 8) > Integer.MAX_VALUE) {
-          throw new OutOfMemoryError("Requested array size exceeds VM limit");
+    sizeLimitedTaskMemoryManager = mock(TaskMemoryManager.class);
+    when(sizeLimitedTaskMemoryManager.allocate(geq(1L << 20))).thenAnswer(
+      new Answer<MemoryBlock>() {
+        @Override
+        public MemoryBlock answer(InvocationOnMock invocation) throws Throwable {
+          if (((Long) invocation.getArguments()[0] / 8) > Integer.MAX_VALUE) {
+            throw new OutOfMemoryError("Requested array size exceeds VM limit");
+          }
+          return new MemoryBlock(null, 0, (Long) invocation.getArguments()[0]);
         }
-        return memoryManager.allocate(1L << 20);
       }
-    });
+    );
   }
 
   @After
   public void tearDown() {
-    if (memoryManager != null) {
-      memoryManager.cleanUpAllAllocatedMemory();
-      memoryManager = null;
+    if (taskMemoryManager != null) {
+      long leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask();
+      Assert.assertEquals(0, taskMemoryManager.cleanUpAllAllocatedMemory());
+      Assert.assertEquals(0, leakedShuffleMemory);
+      shuffleMemoryManager = null;
+      taskMemoryManager = null;
     }
   }
 
@@ -85,7 +91,7 @@ public abstract class AbstractBytesToBytesMapSuite {
   }
 
   private byte[] getRandomByteArray(int numWords) {
-    Assert.assertTrue(numWords > 0);
+    Assert.assertTrue(numWords >= 0);
     final int lengthInBytes = numWords * 8;
     final byte[] bytes = new byte[lengthInBytes];
     rand.nextBytes(bytes);
@@ -111,9 +117,10 @@ public abstract class AbstractBytesToBytesMapSuite {
 
   @Test
   public void emptyMap() {
-    BytesToBytesMap map = new BytesToBytesMap(memoryManager, 64, PAGE_SIZE_BYTES);
+    BytesToBytesMap map = new BytesToBytesMap(
+      taskMemoryManager, shuffleMemoryManager, 64, PAGE_SIZE_BYTES);
     try {
-      Assert.assertEquals(0, map.size());
+      Assert.assertEquals(0, map.numElements());
       final int keyLengthInWords = 10;
       final int keyLengthInBytes = keyLengthInWords * 8;
       final byte[] key = getRandomByteArray(keyLengthInWords);
@@ -126,7 +133,8 @@ public abstract class AbstractBytesToBytesMapSuite {
 
   @Test
   public void setAndRetrieveAKey() {
-    BytesToBytesMap map = new BytesToBytesMap(memoryManager, 64, PAGE_SIZE_BYTES);
+    BytesToBytesMap map = new BytesToBytesMap(
+      taskMemoryManager, shuffleMemoryManager, 64, PAGE_SIZE_BYTES);
     final int recordLengthWords = 10;
     final int recordLengthBytes = recordLengthWords * 8;
     final byte[] keyData = getRandomByteArray(recordLengthWords);
@@ -135,14 +143,14 @@ public abstract class AbstractBytesToBytesMapSuite {
       final BytesToBytesMap.Location loc =
         map.lookup(keyData, BYTE_ARRAY_OFFSET, recordLengthBytes);
       Assert.assertFalse(loc.isDefined());
-      loc.putNewKey(
+      Assert.assertTrue(loc.putNewKey(
         keyData,
         BYTE_ARRAY_OFFSET,
         recordLengthBytes,
         valueData,
         BYTE_ARRAY_OFFSET,
         recordLengthBytes
-      );
+      ));
       // After storing the key and value, the other location methods should return results that
       // reflect the result of this store without us having to call lookup() again on the same key.
       Assert.assertEquals(recordLengthBytes, loc.getKeyLength());
@@ -158,14 +166,14 @@ public abstract class AbstractBytesToBytesMapSuite {
       Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes));
 
       try {
-        loc.putNewKey(
+        Assert.assertTrue(loc.putNewKey(
           keyData,
           BYTE_ARRAY_OFFSET,
           recordLengthBytes,
           valueData,
           BYTE_ARRAY_OFFSET,
           recordLengthBytes
-        );
+        ));
         Assert.fail("Should not be able to set a new value for a key");
       } catch (AssertionError e) {
         // Expected exception; do nothing.
@@ -178,7 +186,8 @@ public abstract class AbstractBytesToBytesMapSuite {
   @Test
   public void iteratorTest() throws Exception {
     final int size = 4096;
-    BytesToBytesMap map = new BytesToBytesMap(memoryManager, size / 2, PAGE_SIZE_BYTES);
+    BytesToBytesMap map = new BytesToBytesMap(
+      taskMemoryManager, shuffleMemoryManager, size / 2, PAGE_SIZE_BYTES);
     try {
       for (long i = 0; i < size; i++) {
         final long[] value = new long[] { i };
@@ -187,23 +196,23 @@ public abstract class AbstractBytesToBytesMapSuite {
         Assert.assertFalse(loc.isDefined());
         // Ensure that we store some zero-length keys
         if (i % 5 == 0) {
-          loc.putNewKey(
+          Assert.assertTrue(loc.putNewKey(
             null,
             PlatformDependent.LONG_ARRAY_OFFSET,
             0,
             value,
             PlatformDependent.LONG_ARRAY_OFFSET,
             8
-          );
+          ));
         } else {
-          loc.putNewKey(
+          Assert.assertTrue(loc.putNewKey(
             value,
             PlatformDependent.LONG_ARRAY_OFFSET,
             8,
             value,
             PlatformDependent.LONG_ARRAY_OFFSET,
             8
-          );
+          ));
         }
       }
       final java.util.BitSet valuesSeen = new java.util.BitSet(size);
@@ -236,7 +245,8 @@ public abstract class AbstractBytesToBytesMapSuite {
     final int NUM_ENTRIES = 1000 * 1000;
     final int KEY_LENGTH = 16;
     final int VALUE_LENGTH = 40;
-    final BytesToBytesMap map = new BytesToBytesMap(memoryManager, NUM_ENTRIES, PAGE_SIZE_BYTES);
+    final BytesToBytesMap map = new BytesToBytesMap(
+      taskMemoryManager, shuffleMemoryManager, NUM_ENTRIES, PAGE_SIZE_BYTES);
     // Each record will take 8 + 8 + 16 + 40 = 72 bytes of space in the data page. Our 64-megabyte
     // pages won't be evenly-divisible by records of this size, which will cause us to waste some
     // space at the end of the page. This is necessary in order for us to take the end-of-record
@@ -251,14 +261,14 @@ public abstract class AbstractBytesToBytesMapSuite {
           KEY_LENGTH
         );
         Assert.assertFalse(loc.isDefined());
-        loc.putNewKey(
+        Assert.assertTrue(loc.putNewKey(
           key,
           LONG_ARRAY_OFFSET,
           KEY_LENGTH,
           value,
           LONG_ARRAY_OFFSET,
           VALUE_LENGTH
-        );
+        ));
       }
       Assert.assertEquals(2, map.getNumDataPages());
 
@@ -305,7 +315,8 @@ public abstract class AbstractBytesToBytesMapSuite {
     // Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays
     // into ByteBuffers in order to use them as keys here.
     final Map<ByteBuffer, byte[]> expected = new HashMap<ByteBuffer, byte[]>();
-    final BytesToBytesMap map = new BytesToBytesMap(memoryManager, size, PAGE_SIZE_BYTES);
+    final BytesToBytesMap map = new BytesToBytesMap(
+      taskMemoryManager, shuffleMemoryManager, size, PAGE_SIZE_BYTES);
 
     try {
       // Fill the map to 90% full so that we can trigger probing
@@ -320,14 +331,14 @@ public abstract class AbstractBytesToBytesMapSuite {
             key.length
           );
           Assert.assertFalse(loc.isDefined());
-          loc.putNewKey(
+          Assert.assertTrue(loc.putNewKey(
             key,
             BYTE_ARRAY_OFFSET,
             key.length,
             value,
             BYTE_ARRAY_OFFSET,
             value.length
-          );
+          ));
           // After calling putNewKey, the following should be true, even before calling
           // lookup():
           Assert.assertTrue(loc.isDefined());
@@ -351,10 +362,102 @@ public abstract class AbstractBytesToBytesMapSuite {
     }
   }
 
+  @Test
+  public void randomizedTestWithRecordsLargerThanPageSize() {
+    final long pageSizeBytes = 128;
+    final BytesToBytesMap map = new BytesToBytesMap(
+      taskMemoryManager, shuffleMemoryManager, 64, pageSizeBytes);
+    // Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays
+    // into ByteBuffers in order to use them as keys here.
+    final Map<ByteBuffer, byte[]> expected = new HashMap<ByteBuffer, byte[]>();
+    try {
+      for (int i = 0; i < 1000; i++) {
+        final byte[] key = getRandomByteArray(rand.nextInt(128));
+        final byte[] value = getRandomByteArray(rand.nextInt(128));
+        if (!expected.containsKey(ByteBuffer.wrap(key))) {
+          expected.put(ByteBuffer.wrap(key), value);
+          final BytesToBytesMap.Location loc = map.lookup(
+            key,
+            BYTE_ARRAY_OFFSET,
+            key.length
+          );
+          Assert.assertFalse(loc.isDefined());
+          Assert.assertTrue(loc.putNewKey(
+            key,
+            BYTE_ARRAY_OFFSET,
+            key.length,
+            value,
+            BYTE_ARRAY_OFFSET,
+            value.length
+          ));
+          // After calling putNewKey, the following should be true, even before calling
+          // lookup():
+          Assert.assertTrue(loc.isDefined());
+          Assert.assertEquals(key.length, loc.getKeyLength());
+          Assert.assertEquals(value.length, loc.getValueLength());
+          Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), key.length));
+          Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), value.length));
+        }
+      }
+      for (Map.Entry<ByteBuffer, byte[]> entry : expected.entrySet()) {
+        final byte[] key = entry.getKey().array();
+        final byte[] value = entry.getValue();
+        final BytesToBytesMap.Location loc = map.lookup(key, BYTE_ARRAY_OFFSET, key.length);
+        Assert.assertTrue(loc.isDefined());
+        Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength()));
+        Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength()));
+      }
+    } finally {
+      map.free();
+    }
+  }
+
+  @Test
+  public void failureToAllocateFirstPage() {
+    shuffleMemoryManager = new ShuffleMemoryManager(1024);
+    BytesToBytesMap map =
+      new BytesToBytesMap(taskMemoryManager, shuffleMemoryManager, 1, PAGE_SIZE_BYTES);
+    try {
+      final long[] emptyArray = new long[0];
+      final BytesToBytesMap.Location loc =
+        map.lookup(emptyArray, PlatformDependent.LONG_ARRAY_OFFSET, 0);
+      Assert.assertFalse(loc.isDefined());
+      Assert.assertFalse(loc.putNewKey(
+        emptyArray, LONG_ARRAY_OFFSET, 0,
+        emptyArray, LONG_ARRAY_OFFSET, 0
+      ));
+    } finally {
+      map.free();
+    }
+  }
+
+
+  @Test
+  public void failureToGrow() {
+    shuffleMemoryManager = new ShuffleMemoryManager(1024 * 10);
+    BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, shuffleMemoryManager, 1, 1024);
+    try {
+      boolean success = true;
+      int i;
+      for (i = 0; i < 1024; i++) {
+        final long[] arr = new long[]{i};
+        final BytesToBytesMap.Location loc = map.lookup(arr, PlatformDependent.LONG_ARRAY_OFFSET, 8);
+        success = loc.putNewKey(arr, LONG_ARRAY_OFFSET, 8, arr, LONG_ARRAY_OFFSET, 8);
+        if (!success) {
+          break;
+        }
+      }
+      Assert.assertThat(i, greaterThan(0));
+      Assert.assertFalse(success);
+    } finally {
+      map.free();
+    }
+  }
+
   @Test
   public void initialCapacityBoundsChecking() {
     try {
-      new BytesToBytesMap(sizeLimitedMemoryManager, 0, PAGE_SIZE_BYTES);
+      new BytesToBytesMap(sizeLimitedTaskMemoryManager, shuffleMemoryManager, 0, PAGE_SIZE_BYTES);
       Assert.fail("Expected IllegalArgumentException to be thrown");
     } catch (IllegalArgumentException e) {
       // expected exception
@@ -362,23 +465,34 @@ public abstract class AbstractBytesToBytesMapSuite {
 
     try {
       new BytesToBytesMap(
-        sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY + 1, PAGE_SIZE_BYTES);
+        sizeLimitedTaskMemoryManager,
+        shuffleMemoryManager,
+        BytesToBytesMap.MAX_CAPACITY + 1,
+        PAGE_SIZE_BYTES);
       Assert.fail("Expected IllegalArgumentException to be thrown");
     } catch (IllegalArgumentException e) {
       // expected exception
     }
 
-   // Can allocate _at_ the max capacity
-    BytesToBytesMap map =
-      new BytesToBytesMap(sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY, PAGE_SIZE_BYTES);
-    map.free();
+    // Ignored because this can OOM now that we allocate the long array w/o a TaskMemoryManager
+    // Can allocate _at_ the max capacity
+    //    BytesToBytesMap map = new BytesToBytesMap(
+    //      sizeLimitedTaskMemoryManager,
+    //      shuffleMemoryManager,
+    //      BytesToBytesMap.MAX_CAPACITY,
+    //      PAGE_SIZE_BYTES);
+    //    map.free();
   }
 
-  @Test
+  // Ignored because this can OOM now that we allocate the long array w/o a TaskMemoryManager
+  @Ignore
   public void resizingLargeMap() {
     // As long as a map's capacity is below the max, we should be able to resize up to the max
     BytesToBytesMap map = new BytesToBytesMap(
-      sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY - 64, PAGE_SIZE_BYTES);
+      sizeLimitedTaskMemoryManager,
+      shuffleMemoryManager,
+      BytesToBytesMap.MAX_CAPACITY - 64,
+      PAGE_SIZE_BYTES);
     map.growAndRehash();
     map.free();
   }
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java
similarity index 100%
rename from unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java
rename to core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java
similarity index 100%
rename from unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java
rename to core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
similarity index 89%
rename from sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
rename to sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index f3b462778dc103b4fceda1bad0ddb0454260728e..66012e3c94d27e0fd2a3740f489fc84af9b2efd4 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -15,11 +15,15 @@
  * limitations under the License.
  */
 
-package org.apache.spark.sql.catalyst.expressions;
+package org.apache.spark.sql.execution;
 
+import java.io.IOException;
 import java.util.Iterator;
 
+import org.apache.spark.shuffle.ShuffleMemoryManager;
 import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
 import org.apache.spark.sql.types.Decimal;
 import org.apache.spark.sql.types.DecimalType;
 import org.apache.spark.sql.types.StructField;
@@ -87,7 +91,9 @@ public final class UnsafeFixedWidthAggregationMap {
    * @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function)
    * @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion.
    * @param groupingKeySchema the schema of the grouping key, used for row conversion.
-   * @param memoryManager the memory manager used to allocate our Unsafe memory structures.
+   * @param taskMemoryManager the memory manager used to allocate our Unsafe memory structures.
+   * @param shuffleMemoryManager the shuffle memory manager, for coordinating our memory usage with
+   *                             other tasks.
    * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing).
    * @param pageSizeBytes the data page size, in bytes; limits the maximum record size.
    * @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact)
@@ -96,15 +102,16 @@ public final class UnsafeFixedWidthAggregationMap {
       InternalRow emptyAggregationBuffer,
       StructType aggregationBufferSchema,
       StructType groupingKeySchema,
-      TaskMemoryManager memoryManager,
+      TaskMemoryManager taskMemoryManager,
+      ShuffleMemoryManager shuffleMemoryManager,
       int initialCapacity,
       long pageSizeBytes,
       boolean enablePerfMetrics) {
     this.aggregationBufferSchema = aggregationBufferSchema;
     this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema);
     this.groupingKeySchema = groupingKeySchema;
-    this.map =
-      new BytesToBytesMap(memoryManager, initialCapacity, pageSizeBytes, enablePerfMetrics);
+    this.map = new BytesToBytesMap(
+      taskMemoryManager, shuffleMemoryManager, initialCapacity, pageSizeBytes, enablePerfMetrics);
     this.enablePerfMetrics = enablePerfMetrics;
 
     // Initialize the buffer for aggregation value
@@ -116,7 +123,8 @@ public final class UnsafeFixedWidthAggregationMap {
 
   /**
    * Return the aggregation buffer for the current group. For efficiency, all calls to this method
-   * return the same object.
+   * return the same object. If additional memory could not be allocated, then this method will
+   * signal an error by returning null.
    */
   public UnsafeRow getAggregationBuffer(InternalRow groupingKey) {
     final UnsafeRow unsafeGroupingKeyRow = this.groupingKeyProjection.apply(groupingKey);
@@ -129,7 +137,7 @@ public final class UnsafeFixedWidthAggregationMap {
     if (!loc.isDefined()) {
       // This is the first time that we've seen this grouping key, so we'll insert a copy of the
       // empty aggregation buffer into the map:
-      loc.putNewKey(
+      boolean putSucceeded = loc.putNewKey(
         unsafeGroupingKeyRow.getBaseObject(),
         unsafeGroupingKeyRow.getBaseOffset(),
         unsafeGroupingKeyRow.getSizeInBytes(),
@@ -137,6 +145,9 @@ public final class UnsafeFixedWidthAggregationMap {
         PlatformDependent.BYTE_ARRAY_OFFSET,
         emptyAggregationBuffer.length
       );
+      if (!putSucceeded) {
+        return null;
+      }
     }
 
     // Reset the pointer to point to the value that we just stored or looked up:
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index d851eae3fcc71c437fa42fba934c8fbb617f1af2..469de6ca8e1012d109ec72742de1e7f951a7235b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.execution
 
+import java.io.IOException
+
 import org.apache.spark.{SparkEnv, TaskContext}
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.rdd.RDD
@@ -266,6 +268,7 @@ case class GeneratedAggregate(
           aggregationBufferSchema,
           groupKeySchema,
           TaskContext.get.taskMemoryManager(),
+          SparkEnv.get.shuffleMemoryManager,
           1024 * 16, // initial capacity
           pageSizeBytes,
           false // disable tracking of performance metrics
@@ -275,6 +278,9 @@ case class GeneratedAggregate(
           val currentRow: InternalRow = iter.next()
           val groupKey: InternalRow = groupProjection(currentRow)
           val aggregationBuffer = aggregationMap.getAggregationBuffer(groupKey)
+          if (aggregationBuffer == null) {
+            throw new IOException("Could not allocate memory to grow aggregation buffer")
+          }
           updateProjection.target(aggregationBuffer)(joinedRow(aggregationBuffer, currentRow))
         }
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index f88a45f48aee9e3c9f76aead2cc2ac5d2567b59d..cc8bbfd2f89438f7d080bca0ec8ff8d6d9fbe7f6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -17,10 +17,11 @@
 
 package org.apache.spark.sql.execution.joins
 
-import java.io.{Externalizable, ObjectInput, ObjectOutput}
+import java.io.{IOException, Externalizable, ObjectInput, ObjectOutput}
 import java.nio.ByteOrder
 import java.util.{HashMap => JavaHashMap}
 
+import org.apache.spark.shuffle.ShuffleMemoryManager
 import org.apache.spark.{SparkConf, SparkEnv, TaskContext}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
@@ -28,6 +29,7 @@ import org.apache.spark.sql.execution.SparkSqlSerializer
 import org.apache.spark.unsafe.PlatformDependent
 import org.apache.spark.unsafe.map.BytesToBytesMap
 import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
+import org.apache.spark.util.Utils
 import org.apache.spark.util.collection.CompactBuffer
 
 
@@ -217,7 +219,7 @@ private[joins] final class UnsafeHashedRelation(
     }
   }
 
-  override def writeExternal(out: ObjectOutput): Unit = {
+  override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
     out.writeInt(hashTable.size())
 
     val iter = hashTable.entrySet().iterator()
@@ -256,16 +258,26 @@ private[joins] final class UnsafeHashedRelation(
     }
   }
 
-  override def readExternal(in: ObjectInput): Unit = {
+  override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
     val nKeys = in.readInt()
     // This is used in Broadcast, shared by multiple tasks, so we use on-heap memory
-    val memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
+    val taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
+
+    // Dummy shuffle memory manager which always grants all memory allocation requests.
+    // We use this because it doesn't make sense count shared broadcast variables' memory usage
+    // towards individual tasks' quotas. In the future, we should devise a better way of handling
+    // this.
+    val shuffleMemoryManager = new ShuffleMemoryManager(new SparkConf()) {
+      override def tryToAcquire(numBytes: Long): Long = numBytes
+      override def release(numBytes: Long): Unit = {}
+    }
 
     val pageSizeBytes = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
       .getSizeAsBytes("spark.buffer.pageSize", "64m")
 
     binaryMap = new BytesToBytesMap(
-      memoryManager,
+      taskMemoryManager,
+      shuffleMemoryManager,
       nKeys * 2, // reduce hash collision
       pageSizeBytes)
 
@@ -287,8 +299,11 @@ private[joins] final class UnsafeHashedRelation(
       // put it into binary map
       val loc = binaryMap.lookup(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize)
       assert(!loc.isDefined, "Duplicated key found!")
-      loc.putNewKey(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize,
+      val putSuceeded = loc.putNewKey(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize,
         valuesBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, valuesSize)
+      if (!putSuceeded) {
+        throw new IOException("Could not allocate memory to grow BytesToBytesMap")
+      }
       i += 1
     }
   }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
similarity index 78%
rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
rename to sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
index c6b4c729de2f96d04c38a9ca3d7933d67d32ef28..79fd52dacda52d6b03da39baca5ba4f2d3c35fde 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
@@ -15,17 +15,18 @@
  * limitations under the License.
  */
 
-package org.apache.spark.sql.catalyst.expressions
+package org.apache.spark.sql.execution
+
+import org.scalatest.{BeforeAndAfterEach, Matchers}
 
 import scala.collection.JavaConverters._
 import scala.util.Random
 
-import org.scalatest.{BeforeAndAfterEach, Matchers}
-
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.shuffle.ShuffleMemoryManager
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, TaskMemoryManager, MemoryAllocator}
+import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
 import org.apache.spark.unsafe.types.UTF8String
 
 
@@ -41,16 +42,20 @@ class UnsafeFixedWidthAggregationMapSuite
   private def emptyAggregationBuffer: InternalRow = InternalRow(0)
   private val PAGE_SIZE_BYTES: Long = 1L << 26; // 64 megabytes
 
-  private var memoryManager: TaskMemoryManager = null
+  private var taskMemoryManager: TaskMemoryManager = null
+  private var shuffleMemoryManager: ShuffleMemoryManager = null
 
   override def beforeEach(): Unit = {
-    memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
+    taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
+    shuffleMemoryManager = new ShuffleMemoryManager(Long.MaxValue)
   }
 
   override def afterEach(): Unit = {
-    if (memoryManager != null) {
-      memoryManager.cleanUpAllAllocatedMemory()
-      memoryManager = null
+    if (taskMemoryManager != null) {
+      val leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask
+      assert(taskMemoryManager.cleanUpAllAllocatedMemory() === 0)
+      assert(leakedShuffleMemory === 0)
+      taskMemoryManager = null
     }
   }
 
@@ -69,7 +74,8 @@ class UnsafeFixedWidthAggregationMapSuite
       emptyAggregationBuffer,
       aggBufferSchema,
       groupKeySchema,
-      memoryManager,
+      taskMemoryManager,
+      shuffleMemoryManager,
       1024, // initial capacity,
       PAGE_SIZE_BYTES,
       false // disable perf metrics
@@ -83,7 +89,8 @@ class UnsafeFixedWidthAggregationMapSuite
       emptyAggregationBuffer,
       aggBufferSchema,
       groupKeySchema,
-      memoryManager,
+      taskMemoryManager,
+      shuffleMemoryManager,
       1024, // initial capacity
       PAGE_SIZE_BYTES,
       false // disable perf metrics
@@ -91,7 +98,7 @@ class UnsafeFixedWidthAggregationMapSuite
     val groupKey = InternalRow(UTF8String.fromString("cats"))
 
     // Looking up a key stores a zero-entry in the map (like Python Counters or DefaultDicts)
-    map.getAggregationBuffer(groupKey)
+    assert(map.getAggregationBuffer(groupKey) != null)
     val iter = map.iterator()
     val entry = iter.next()
     assert(!iter.hasNext)
@@ -110,7 +117,8 @@ class UnsafeFixedWidthAggregationMapSuite
       emptyAggregationBuffer,
       aggBufferSchema,
       groupKeySchema,
-      memoryManager,
+      taskMemoryManager,
+      shuffleMemoryManager,
       128, // initial capacity
       PAGE_SIZE_BYTES,
       false // disable perf metrics
@@ -118,7 +126,7 @@ class UnsafeFixedWidthAggregationMapSuite
     val rand = new Random(42)
     val groupKeys: Set[String] = Seq.fill(512)(rand.nextString(1024)).toSet
     groupKeys.foreach { keyString =>
-      map.getAggregationBuffer(InternalRow(UTF8String.fromString(keyString)))
+      assert(map.getAggregationBuffer(InternalRow(UTF8String.fromString(keyString))) != null)
     }
     val seenKeys: Set[String] = map.iterator().asScala.map { entry =>
       entry.key.getString(0)
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
index 3dc82d8c2eb399ac86571126a29724c0d994c914..91be46ba21ff81b7c9fa29d2c4eef0a35ffe2ffc 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
@@ -34,7 +34,7 @@ public class MemoryBlock extends MemoryLocation {
    */
   int pageNumber = -1;
 
-  MemoryBlock(@Nullable Object obj, long offset, long length) {
+  public MemoryBlock(@Nullable Object obj, long offset, long length) {
     super(obj, offset);
     this.length = length;
   }
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
index dd70df3b1f791c2e1d08736e1526819af94472ae..358bb372501582831c9b4b6a5edaf850b530777d 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
@@ -164,6 +164,7 @@ public class TaskMemoryManager {
    * top-level Javadoc for more details).
    */
   public MemoryBlock allocate(long size) throws OutOfMemoryError {
+    assert(size > 0) : "Size must be positive, but got " + size;
     final MemoryBlock memory = executorMemoryManager.allocate(size);
     allocatedNonPageMemory.add(memory);
     return memory;