diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index d2fcdea4f2ceeb6cf3f042b1b139556a0d9be882..44120e591f2fba22f966794d491dddc363d2449b 100644
--- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -170,6 +170,8 @@ public final class BytesToBytesMap extends MemoryConsumer {
 
   private long peakMemoryUsedBytes = 0L;
 
+  private final int initialCapacity;
+
   private final BlockManager blockManager;
   private final SerializerManager serializerManager;
   private volatile MapIterator destructiveIterator = null;
@@ -202,6 +204,7 @@ public final class BytesToBytesMap extends MemoryConsumer {
       throw new IllegalArgumentException("Page size " + pageSizeBytes + " cannot exceed " +
         TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES);
     }
+    this.initialCapacity = initialCapacity;
     allocate(initialCapacity);
   }
 
@@ -902,12 +905,12 @@ public final class BytesToBytesMap extends MemoryConsumer {
   public void reset() {
     numKeys = 0;
     numValues = 0;
-    longArray.zeroOut();
-
+    freeArray(longArray);
     while (dataPages.size() > 0) {
       MemoryBlock dataPage = dataPages.removeLast();
       freePage(dataPage);
     }
+    allocate(initialCapacity);
     currentPage = null;
     pageCursor = 0;
   }