diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index 0f42950e6ed8b61490dc70d4e71e2be465ca1232..481375f493a5045f90bd94ae2148082f366a27e3 100644
--- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -17,7 +17,6 @@
 
 package org.apache.spark.unsafe.map;
 
-import java.io.IOException;
 import java.lang.Override;
 import java.lang.UnsupportedOperationException;
 import java.util.Iterator;
@@ -212,7 +211,7 @@ public final class BytesToBytesMap {
    */
   public int numElements() { return numElements; }
 
-  private static final class BytesToBytesMapIterator implements Iterator<Location> {
+  public static final class BytesToBytesMapIterator implements Iterator<Location> {
 
     private final int numRecords;
     private final Iterator<MemoryBlock> dataPagesIterator;
@@ -222,7 +221,8 @@ public final class BytesToBytesMap {
     private Object pageBaseObject;
     private long offsetInPage;
 
-    BytesToBytesMapIterator(int numRecords, Iterator<MemoryBlock> dataPagesIterator, Location loc) {
+    private BytesToBytesMapIterator(
+        int numRecords, Iterator<MemoryBlock> dataPagesIterator, Location loc) {
       this.numRecords = numRecords;
       this.dataPagesIterator = dataPagesIterator;
       this.loc = loc;
@@ -244,13 +244,13 @@ public final class BytesToBytesMap {
 
     @Override
     public Location next() {
-      int keyLength = (int) PlatformDependent.UNSAFE.getLong(pageBaseObject, offsetInPage);
-      if (keyLength == END_OF_PAGE_MARKER) {
+      int totalLength = PlatformDependent.UNSAFE.getInt(pageBaseObject, offsetInPage);
+      if (totalLength == END_OF_PAGE_MARKER) {
         advanceToNextPage();
-        keyLength = (int) PlatformDependent.UNSAFE.getLong(pageBaseObject, offsetInPage);
+        totalLength = PlatformDependent.UNSAFE.getInt(pageBaseObject, offsetInPage);
       }
       loc.with(pageBaseObject, offsetInPage);
-      offsetInPage += 8 + 8 + keyLength + loc.getValueLength();
+      offsetInPage += 8 + totalLength;
       currentRecordNumber++;
       return loc;
     }
@@ -269,7 +269,7 @@ public final class BytesToBytesMap {
    * If any other lookups or operations are performed on this map while iterating over it, including
    * `lookup()`, the behavior of the returned iterator is undefined.
    */
-  public Iterator<Location> iterator() {
+  public BytesToBytesMapIterator iterator() {
     return new BytesToBytesMapIterator(numElements, dataPages.iterator(), loc);
   }
 
@@ -352,15 +352,18 @@ public final class BytesToBytesMap {
         taskMemoryManager.getOffsetInPage(fullKeyAddress));
     }
 
-    private void updateAddressesAndSizes(Object page, long keyOffsetInPage) {
-        long position = keyOffsetInPage;
-        keyLength = (int) PlatformDependent.UNSAFE.getLong(page, position);
-        position += 8; // word used to store the key size
-        keyMemoryLocation.setObjAndOffset(page, position);
-        position += keyLength;
-        valueLength = (int) PlatformDependent.UNSAFE.getLong(page, position);
-        position += 8; // word used to store the key size
-        valueMemoryLocation.setObjAndOffset(page, position);
+    private void updateAddressesAndSizes(final Object page, final long keyOffsetInPage) {
+      long position = keyOffsetInPage;
+      final int totalLength = PlatformDependent.UNSAFE.getInt(page, position);
+      position += 4;
+      keyLength = PlatformDependent.UNSAFE.getInt(page, position);
+      position += 4;
+      valueLength = totalLength - keyLength;
+
+      keyMemoryLocation.setObjAndOffset(page, position);
+
+      position += keyLength;
+      valueMemoryLocation.setObjAndOffset(page, position);
     }
 
     Location with(int pos, int keyHashcode, boolean isDefined) {
@@ -478,7 +481,7 @@ public final class BytesToBytesMap {
       // the key address instead of storing the absolute address of the value, the key and value
       // must be stored in the same memory page.
       // (8 byte key length) (key) (8 byte value length) (value)
-      final long requiredSize = 8 + keyLengthBytes + 8 + valueLengthBytes;
+      final long requiredSize = 8 + keyLengthBytes + valueLengthBytes;
 
       // --- Figure out where to insert the new record ---------------------------------------------
 
@@ -508,7 +511,7 @@ public final class BytesToBytesMap {
           // There wasn't enough space in the current page, so write an end-of-page marker:
           final Object pageBaseObject = currentDataPage.getBaseObject();
           final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor;
-          PlatformDependent.UNSAFE.putLong(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER);
+          PlatformDependent.UNSAFE.putInt(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER);
         }
         final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
         if (memoryGranted != pageSizeBytes) {
@@ -535,21 +538,22 @@ public final class BytesToBytesMap {
       long insertCursor = dataPageInsertOffset;
 
       // Compute all of our offsets up-front:
-      final long keySizeOffsetInPage = insertCursor;
-      insertCursor += 8; // word used to store the key size
+      final long totalLengthOffset = insertCursor;
+      insertCursor += 4;
+      final long keyLengthOffset = insertCursor;
+      insertCursor += 4;
       final long keyDataOffsetInPage = insertCursor;
       insertCursor += keyLengthBytes;
-      final long valueSizeOffsetInPage = insertCursor;
-      insertCursor += 8; // word used to store the value size
       final long valueDataOffsetInPage = insertCursor;
       insertCursor += valueLengthBytes; // word used to store the value size
 
+      PlatformDependent.UNSAFE.putInt(dataPageBaseObject, totalLengthOffset,
+        keyLengthBytes + valueLengthBytes);
+      PlatformDependent.UNSAFE.putInt(dataPageBaseObject, keyLengthOffset, keyLengthBytes);
       // Copy the key
-      PlatformDependent.UNSAFE.putLong(dataPageBaseObject, keySizeOffsetInPage, keyLengthBytes);
       PlatformDependent.copyMemory(
         keyBaseObject, keyBaseOffset, dataPageBaseObject, keyDataOffsetInPage, keyLengthBytes);
       // Copy the value
-      PlatformDependent.UNSAFE.putLong(dataPageBaseObject, valueSizeOffsetInPage, valueLengthBytes);
       PlatformDependent.copyMemory(valueBaseObject, valueBaseOffset, dataPageBaseObject,
         valueDataOffsetInPage, valueLengthBytes);
 
@@ -557,7 +561,7 @@ public final class BytesToBytesMap {
 
       if (useOverflowPage) {
         // Store the end-of-page marker at the end of the data page
-        PlatformDependent.UNSAFE.putLong(dataPageBaseObject, insertCursor, END_OF_PAGE_MARKER);
+        PlatformDependent.UNSAFE.putInt(dataPageBaseObject, insertCursor, END_OF_PAGE_MARKER);
       } else {
         pageCursor += requiredSize;
       }
@@ -565,7 +569,7 @@ public final class BytesToBytesMap {
       numElements++;
       bitset.set(pos);
       final long storedKeyAddress = taskMemoryManager.encodePageNumberAndOffset(
-        dataPage, keySizeOffsetInPage);
+        dataPage, totalLengthOffset);
       longArray.set(pos * 2, storedKeyAddress);
       longArray.set(pos * 2 + 1, keyHashcode);
       updateAddressesAndSizes(storedKeyAddress);
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index 866e0b41515771574baa498aa54ab0c7b715f403..c05f2c332eee3596500a3a1bf336c66b867114b6 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -282,6 +282,21 @@ public final class UnsafeExternalSorter {
     sorter.insertRecord(recordAddress, prefix);
   }
 
+  /**
+   * Write a record to the sorter. The record is broken down into two different parts, and
+   *
+   */
+  public void insertRecord(
+      Object recordBaseObject1,
+      long recordBaseOffset1,
+      int lengthInBytes1,
+      Object recordBaseObject2,
+      long recordBaseOffset2,
+      int lengthInBytes2,
+      long prefix) throws IOException {
+
+  }
+
   public UnsafeSorterIterator getSortedIterator() throws IOException {
     final UnsafeSorterIterator inMemoryIterator = sorter.getSortedIterator();
     int numIteratorsToMerge = spillWriters.size() + (inMemoryIterator.hasNext() ? 1 : 0);
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index 60f483acbcb80581d60fbfb0cdd2154f05400ba6..70f8ca4d213450c14c5c5702299dc79373080088 100644
--- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -243,17 +243,17 @@ public abstract class AbstractBytesToBytesMapSuite {
   @Test
   public void iteratingOverDataPagesWithWastedSpace() throws Exception {
     final int NUM_ENTRIES = 1000 * 1000;
-    final int KEY_LENGTH = 16;
+    final int KEY_LENGTH = 24;
     final int VALUE_LENGTH = 40;
     final BytesToBytesMap map = new BytesToBytesMap(
       taskMemoryManager, shuffleMemoryManager, NUM_ENTRIES, PAGE_SIZE_BYTES);
-    // Each record will take 8 + 8 + 16 + 40 = 72 bytes of space in the data page. Our 64-megabyte
+    // Each record will take 8 + 24 + 40 = 72 bytes of space in the data page. Our 64-megabyte
     // pages won't be evenly-divisible by records of this size, which will cause us to waste some
     // space at the end of the page. This is necessary in order for us to take the end-of-record
     // handling branch in iterator().
     try {
       for (int i = 0; i < NUM_ENTRIES; i++) {
-        final long[] key = new long[] { i, i };  // 2 * 8 = 16 bytes
+        final long[] key = new long[] { i, i, i };  // 3 * 8 = 24 bytes
         final long[] value = new long[] { i, i, i, i, i }; // 5 * 8 = 40 bytes
         final BytesToBytesMap.Location loc = map.lookup(
           key,
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeKeyValueSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeKeyValueSorter.java
new file mode 100644
index 0000000000000000000000000000000000000000..59c774da74acf95e2012be470467c0ea420e9de8
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeKeyValueSorter.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution;
+
+import java.io.IOException;
+
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
+import org.apache.spark.unsafe.KVIterator;
+
+public abstract class UnsafeKeyValueSorter {
+
+  public abstract void insert(UnsafeRow key, UnsafeRow value);
+
+  public abstract KVIterator<UnsafeRow, UnsafeRow> sort() throws IOException;
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index 08a98cdd94a4c28b4d92c54164ad0f0cb46c1d0c..c18b6dea6b2e14a3ba9f9dbfdc6c86f401fe3fa8 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -17,9 +17,6 @@
 
 package org.apache.spark.sql.execution;
 
-import java.io.IOException;
-import java.util.Iterator;
-
 import org.apache.spark.shuffle.ShuffleMemoryManager;
 import org.apache.spark.sql.catalyst.InternalRow;
 import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
@@ -28,6 +25,7 @@ import org.apache.spark.sql.types.Decimal;
 import org.apache.spark.sql.types.DecimalType;
 import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
+import org.apache.spark.unsafe.KVIterator;
 import org.apache.spark.unsafe.PlatformDependent;
 import org.apache.spark.unsafe.map.BytesToBytesMap;
 import org.apache.spark.unsafe.memory.MemoryLocation;
@@ -156,54 +154,55 @@ public final class UnsafeFixedWidthAggregationMap {
     return currentAggregationBuffer;
   }
 
-  /**
-   * Mutable pair object returned by {@link UnsafeFixedWidthAggregationMap#iterator()}.
-   */
-  public static class MapEntry {
-    private MapEntry() { };
-    public final UnsafeRow key = new UnsafeRow();
-    public final UnsafeRow value = new UnsafeRow();
-  }
-
   /**
    * Returns an iterator over the keys and values in this map.
    *
    * For efficiency, each call returns the same object.
    */
-  public Iterator<MapEntry> iterator() {
-    return new Iterator<MapEntry>() {
+  public KVIterator<UnsafeRow, UnsafeRow> iterator() {
+    return new KVIterator<UnsafeRow, UnsafeRow>() {
+
+      private final BytesToBytesMap.BytesToBytesMapIterator mapLocationIterator = map.iterator();
+      private final UnsafeRow key = new UnsafeRow();
+      private final UnsafeRow value = new UnsafeRow();
 
-      private final MapEntry entry = new MapEntry();
-      private final Iterator<BytesToBytesMap.Location> mapLocationIterator = map.iterator();
+      @Override
+      public boolean next() {
+        if (mapLocationIterator.hasNext()) {
+          final BytesToBytesMap.Location loc = mapLocationIterator.next();
+          final MemoryLocation keyAddress = loc.getKeyAddress();
+          final MemoryLocation valueAddress = loc.getValueAddress();
+          key.pointTo(
+            keyAddress.getBaseObject(),
+            keyAddress.getBaseOffset(),
+            groupingKeySchema.length(),
+            loc.getKeyLength()
+          );
+          value.pointTo(
+            valueAddress.getBaseObject(),
+            valueAddress.getBaseOffset(),
+            aggregationBufferSchema.length(),
+            loc.getValueLength()
+          );
+          return true;
+        } else {
+          return false;
+        }
+      }
 
       @Override
-      public boolean hasNext() {
-        return mapLocationIterator.hasNext();
+      public UnsafeRow getKey() {
+        return key;
       }
 
       @Override
-      public MapEntry next() {
-        final BytesToBytesMap.Location loc = mapLocationIterator.next();
-        final MemoryLocation keyAddress = loc.getKeyAddress();
-        final MemoryLocation valueAddress = loc.getValueAddress();
-        entry.key.pointTo(
-          keyAddress.getBaseObject(),
-          keyAddress.getBaseOffset(),
-          groupingKeySchema.length(),
-          loc.getKeyLength()
-        );
-        entry.value.pointTo(
-          valueAddress.getBaseObject(),
-          valueAddress.getBaseOffset(),
-          aggregationBufferSchema.length(),
-          loc.getValueLength()
-        );
-        return entry;
+      public UnsafeRow getValue() {
+        return value;
       }
 
       @Override
-      public void remove() {
-        throw new UnsupportedOperationException();
+      public void close() {
+        // Do nothing.
       }
     };
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index 469de6ca8e1012d109ec72742de1e7f951a7235b..cd87b8deba0c2c4d160f5372ec10decac09fefed 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -287,21 +287,26 @@ case class GeneratedAggregate(
         new Iterator[InternalRow] {
           private[this] val mapIterator = aggregationMap.iterator()
           private[this] val resultProjection = resultProjectionBuilder()
+          private[this] var _hasNext = mapIterator.next()
 
-          def hasNext: Boolean = mapIterator.hasNext
+          def hasNext: Boolean = _hasNext
 
           def next(): InternalRow = {
-            val entry = mapIterator.next()
-            val result = resultProjection(joinedRow(entry.key, entry.value))
-            if (hasNext) {
-              result
+            if (_hasNext) {
+              val result = resultProjection(joinedRow(mapIterator.getKey, mapIterator.getValue))
+              _hasNext = mapIterator.next()
+              if (_hasNext) {
+                result
+              } else {
+                // This is the last element in the iterator, so let's free the buffer. Before we do,
+                // though, we need to make a defensive copy of the result so that we don't return an
+                // object that might contain dangling pointers to the freed memory
+                val resultCopy = result.copy()
+                aggregationMap.free()
+                resultCopy
+              }
             } else {
-              // This is the last element in the iterator, so let's free the buffer. Before we do,
-              // though, we need to make a defensive copy of the result so that we don't return an
-              // object that might contain dangling pointers to the freed memory
-              val resultCopy = result.copy()
-              aggregationMap.free()
-              resultCopy
+              throw new java.util.NoSuchElementException
             }
           }
         }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
index 79fd52dacda52d6b03da39baca5ba4f2d3c35fde..6a2c51ca88ac3d77264ed1a9f8df44ac76a2e6a4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
 import org.scalatest.{BeforeAndAfterEach, Matchers}
 
 import scala.collection.JavaConverters._
+import scala.collection.mutable
 import scala.util.Random
 
 import org.apache.spark.SparkFunSuite
@@ -52,7 +53,7 @@ class UnsafeFixedWidthAggregationMapSuite
 
   override def afterEach(): Unit = {
     if (taskMemoryManager != null) {
-      val leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask
+      val leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask()
       assert(taskMemoryManager.cleanUpAllAllocatedMemory() === 0)
       assert(leakedShuffleMemory === 0)
       taskMemoryManager = null
@@ -80,7 +81,7 @@ class UnsafeFixedWidthAggregationMapSuite
       PAGE_SIZE_BYTES,
       false // disable perf metrics
     )
-    assert(!map.iterator().hasNext)
+    assert(!map.iterator().next())
     map.free()
   }
 
@@ -100,13 +101,13 @@ class UnsafeFixedWidthAggregationMapSuite
     // Looking up a key stores a zero-entry in the map (like Python Counters or DefaultDicts)
     assert(map.getAggregationBuffer(groupKey) != null)
     val iter = map.iterator()
-    val entry = iter.next()
-    assert(!iter.hasNext)
-    entry.key.getString(0) should be ("cats")
-    entry.value.getInt(0) should be (0)
+    assert(iter.next())
+    iter.getKey.getString(0) should be ("cats")
+    iter.getValue.getInt(0) should be (0)
+    assert(!iter.next())
 
     // Modifications to rows retrieved from the map should update the values in the map
-    entry.value.setInt(0, 42)
+    iter.getValue.setInt(0, 42)
     map.getAggregationBuffer(groupKey).getInt(0) should be (42)
 
     map.free()
@@ -128,12 +129,14 @@ class UnsafeFixedWidthAggregationMapSuite
     groupKeys.foreach { keyString =>
       assert(map.getAggregationBuffer(InternalRow(UTF8String.fromString(keyString))) != null)
     }
-    val seenKeys: Set[String] = map.iterator().asScala.map { entry =>
-      entry.key.getString(0)
-    }.toSet
-    seenKeys.size should be (groupKeys.size)
-    seenKeys should be (groupKeys)
 
+    val seenKeys = new mutable.HashSet[String]
+    val iter = map.iterator()
+    while (iter.next()) {
+      seenKeys += iter.getKey.getString(0)
+    }
+    assert(seenKeys.size === groupKeys.size)
+    assert(seenKeys === groupKeys)
     map.free()
   }
 
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java b/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java
new file mode 100644
index 0000000000000000000000000000000000000000..fb163401c0d274743e6cb3f72869cc0a2f821adf
--- /dev/null
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/KVIterator.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe;
+
+public abstract class KVIterator<K, V> {
+
+  public abstract boolean next();
+
+  public abstract K getKey();
+
+  public abstract V getValue();
+
+  public abstract void close();
+}