diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
index d2a88864f7ac9e026ba0eb44736a50b8937ad43d..8757dff36f15925ddc269cad90228bba7cb506c7 100644
--- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
+++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
@@ -111,6 +111,11 @@ public class TaskMemoryManager {
   @GuardedBy("this")
   private final HashSet<MemoryConsumer> consumers;
 
+  /**
+   * The amount of memory that is acquired but not used.
+   */
+  private long acquiredButNotUsed = 0L;
+
   /**
    * Construct a new TaskMemoryManager.
    */
@@ -256,7 +261,20 @@ public class TaskMemoryManager {
       }
       allocatedPages.set(pageNumber);
     }
-    final MemoryBlock page = memoryManager.tungstenMemoryAllocator().allocate(acquired);
+    MemoryBlock page = null;
+    try {
+      page = memoryManager.tungstenMemoryAllocator().allocate(acquired);
+    } catch (OutOfMemoryError e) {
+      logger.warn("Failed to allocate a page ({} bytes), try again.", acquired);
+      // there is no enough memory actually, it means the actual free memory is smaller than
+      // MemoryManager thought, we should keep the acquired memory.
+      acquiredButNotUsed += acquired;
+      synchronized (this) {
+        allocatedPages.clear(pageNumber);
+      }
+      // this could trigger spilling to free some pages.
+      return allocatePage(size, consumer);
+    }
     page.pageNumber = pageNumber;
     pageTable[pageNumber] = page;
     if (logger.isTraceEnabled()) {
@@ -378,6 +396,9 @@ public class TaskMemoryManager {
     }
     Arrays.fill(pageTable, null);
 
+    // release the memory that is not used by any consumer.
+    memoryManager.releaseExecutionMemory(acquiredButNotUsed, taskAttemptId, tungstenMemoryMode);
+
     return memoryManager.releaseAllExecutionMemoryForTask(taskAttemptId);
   }
 
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
index 2c84de5bf2a5a38050526ebf626b8ce6d3e4a336..f97e76d7ed0d994432467f528685c746bc037fae 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
@@ -320,15 +320,7 @@ final class ShuffleExternalSorter extends MemoryConsumer {
     assert(inMemSorter != null);
     if (!inMemSorter.hasSpaceForAnotherRecord()) {
       long used = inMemSorter.getMemoryUsage();
-      LongArray array;
-      try {
-        // could trigger spilling
-        array = allocateArray(used / 8 * 2);
-      } catch (OutOfMemoryError e) {
-        // should have trigger spilling
-        assert(inMemSorter.hasSpaceForAnotherRecord());
-        return;
-      }
+      LongArray array = allocateArray(used / 8 * 2);
       // check if spilling is triggered or not
       if (inMemSorter.hasSpaceForAnotherRecord()) {
         freeArray(array);
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
index 58ad88e1ed87bde0a5a5b5d913923aa182410dec..d74602cd205add869d49abb2e52afdd5cc61e747 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
@@ -104,7 +104,7 @@ final class ShuffleInMemorySorter {
    */
   public void insertRecord(long recordPointer, int partitionId) {
     if (!hasSpaceForAnotherRecord()) {
-      expandPointerArray(consumer.allocateArray(array.size() * 2));
+      throw new IllegalStateException("There is no space for new record");
     }
     array.set(pos, PackedRecordPointer.packPointer(recordPointer, partitionId));
     pos++;
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index a6edc1ad3f66560bbaf476c09d22ad63495413ae..296bf722fc178360cefdf4c55b04f59757107203 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -293,15 +293,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
     assert(inMemSorter != null);
     if (!inMemSorter.hasSpaceForAnotherRecord()) {
       long used = inMemSorter.getMemoryUsage();
-      LongArray array;
-      try {
-        // could trigger spilling
-        array = allocateArray(used / 8 * 2);
-      } catch (OutOfMemoryError e) {
-        // should have trigger spilling
-        assert(inMemSorter.hasSpaceForAnotherRecord());
-        return;
-      }
+      LongArray array = allocateArray(used / 8 * 2);
       // check if spilling is triggered or not
       if (inMemSorter.hasSpaceForAnotherRecord()) {
         freeArray(array);
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index d1b0bc5d11f460ed2928f06c40fc1a107efe526b..cea0f0a0c6c11a38e5cc5a8a6aef8cf15580d60d 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -164,7 +164,7 @@ public final class UnsafeInMemorySorter {
    */
   public void insertRecord(long recordPointer, long keyPrefix) {
     if (!hasSpaceForAnotherRecord()) {
-      expandPointerArray(consumer.allocateArray(array.size() * 2));
+      throw new IllegalStateException("There is no space for new record");
     }
     array.set(pos, recordPointer);
     pos++;
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
index 0328e63e4543939174e30948ea63f44006979a2d..eb1da8e1b43eb75aeb6ea3a7c36a349afe22daa0 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
@@ -75,6 +75,9 @@ public class ShuffleInMemorySorterSuite {
     // Write the records into the data page and store pointers into the sorter
     long position = dataPage.getBaseOffset();
     for (String str : dataToSort) {
+      if (!sorter.hasSpaceForAnotherRecord()) {
+        sorter.expandPointerArray(consumer.allocateArray(sorter.numRecords() * 2));
+      }
       final long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, position);
       final byte[] strBytes = str.getBytes("utf-8");
       Platform.putInt(baseObject, position, strBytes.length);
@@ -114,6 +117,9 @@ public class ShuffleInMemorySorterSuite {
     int[] numbersToSort = new int[128000];
     Random random = new Random(16);
     for (int i = 0; i < numbersToSort.length; i++) {
+      if (!sorter.hasSpaceForAnotherRecord()) {
+        sorter.expandPointerArray(consumer.allocateArray(sorter.numRecords() * 2));
+      }
       numbersToSort[i] = random.nextInt(PackedRecordPointer.MAXIMUM_PARTITION_ID + 1);
       sorter.insertRecord(0, numbersToSort[i]);
     }
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
index 93efd033eb940bfe074fb3f07e4847e1596a6e68..8e557ec0ab0b42feb76432b69b5a71761140533a 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
@@ -111,6 +111,9 @@ public class UnsafeInMemorySorterSuite {
     // Given a page of records, insert those records into the sorter one-by-one:
     position = dataPage.getBaseOffset();
     for (int i = 0; i < dataToSort.length; i++) {
+      if (!sorter.hasSpaceForAnotherRecord()) {
+        sorter.expandPointerArray(consumer.allocateArray(sorter.numRecords() * 2 * 2));
+      }
       // position now points to the start of a record (which holds its length).
       final int recordLength = Platform.getInt(baseObject, position);
       final long address = memoryManager.encodePageNumberAndOffset(dataPage, position);