diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index 07241c827c2ae5e5017a2b6c3bbb999a668be07f..6656fd1d0bc5973e985820bb9928ade33a697df8 100644
--- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -20,6 +20,7 @@ package org.apache.spark.unsafe.map;
 import javax.annotation.Nullable;
 import java.io.File;
 import java.io.IOException;
+import java.util.Arrays;
 import java.util.Iterator;
 import java.util.LinkedList;
 
@@ -638,7 +639,11 @@ public final class BytesToBytesMap extends MemoryConsumer {
       assert (valueLength % 8 == 0);
       assert(longArray != null);
 
-      if (numElements == MAX_CAPACITY || !canGrowArray) {
+
+      if (numElements == MAX_CAPACITY
+        // The map could be reused from last spill (because of no enough memory to grow),
+        // then we don't try to grow again if hit the `growthThreshold`.
+        || !canGrowArray && numElements > growthThreshold) {
         return false;
       }
 
@@ -730,25 +735,18 @@ public final class BytesToBytesMap extends MemoryConsumer {
   }
 
   /**
-   * Free the memory used by longArray.
+   * Free all allocated memory associated with this map, including the storage for keys and values
+   * as well as the hash map array itself.
+   *
+   * This method is idempotent and can be called multiple times.
    */
-  public void freeArray() {
+  public void free() {
     updatePeakMemoryUsed();
     if (longArray != null) {
       long used = longArray.memoryBlock().size();
       longArray = null;
       releaseMemory(used);
     }
-  }
-
-  /**
-   * Free all allocated memory associated with this map, including the storage for keys and values
-   * as well as the hash map array itself.
-   *
-   * This method is idempotent and can be called multiple times.
-   */
-  public void free() {
-    freeArray();
     Iterator<MemoryBlock> dataPagesIterator = dataPages.iterator();
     while (dataPagesIterator.hasNext()) {
       MemoryBlock dataPage = dataPagesIterator.next();
@@ -833,6 +831,28 @@ public final class BytesToBytesMap extends MemoryConsumer {
     return dataPages.size();
   }
 
+  /**
+   * Returns the underline long[] of longArray.
+   */
+  public long[] getArray() {
+    assert(longArray != null);
+    return (long[]) longArray.memoryBlock().getBaseObject();
+  }
+
+  /**
+   * Reset this map to initialized state.
+   */
+  public void reset() {
+    numElements = 0;
+    Arrays.fill(getArray(), 0);
+    while (dataPages.size() > 0) {
+      MemoryBlock dataPage = dataPages.removeLast();
+      freePage(dataPage);
+    }
+    currentPage = null;
+    pageCursor = 0;
+  }
+
   /**
    * Grows the size of the hash table and re-hash everything.
    */
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index 509fb0a044c0c43e220f5cb03e176093bda1c576..cba043bc48cc877a382c0b89a1a7a9fd109f5aa9 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -79,9 +79,13 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
       PrefixComparator prefixComparator,
       int initialSize,
       long pageSizeBytes,
-      UnsafeInMemorySorter inMemorySorter) {
-    return new UnsafeExternalSorter(taskMemoryManager, blockManager,
+      UnsafeInMemorySorter inMemorySorter) throws IOException {
+    UnsafeExternalSorter sorter = new UnsafeExternalSorter(taskMemoryManager, blockManager,
       taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, inMemorySorter);
+    sorter.spill(Long.MAX_VALUE, sorter);
+    // The external sorter will be used to insert records, in-memory sorter is not needed.
+    sorter.inMemSorter = null;
+    return sorter;
   }
 
   public static UnsafeExternalSorter create(
@@ -124,7 +128,6 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
       acquireMemory(inMemSorter.getMemoryUsage());
     } else {
       this.inMemSorter = existingInMemorySorter;
-      // will acquire after free the map
     }
     this.peakMemoryUsedBytes = getMemoryUsage();
 
@@ -157,12 +160,9 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
    */
   @Override
   public long spill(long size, MemoryConsumer trigger) throws IOException {
-    assert(inMemSorter != null);
     if (trigger != this) {
       if (readingIterator != null) {
         return readingIterator.spill();
-      } else {
-
       }
       return 0L; // this should throw exception
     }
@@ -388,25 +388,38 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
     inMemSorter.insertRecord(recordAddress, prefix);
   }
 
+  /**
+   * Merges another UnsafeExternalSorters into this one, the other one will be emptied.
+   *
+   * @throws IOException
+   */
+  public void merge(UnsafeExternalSorter other) throws IOException {
+    other.spill();
+    spillWriters.addAll(other.spillWriters);
+    // remove them from `spillWriters`, or the files will be deleted in `cleanupResources`.
+    other.spillWriters.clear();
+    other.cleanupResources();
+  }
+
   /**
    * Returns a sorted iterator. It is the caller's responsibility to call `cleanupResources()`
    * after consuming this iterator.
    */
   public UnsafeSorterIterator getSortedIterator() throws IOException {
-    assert(inMemSorter != null);
-    readingIterator = new SpillableIterator(inMemSorter.getSortedIterator());
-    int numIteratorsToMerge = spillWriters.size() + (readingIterator.hasNext() ? 1 : 0);
     if (spillWriters.isEmpty()) {
+      assert(inMemSorter != null);
+      readingIterator = new SpillableIterator(inMemSorter.getSortedIterator());
       return readingIterator;
     } else {
       final UnsafeSorterSpillMerger spillMerger =
-        new UnsafeSorterSpillMerger(recordComparator, prefixComparator, numIteratorsToMerge);
+        new UnsafeSorterSpillMerger(recordComparator, prefixComparator, spillWriters.size());
       for (UnsafeSorterSpillWriter spillWriter : spillWriters) {
         spillMerger.addSpillIfNotEmpty(spillWriter.getReader(blockManager));
       }
-      spillWriters.clear();
-      spillMerger.addSpillIfNotEmpty(readingIterator);
-
+      if (inMemSorter != null) {
+        readingIterator = new SpillableIterator(inMemSorter.getSortedIterator());
+        spillMerger.addSpillIfNotEmpty(readingIterator);
+      }
       return spillMerger.getSortedIterator();
     }
   }
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index 1480f0681ed9ca9e60772a2b7c984243cec37464..d57213b9b8bfc839208da0fc1ee51119ea714d61 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -19,9 +19,9 @@ package org.apache.spark.util.collection.unsafe.sort;
 
 import java.util.Comparator;
 
+import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.unsafe.Platform;
 import org.apache.spark.util.collection.Sorter;
-import org.apache.spark.memory.TaskMemoryManager;
 
 /**
  * Sorts records using an AlphaSort-style key-prefix sort. This sort stores pointers to records
@@ -77,13 +77,20 @@ public final class UnsafeInMemorySorter {
    */
   private int pos = 0;
 
+  public UnsafeInMemorySorter(
+    final TaskMemoryManager memoryManager,
+    final RecordComparator recordComparator,
+    final PrefixComparator prefixComparator,
+    int initialSize) {
+    this(memoryManager, recordComparator, prefixComparator, new long[initialSize * 2]);
+  }
+
   public UnsafeInMemorySorter(
       final TaskMemoryManager memoryManager,
       final RecordComparator recordComparator,
       final PrefixComparator prefixComparator,
-      int initialSize) {
-    assert (initialSize > 0);
-    this.array = new long[initialSize * 2];
+      long[] array) {
+    this.array = array;
     this.memoryManager = memoryManager;
     this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE);
     this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager);
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index d4b6d75b4d981b0d24a8755e42fd13b60c229eb3..a2f99d566d4711caef021ddbe9230d6dd6b68283 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -236,16 +236,13 @@ public final class UnsafeFixedWidthAggregationMap {
 
   /**
    * Sorts the map's records in place, spill them to disk, and returns an [[UnsafeKVExternalSorter]]
-   * that can be used to insert more records to do external sorting.
    *
-   * The only memory that is allocated is the address/prefix array, 16 bytes per record.
-   *
-   * Note that this destroys the map, and as a result, the map cannot be used anymore after this.
+   * Note that the map will be reset for inserting new records, and the returned sorter can NOT be used
+   * to insert records.
    */
   public UnsafeKVExternalSorter destructAndCreateExternalSorter() throws IOException {
-    UnsafeKVExternalSorter sorter = new UnsafeKVExternalSorter(
+    return new UnsafeKVExternalSorter(
       groupingKeySchema, aggregationBufferSchema,
       SparkEnv.get().blockManager(), map.getPageSizeBytes(), map);
-    return sorter;
   }
 }
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
index 845f2ae6859b70fca0e98001588c771152d61ec0..e2898ef2e215839823cfb99318ff95543c209b72 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
@@ -83,11 +83,10 @@ public final class UnsafeKVExternalSorter {
         /* initialSize */ 4096,
         pageSizeBytes);
     } else {
-      // The memory needed for UnsafeInMemorySorter should be less than longArray in map.
-      map.freeArray();
-      // The memory used by UnsafeInMemorySorter will be counted later (end of this block)
+      // During spilling, the array in map will not be used, so we can borrow that and use it
+      // as the underline array for in-memory sorter (it's always large enough).
       final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter(
-        taskMemoryManager, recordComparator, prefixComparator, Math.max(1, map.numElements()));
+        taskMemoryManager, recordComparator, prefixComparator, map.getArray());
 
       // We cannot use the destructive iterator here because we are reusing the existing memory
       // pages in BytesToBytesMap to hold records during sorting.
@@ -123,10 +122,9 @@ public final class UnsafeKVExternalSorter {
         pageSizeBytes,
         inMemSorter);
 
-      sorter.spill();
-      map.free();
-      // counting the memory used UnsafeInMemorySorter
-      taskMemoryManager.acquireExecutionMemory(inMemSorter.getMemoryUsage(), sorter);
+      // reset the map, so we can re-use it to insert new records. the inMemSorter will not used
+      // anymore, so the underline array could be used by map again.
+      map.reset();
     }
   }
 
@@ -142,6 +140,15 @@ public final class UnsafeKVExternalSorter {
       value.getBaseObject(), value.getBaseOffset(), value.getSizeInBytes(), prefix);
   }
 
+  /**
+   * Merges another UnsafeKVExternalSorter into `this`, the other one will be emptied.
+   *
+   * @throws IOException
+   */
+  public void merge(UnsafeKVExternalSorter other) throws IOException {
+    sorter.merge(other.sorter);
+  }
+
   /**
    * Returns a sorted iterator. It is the caller's responsibility to call `cleanupResources()`
    * after consuming this iterator.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index 713a4db0cd59bc255f135d82c19b8eac8abfda69..ce8d592c368eeb61f827421026e98c8df487d79b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -34,14 +34,18 @@ import org.apache.spark.sql.types.StructType
  *
  * This iterator first uses hash-based aggregation to process input rows. It uses
  * a hash map to store groups and their corresponding aggregation buffers. If we
- * this map cannot allocate memory from memory manager,
- * it switches to sort-based aggregation. The process of the switch has the following step:
+ * this map cannot allocate memory from memory manager, it spill the map into disk
+ * and create a new one. After processed all the input, then merge all the spills
+ * together using external sorter, and do sort-based aggregation.
+ *
+ * The process has the following step:
+ *  - Step 0: Do hash-based aggregation.
  *  - Step 1: Sort all entries of the hash map based on values of grouping expressions and
  *            spill them to disk.
- *  - Step 2: Create a external sorter based on the spilled sorted map entries.
- *  - Step 3: Redirect all input rows to the external sorter.
- *  - Step 4: Get a sorted [[KVIterator]] from the external sorter.
- *  - Step 5: Initialize sort-based aggregation.
+ *  - Step 2: Create a external sorter based on the spilled sorted map entries and reset the map.
+ *  - Step 3: Get a sorted [[KVIterator]] from the external sorter.
+ *  - Step 4: Repeat step 0 until no more input.
+ *  - Step 5: Initialize sort-based aggregation on the sorted iterator.
  * Then, this iterator works in the way of sort-based aggregation.
  *
  * The code of this class is organized as follows:
@@ -488,9 +492,10 @@ class TungstenAggregationIterator(
 
   // The function used to read and process input rows. When processing input rows,
   // it first uses hash-based aggregation by putting groups and their buffers in
-  // hashMap. If we could not allocate more memory for the map, we switch to
-  // sort-based aggregation (by calling switchToSortBasedAggregation).
-  private def processInputs(): Unit = {
+  // hashMap. If there is not enough memory, it will multiple hash-maps, spilling
+  // after each becomes full then using sort to merge these spills, finally do sort
+  // based aggregation.
+  private def processInputs(fallbackStartsAt: Int): Unit = {
     if (groupingExpressions.isEmpty) {
       // If there is no grouping expressions, we can just reuse the same buffer over and over again.
       // Note that it would be better to eliminate the hash map entirely in the future.
@@ -502,44 +507,40 @@ class TungstenAggregationIterator(
         processRow(buffer, newInput)
       }
     } else {
-      while (!sortBased && inputIter.hasNext) {
+      var i = 0
+      while (inputIter.hasNext) {
         val newInput = inputIter.next()
         numInputRows += 1
         val groupingKey = groupProjection.apply(newInput)
-        val buffer: UnsafeRow = hashMap.getAggregationBufferFromUnsafeRow(groupingKey)
+        var buffer: UnsafeRow = null
+        if (i < fallbackStartsAt) {
+          buffer = hashMap.getAggregationBufferFromUnsafeRow(groupingKey)
+        }
         if (buffer == null) {
-          // buffer == null means that we could not allocate more memory.
-          // Now, we need to spill the map and switch to sort-based aggregation.
-          switchToSortBasedAggregation(groupingKey, newInput)
-        } else {
-          processRow(buffer, newInput)
+          val sorter = hashMap.destructAndCreateExternalSorter()
+          if (externalSorter == null) {
+            externalSorter = sorter
+          } else {
+            externalSorter.merge(sorter)
+          }
+          i = 0
+          buffer = hashMap.getAggregationBufferFromUnsafeRow(groupingKey)
+          if (buffer == null) {
+            // failed to allocate the first page
+            throw new OutOfMemoryError("No enough memory for aggregation")
+          }
         }
+        processRow(buffer, newInput)
+        i += 1
       }
-    }
-  }
 
-  // This function is only used for testing. It basically the same as processInputs except
-  // that it switch to sort-based aggregation after `fallbackStartsAt` input rows have
-  // been processed.
-  private def processInputsWithControlledFallback(fallbackStartsAt: Int): Unit = {
-    var i = 0
-    while (!sortBased && inputIter.hasNext) {
-      val newInput = inputIter.next()
-      numInputRows += 1
-      val groupingKey = groupProjection.apply(newInput)
-      val buffer: UnsafeRow = if (i < fallbackStartsAt) {
-        hashMap.getAggregationBufferFromUnsafeRow(groupingKey)
-      } else {
-        null
-      }
-      if (buffer == null) {
-        // buffer == null means that we could not allocate more memory.
-        // Now, we need to spill the map and switch to sort-based aggregation.
-        switchToSortBasedAggregation(groupingKey, newInput)
-      } else {
-        processRow(buffer, newInput)
+      if (externalSorter != null) {
+        val sorter = hashMap.destructAndCreateExternalSorter()
+        externalSorter.merge(sorter)
+        hashMap.free()
+
+        switchToSortBasedAggregation()
       }
-      i += 1
     }
   }
 
@@ -561,88 +562,8 @@ class TungstenAggregationIterator(
   /**
    * Switch to sort-based aggregation when the hash-based approach is unable to acquire memory.
    */
-  private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: InternalRow): Unit = {
+  private def switchToSortBasedAggregation(): Unit = {
     logInfo("falling back to sort based aggregation.")
-    // Step 1: Get the ExternalSorter containing sorted entries of the map.
-    externalSorter = hashMap.destructAndCreateExternalSorter()
-
-    // Step 2: If we have aggregate function with mode Partial or Complete,
-    // we need to process input rows to get aggregation buffer.
-    // So, later in the sort-based aggregation iterator, we can do merge.
-    // If aggregate functions are with mode Final and PartialMerge,
-    // we just need to project the aggregation buffer from an input row.
-    val needsProcess = aggregationMode match {
-      case (Some(Partial), None) => true
-      case (None, Some(Complete)) => true
-      case (Some(Final), Some(Complete)) => true
-      case _ => false
-    }
-
-    // Note: Since we spill the sorter's contents immediately after creating it, we must insert
-    // something into the sorter here to ensure that we acquire at least a page of memory.
-    // This is done through `externalSorter.insertKV`, which will trigger the page allocation.
-    // Otherwise, children operators may steal the window of opportunity and starve our sorter.
-
-    if (needsProcess) {
-      // First, we create a buffer.
-      val buffer = createNewAggregationBuffer()
-
-      // Process firstKey and firstInput.
-      // Initialize buffer.
-      buffer.copyFrom(initialAggregationBuffer)
-      processRow(buffer, firstInput)
-      externalSorter.insertKV(firstKey, buffer)
-
-      // Process the rest of input rows.
-      while (inputIter.hasNext) {
-        val newInput = inputIter.next()
-        numInputRows += 1
-        val groupingKey = groupProjection.apply(newInput)
-        buffer.copyFrom(initialAggregationBuffer)
-        processRow(buffer, newInput)
-        externalSorter.insertKV(groupingKey, buffer)
-      }
-    } else {
-      // When needsProcess is false, the format of input rows is groupingKey + aggregation buffer.
-      // We need to project the aggregation buffer part from an input row.
-      val buffer = createNewAggregationBuffer()
-      // In principle, we could use `allAggregateFunctions.flatMap(_.inputAggBufferAttributes)` to
-      // extract the aggregation buffer. In practice, however, we extract it positionally by relying
-      // on it being present at the end of the row. The reason for this relates to how the different
-      // aggregates handle input binding.
-      //
-      // ImperativeAggregate uses field numbers and field number offsets to manipulate its buffers,
-      // so its correctness does not rely on attribute bindings. When we fall back to sort-based
-      // aggregation, these field number offsets (mutableAggBufferOffset and inputAggBufferOffset)
-      // need to be updated and any internal state in the aggregate functions themselves must be
-      // reset, so we call withNewMutableAggBufferOffset and withNewInputAggBufferOffset to reset
-      // this state and update the offsets.
-      //
-      // The updated ImperativeAggregate will have different attribute ids for its
-      // aggBufferAttributes and inputAggBufferAttributes. This isn't a problem for the actual
-      // ImperativeAggregate evaluation, but it means that
-      // `allAggregateFunctions.flatMap(_.inputAggBufferAttributes)` will no longer match the
-      // attributes in `originalInputAttributes`, which is why we can't use those attributes here.
-      //
-      // For more details, see the discussion on PR #9038.
-      val bufferExtractor = newMutableProjection(
-        originalInputAttributes.drop(initialInputBufferOffset),
-        originalInputAttributes)()
-      bufferExtractor.target(buffer)
-
-      // Insert firstKey and its buffer.
-      bufferExtractor(firstInput)
-      externalSorter.insertKV(firstKey, buffer)
-
-      // Insert the rest of input rows.
-      while (inputIter.hasNext) {
-        val newInput = inputIter.next()
-        numInputRows += 1
-        val groupingKey = groupProjection.apply(newInput)
-        bufferExtractor(newInput)
-        externalSorter.insertKV(groupingKey, buffer)
-      }
-    }
 
     // Set aggregationMode, processRow, and generateOutput for sort-based aggregation.
     val newAggregationMode = aggregationMode match {
@@ -762,15 +683,7 @@ class TungstenAggregationIterator(
   /**
    * Start processing input rows.
    */
-  testFallbackStartsAt match {
-    case None =>
-      processInputs()
-    case Some(fallbackStartsAt) =>
-      // This is the testing path. processInputsWithControlledFallback is same as processInputs
-      // except that it switches to sort-based aggregation after `fallbackStartsAt` input rows
-      // have been processed.
-      processInputsWithControlledFallback(fallbackStartsAt)
-  }
+  processInputs(testFallbackStartsAt.getOrElse(Int.MaxValue))
 
   // If we did not switch to sort-based aggregation in processInputs,
   // we pre-load the first key-value pair from the map (to make hasNext idempotent).
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
index a38623623a441c1357ede81307466ec441dbe35f..7ceaee38d131bb53d6d2a544bf956aedd14a1e37 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
@@ -170,9 +170,6 @@ class UnsafeFixedWidthAggregationMapSuite
   }
 
   testWithMemoryLeakDetection("test external sorting") {
-    // Memory consumption in the beginning of the task.
-    val initialMemoryConsumption = taskMemoryManager.getMemoryConsumptionForThisTask()
-
     val map = new UnsafeFixedWidthAggregationMap(
       emptyAggregationBuffer,
       aggBufferSchema,
@@ -189,35 +186,33 @@ class UnsafeFixedWidthAggregationMapSuite
       buf.setInt(0, keyString.length)
       assert(buf != null)
     }
-
-    // Convert the map into a sorter
     val sorter = map.destructAndCreateExternalSorter()
 
     // Add more keys to the sorter and make sure the results come out sorted.
     val additionalKeys = randomStrings(1024)
-    val keyConverter = UnsafeProjection.create(groupKeySchema)
-    val valueConverter = UnsafeProjection.create(aggBufferSchema)
-
     additionalKeys.zipWithIndex.foreach { case (str, i) =>
-      val k = InternalRow(UTF8String.fromString(str))
-      val v = InternalRow(str.length)
-      sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v))
+      val buf = map.getAggregationBuffer(InternalRow(UTF8String.fromString(str)))
+      buf.setInt(0, str.length)
 
       if ((i % 100) == 0) {
-        memoryManager.markExecutionAsOutOfMemoryOnce()
-        sorter.closeCurrentPage()
+        val sorter2 = map.destructAndCreateExternalSorter()
+        sorter.merge(sorter2)
       }
     }
+    val sorter2 = map.destructAndCreateExternalSorter()
+    sorter.merge(sorter2)
 
     val out = new scala.collection.mutable.ArrayBuffer[String]
     val iter = sorter.sortedIterator()
     while (iter.next()) {
-      assert(iter.getKey.getString(0).length === iter.getValue.getInt(0))
-      out += iter.getKey.getString(0)
+      // At here, we also test if copy is correct.
+      val key = iter.getKey.copy()
+      val value = iter.getValue.copy()
+      assert(key.getString(0).length === value.getInt(0))
+      out += key.getString(0)
     }
 
     assert(out === (keys ++ additionalKeys).sorted)
-
     map.free()
   }
 
@@ -232,25 +227,21 @@ class UnsafeFixedWidthAggregationMapSuite
       PAGE_SIZE_BYTES,
       false // disable perf metrics
     )
-
-    // Convert the map into a sorter
     val sorter = map.destructAndCreateExternalSorter()
 
     // Add more keys to the sorter and make sure the results come out sorted.
     val additionalKeys = randomStrings(1024)
-    val keyConverter = UnsafeProjection.create(groupKeySchema)
-    val valueConverter = UnsafeProjection.create(aggBufferSchema)
-
     additionalKeys.zipWithIndex.foreach { case (str, i) =>
-      val k = InternalRow(UTF8String.fromString(str))
-      val v = InternalRow(str.length)
-      sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v))
+      val buf = map.getAggregationBuffer(InternalRow(UTF8String.fromString(str)))
+      buf.setInt(0, str.length)
 
       if ((i % 100) == 0) {
-        memoryManager.markExecutionAsOutOfMemoryOnce()
-        sorter.closeCurrentPage()
+        val sorter2 = map.destructAndCreateExternalSorter()
+        sorter.merge(sorter2)
       }
     }
+    val sorter2 = map.destructAndCreateExternalSorter()
+    sorter.merge(sorter2)
 
     val out = new scala.collection.mutable.ArrayBuffer[String]
     val iter = sorter.sortedIterator()
@@ -262,16 +253,12 @@ class UnsafeFixedWidthAggregationMapSuite
       out += key.getString(0)
     }
 
-    assert(out === (additionalKeys).sorted)
-
+    assert(out === additionalKeys.sorted)
     map.free()
   }
 
   testWithMemoryLeakDetection("test external sorting with empty records") {
 
-    // Memory consumption in the beginning of the task.
-    val initialMemoryConsumption = taskMemoryManager.getMemoryConsumptionForThisTask()
-
     val map = new UnsafeFixedWidthAggregationMap(
       emptyAggregationBuffer,
       StructType(Nil),
@@ -281,7 +268,6 @@ class UnsafeFixedWidthAggregationMapSuite
       PAGE_SIZE_BYTES,
       false // disable perf metrics
     )
-
     (1 to 10).foreach { i =>
       val buf = map.getAggregationBuffer(UnsafeRow.createFromByteArray(0, 0))
       assert(buf != null)
@@ -292,13 +278,15 @@ class UnsafeFixedWidthAggregationMapSuite
 
     // Add more keys to the sorter and make sure the results come out sorted.
     (1 to 4096).foreach { i =>
-      sorter.insertKV(UnsafeRow.createFromByteArray(0, 0), UnsafeRow.createFromByteArray(0, 0))
+      map.getAggregationBufferFromUnsafeRow(UnsafeRow.createFromByteArray(0, 0))
 
       if ((i % 100) == 0) {
-        memoryManager.markExecutionAsOutOfMemoryOnce()
-        sorter.closeCurrentPage()
+        val sorter2 = map.destructAndCreateExternalSorter()
+        sorter.merge(sorter2)
       }
     }
+    val sorter2 = map.destructAndCreateExternalSorter()
+    sorter.merge(sorter2)
 
     var count = 0
     val iter = sorter.sortedIterator()
@@ -309,9 +297,8 @@ class UnsafeFixedWidthAggregationMapSuite
       count += 1
     }
 
-    // 1 record was from the map and 4096 records were explicitly inserted.
-    assert(count === 4097)
-
+    // 1 record per map, spilled 42 times.
+    assert(count === 42)
     map.free()
   }
 
@@ -345,6 +332,7 @@ class UnsafeFixedWidthAggregationMapSuite
     var sorter: UnsafeKVExternalSorter = null
     try {
       sorter = map.destructAndCreateExternalSorter()
+      map.free()
     } finally {
       if (sorter != null) {
         sorter.cleanupResources()
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index 74061db0f28af51247343b7c373bb62c5c173a76..ea80060e370e02f1d2e4a1338f7b80c56f7d240f 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -22,13 +22,12 @@ import scala.collection.JavaConverters._
 import org.apache.spark.SparkException
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
-import org.apache.spark.sql.execution.aggregate
 import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
 import org.apache.spark.sql.functions._
-import org.apache.spark.sql.test.SQLTestUtils
-import org.apache.spark.sql.types._
 import org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum}
 import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql.types._
 
 class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFunction {
 
@@ -702,6 +701,13 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
     }
   }
 
+  test("no aggregation function (SPARK-11486)") {
+    val df = sqlContext.range(20).selectExpr("id", "repeat(id, 1) as s")
+      .groupBy("s").count()
+      .groupBy().count()
+    checkAnswer(df, Row(20) :: Nil)
+  }
+
   test("udaf with all data types") {
     val struct =
       StructType(