diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
index 1d460432be9ffaef1bde0b37450f229f68dbeacf..1aa6ba42012618e84cd6773547c07cb2e12b3560 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
@@ -59,14 +59,14 @@ final class UnsafeShuffleExternalSorter {
 
   private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleExternalSorter.class);
 
-  private static final int PAGE_SIZE = PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES;
   @VisibleForTesting
   static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024;
-  @VisibleForTesting
-  static final int MAX_RECORD_SIZE = PAGE_SIZE - 4;
 
   private final int initialSize;
   private final int numPartitions;
+  private final int pageSizeBytes;
+  @VisibleForTesting
+  final int maxRecordSizeBytes;
   private final TaskMemoryManager memoryManager;
   private final ShuffleMemoryManager shuffleMemoryManager;
   private final BlockManager blockManager;
@@ -109,7 +109,10 @@ final class UnsafeShuffleExternalSorter {
     this.numPartitions = numPartitions;
     // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
     this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
-
+    this.pageSizeBytes = (int) Math.min(
+      PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES,
+      conf.getSizeAsBytes("spark.buffer.pageSize", "64m"));
+    this.maxRecordSizeBytes = pageSizeBytes - 4;
     this.writeMetrics = writeMetrics;
     initializeForWriting();
   }
@@ -272,7 +275,11 @@ final class UnsafeShuffleExternalSorter {
   }
 
   private long getMemoryUsage() {
-    return sorter.getMemoryUsage() + (allocatedPages.size() * (long) PAGE_SIZE);
+    long totalPageSize = 0;
+    for (MemoryBlock page : allocatedPages) {
+      totalPageSize += page.size();
+    }
+    return sorter.getMemoryUsage() + totalPageSize;
   }
 
   private long freeMemory() {
@@ -346,23 +353,23 @@ final class UnsafeShuffleExternalSorter {
       // TODO: we should track metrics on the amount of space wasted when we roll over to a new page
       // without using the free space at the end of the current page. We should also do this for
       // BytesToBytesMap.
-      if (requiredSpace > PAGE_SIZE) {
+      if (requiredSpace > pageSizeBytes) {
         throw new IOException("Required space " + requiredSpace + " is greater than page size (" +
-          PAGE_SIZE + ")");
+          pageSizeBytes + ")");
       } else {
-        final long memoryAcquired = shuffleMemoryManager.tryToAcquire(PAGE_SIZE);
-        if (memoryAcquired < PAGE_SIZE) {
+        final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
+        if (memoryAcquired < pageSizeBytes) {
           shuffleMemoryManager.release(memoryAcquired);
           spill();
-          final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(PAGE_SIZE);
-          if (memoryAcquiredAfterSpilling != PAGE_SIZE) {
+          final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
+          if (memoryAcquiredAfterSpilling != pageSizeBytes) {
             shuffleMemoryManager.release(memoryAcquiredAfterSpilling);
-            throw new IOException("Unable to acquire " + PAGE_SIZE + " bytes of memory");
+            throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory");
           }
         }
-        currentPage = memoryManager.allocatePage(PAGE_SIZE);
+        currentPage = memoryManager.allocatePage(pageSizeBytes);
         currentPagePosition = currentPage.getBaseOffset();
-        freeSpaceInCurrentPage = PAGE_SIZE;
+        freeSpaceInCurrentPage = pageSizeBytes;
         allocatedPages.add(currentPage);
       }
     }
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
index 764578b181422b9f6ffe458cba404d83d4c38ca1..d47d6fc9c2ac4772d021b224cfd7d1b532b319be 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
@@ -129,6 +129,11 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
     open();
   }
 
+  @VisibleForTesting
+  public int maxRecordSizeBytes() {
+    return sorter.maxRecordSizeBytes;
+  }
+
   /**
    * This convenience method should only be called in test code.
    */
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index 80b03d7e99e2bc3a7d7c192de97a9d649736a924..c21990f4e477865e5485c9d89f22b1a1d1b19273 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -41,10 +41,7 @@ public final class UnsafeExternalSorter {
 
   private final Logger logger = LoggerFactory.getLogger(UnsafeExternalSorter.class);
 
-  private static final int PAGE_SIZE = 1 << 27;  // 128 megabytes
-  @VisibleForTesting
-  static final int MAX_RECORD_SIZE = PAGE_SIZE - 4;
-
+  private final long pageSizeBytes;
   private final PrefixComparator prefixComparator;
   private final RecordComparator recordComparator;
   private final int initialSize;
@@ -91,6 +88,7 @@ public final class UnsafeExternalSorter {
     this.initialSize = initialSize;
     // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units
     this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
+    this.pageSizeBytes = conf.getSizeAsBytes("spark.buffer.pageSize", "64m");
     initializeForWriting();
   }
 
@@ -147,7 +145,11 @@ public final class UnsafeExternalSorter {
   }
 
   private long getMemoryUsage() {
-    return sorter.getMemoryUsage() + (allocatedPages.size() * (long) PAGE_SIZE);
+    long totalPageSize = 0;
+    for (MemoryBlock page : allocatedPages) {
+      totalPageSize += page.size();
+    }
+    return sorter.getMemoryUsage() + totalPageSize;
   }
 
   @VisibleForTesting
@@ -214,23 +216,23 @@ public final class UnsafeExternalSorter {
       // TODO: we should track metrics on the amount of space wasted when we roll over to a new page
       // without using the free space at the end of the current page. We should also do this for
       // BytesToBytesMap.
-      if (requiredSpace > PAGE_SIZE) {
+      if (requiredSpace > pageSizeBytes) {
         throw new IOException("Required space " + requiredSpace + " is greater than page size (" +
-          PAGE_SIZE + ")");
+          pageSizeBytes + ")");
       } else {
-        final long memoryAcquired = shuffleMemoryManager.tryToAcquire(PAGE_SIZE);
-        if (memoryAcquired < PAGE_SIZE) {
+        final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
+        if (memoryAcquired < pageSizeBytes) {
           shuffleMemoryManager.release(memoryAcquired);
           spill();
-          final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(PAGE_SIZE);
-          if (memoryAcquiredAfterSpilling != PAGE_SIZE) {
+          final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
+          if (memoryAcquiredAfterSpilling != pageSizeBytes) {
             shuffleMemoryManager.release(memoryAcquiredAfterSpilling);
-            throw new IOException("Unable to acquire " + PAGE_SIZE + " bytes of memory");
+            throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory");
           }
         }
-        currentPage = memoryManager.allocatePage(PAGE_SIZE);
+        currentPage = memoryManager.allocatePage(pageSizeBytes);
         currentPagePosition = currentPage.getBaseOffset();
-        freeSpaceInCurrentPage = PAGE_SIZE;
+        freeSpaceInCurrentPage = pageSizeBytes;
         allocatedPages.add(currentPage);
       }
     }
diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
index 10c3eedbf4b46b81b4d59895fedf7fb352ff7448..04fc09b323dbb2073beeeede6d0fb75a6a849f67 100644
--- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
@@ -111,7 +111,7 @@ public class UnsafeShuffleWriterSuite {
     mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir);
     partitionSizesInMergedFile = null;
     spillFilesCreated.clear();
-    conf = new SparkConf();
+    conf = new SparkConf().set("spark.buffer.pageSize", "128m");
     taskMetrics = new TaskMetrics();
 
     when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg());
@@ -512,12 +512,12 @@ public class UnsafeShuffleWriterSuite {
     writer.insertRecordIntoSorter(new Tuple2<Object, Object>(new byte[1], new byte[1]));
     writer.forceSorterToSpill();
     // We should be able to write a record that's right _at_ the max record size
-    final byte[] atMaxRecordSize = new byte[UnsafeShuffleExternalSorter.MAX_RECORD_SIZE];
+    final byte[] atMaxRecordSize = new byte[writer.maxRecordSizeBytes()];
     new Random(42).nextBytes(atMaxRecordSize);
     writer.insertRecordIntoSorter(new Tuple2<Object, Object>(new byte[0], atMaxRecordSize));
     writer.forceSorterToSpill();
     // Inserting a record that's larger than the max record size should fail:
-    final byte[] exceedsMaxRecordSize = new byte[UnsafeShuffleExternalSorter.MAX_RECORD_SIZE + 1];
+    final byte[] exceedsMaxRecordSize = new byte[writer.maxRecordSizeBytes() + 1];
     new Random(42).nextBytes(exceedsMaxRecordSize);
     Product2<Object, Object> hugeRecord =
       new Tuple2<Object, Object>(new byte[0], exceedsMaxRecordSize);
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
index 684de6e81d67c2aa76554b81fbd26e8c5ac23eb4..03f4c3ed8e6bbfc31aeb0e22ed8c7e770722fb5d 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
@@ -95,6 +95,7 @@ public final class UnsafeFixedWidthAggregationMap {
    * @param groupingKeySchema the schema of the grouping key, used for row conversion.
    * @param memoryManager the memory manager used to allocate our Unsafe memory structures.
    * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing).
+   * @param pageSizeBytes the data page size, in bytes; limits the maximum record size.
    * @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact)
    */
   public UnsafeFixedWidthAggregationMap(
@@ -103,11 +104,13 @@ public final class UnsafeFixedWidthAggregationMap {
       StructType groupingKeySchema,
       TaskMemoryManager memoryManager,
       int initialCapacity,
+      long pageSizeBytes,
       boolean enablePerfMetrics) {
     this.aggregationBufferSchema = aggregationBufferSchema;
     this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema);
     this.groupingKeySchema = groupingKeySchema;
-    this.map = new BytesToBytesMap(memoryManager, initialCapacity, enablePerfMetrics);
+    this.map =
+      new BytesToBytesMap(memoryManager, initialCapacity, pageSizeBytes, enablePerfMetrics);
     this.enablePerfMetrics = enablePerfMetrics;
 
     // Initialize the buffer for aggregation value
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
index 48b7dc57451a3881e1d3ba20fd2ad2305e05eeea..6a907290f2dbee3f2404a94a1b9efb5f6790fb98 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
@@ -39,6 +39,7 @@ class UnsafeFixedWidthAggregationMapSuite
   private val groupKeySchema = StructType(StructField("product", StringType) :: Nil)
   private val aggBufferSchema = StructType(StructField("salePrice", IntegerType) :: Nil)
   private def emptyAggregationBuffer: InternalRow = InternalRow(0)
+  private val PAGE_SIZE_BYTES: Long = 1L << 26; // 64 megabytes
 
   private var memoryManager: TaskMemoryManager = null
 
@@ -69,7 +70,8 @@ class UnsafeFixedWidthAggregationMapSuite
       aggBufferSchema,
       groupKeySchema,
       memoryManager,
-      1024, // initial capacity
+      1024, // initial capacity,
+      PAGE_SIZE_BYTES,
       false // disable perf metrics
     )
     assert(!map.iterator().hasNext)
@@ -83,6 +85,7 @@ class UnsafeFixedWidthAggregationMapSuite
       groupKeySchema,
       memoryManager,
       1024, // initial capacity
+      PAGE_SIZE_BYTES,
       false // disable perf metrics
     )
     val groupKey = InternalRow(UTF8String.fromString("cats"))
@@ -109,6 +112,7 @@ class UnsafeFixedWidthAggregationMapSuite
       groupKeySchema,
       memoryManager,
       128, // initial capacity
+      PAGE_SIZE_BYTES,
       false // disable perf metrics
     )
     val rand = new Random(42)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index 1cd1420480f035bded84e4c360b6d8786cfbf5e4..b85aada9d9d4c91f7cd1c095bb3d54a97727ddd2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.execution
 
-import org.apache.spark.TaskContext
+import org.apache.spark.{SparkEnv, TaskContext}
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
@@ -260,12 +260,14 @@ case class GeneratedAggregate(
       } else if (unsafeEnabled && schemaSupportsUnsafe) {
         assert(iter.hasNext, "There should be at least one row for this path")
         log.info("Using Unsafe-based aggregator")
+        val pageSizeBytes = SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m")
         val aggregationMap = new UnsafeFixedWidthAggregationMap(
           newAggregationBuffer(EmptyRow),
           aggregationBufferSchema,
           groupKeySchema,
           TaskContext.get.taskMemoryManager(),
           1024 * 16, // initial capacity
+          pageSizeBytes,
           false // disable tracking of performance metrics
         )
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 9c058f1f72fe411594fd0ad2082b2680acf05337..7a507391316a98f6d3dd43edf7448db09c258ff7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -21,6 +21,7 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput}
 import java.nio.ByteOrder
 import java.util.{HashMap => JavaHashMap}
 
+import org.apache.spark.{SparkEnv, TaskContext}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.execution.SparkSqlSerializer
@@ -259,7 +260,11 @@ private[joins] final class UnsafeHashedRelation(
     val nKeys = in.readInt()
     // This is used in Broadcast, shared by multiple tasks, so we use on-heap memory
     val memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
-    binaryMap = new BytesToBytesMap(memoryManager, nKeys * 2) // reduce hash collision
+    val pageSizeBytes = SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m")
+    binaryMap = new BytesToBytesMap(
+      memoryManager,
+      nKeys * 2, // reduce hash collision
+      pageSizeBytes)
 
     var i = 0
     var keyBuffer = new Array[Byte](1024)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
index 3662a4352f55d7fffc241ee51c9663724396391a..7bbdef90cd6b9fff9ecabd2642aa2439992f6835 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
@@ -56,6 +56,7 @@ object TestHive
         .set("spark.sql.test", "")
         .set("spark.sql.hive.metastore.barrierPrefixes",
           "org.apache.spark.sql.hive.execution.PairSerDe")
+        .set("spark.buffer.pageSize", "4m")
         // SPARK-8910
         .set("spark.ui.enabled", "false")))
 
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index d0bde69cc106835d267fdc8b49bf36f4ee81441f..198e0684f32f8d330b1c0e54c271f8b2edb6bb95 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -74,12 +74,6 @@ public final class BytesToBytesMap {
    */
   private long pageCursor = 0;
 
-  /**
-   * The size of the data pages that hold key and value data. Map entries cannot span multiple
-   * pages, so this limits the maximum entry size.
-   */
-  private static final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes
-
   /**
    * The maximum number of keys that BytesToBytesMap supports. The hash table has to be
    * power-of-2-sized and its backing Java array can contain at most (1 << 30) elements, since
@@ -117,6 +111,12 @@ public final class BytesToBytesMap {
 
   private final double loadFactor;
 
+  /**
+   * The size of the data pages that hold key and value data. Map entries cannot span multiple
+   * pages, so this limits the maximum entry size.
+   */
+  private final long pageSizeBytes;
+
   /**
    * Number of keys defined in the map.
    */
@@ -153,10 +153,12 @@ public final class BytesToBytesMap {
       TaskMemoryManager memoryManager,
       int initialCapacity,
       double loadFactor,
+      long pageSizeBytes,
       boolean enablePerfMetrics) {
     this.memoryManager = memoryManager;
     this.loadFactor = loadFactor;
     this.loc = new Location();
+    this.pageSizeBytes = pageSizeBytes;
     this.enablePerfMetrics = enablePerfMetrics;
     if (initialCapacity <= 0) {
       throw new IllegalArgumentException("Initial capacity must be greater than 0");
@@ -165,18 +167,26 @@ public final class BytesToBytesMap {
       throw new IllegalArgumentException(
         "Initial capacity " + initialCapacity + " exceeds maximum capacity of " + MAX_CAPACITY);
     }
+    if (pageSizeBytes > TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES) {
+      throw new IllegalArgumentException("Page size " + pageSizeBytes + " cannot exceed " +
+        TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES);
+    }
     allocate(initialCapacity);
   }
 
-  public BytesToBytesMap(TaskMemoryManager memoryManager, int initialCapacity) {
-    this(memoryManager, initialCapacity, 0.70, false);
+  public BytesToBytesMap(
+      TaskMemoryManager memoryManager,
+      int initialCapacity,
+      long pageSizeBytes) {
+    this(memoryManager, initialCapacity, 0.70, pageSizeBytes, false);
   }
 
   public BytesToBytesMap(
       TaskMemoryManager memoryManager,
       int initialCapacity,
+      long pageSizeBytes,
       boolean enablePerfMetrics) {
-    this(memoryManager, initialCapacity, 0.70, enablePerfMetrics);
+    this(memoryManager, initialCapacity, 0.70, pageSizeBytes, enablePerfMetrics);
   }
 
   /**
@@ -443,20 +453,20 @@ public final class BytesToBytesMap {
       // must be stored in the same memory page.
       // (8 byte key length) (key) (8 byte value length) (value)
       final long requiredSize = 8 + keyLengthBytes + 8 + valueLengthBytes;
-      assert (requiredSize <= PAGE_SIZE_BYTES - 8); // Reserve 8 bytes for the end-of-page marker.
+      assert (requiredSize <= pageSizeBytes - 8); // Reserve 8 bytes for the end-of-page marker.
       size++;
       bitset.set(pos);
 
       // If there's not enough space in the current page, allocate a new page (8 bytes are reserved
       // for the end-of-page marker).
-      if (currentDataPage == null || PAGE_SIZE_BYTES - 8 - pageCursor < requiredSize) {
+      if (currentDataPage == null || pageSizeBytes - 8 - pageCursor < requiredSize) {
         if (currentDataPage != null) {
           // There wasn't enough space in the current page, so write an end-of-page marker:
           final Object pageBaseObject = currentDataPage.getBaseObject();
           final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor;
           PlatformDependent.UNSAFE.putLong(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER);
         }
-        MemoryBlock newPage = memoryManager.allocatePage(PAGE_SIZE_BYTES);
+        MemoryBlock newPage = memoryManager.allocatePage(pageSizeBytes);
         dataPages.add(newPage);
         pageCursor = 0;
         currentDataPage = newPage;
@@ -538,10 +548,11 @@ public final class BytesToBytesMap {
 
   /** Returns the total amount of memory, in bytes, consumed by this map's managed structures. */
   public long getTotalMemoryConsumption() {
-    return (
-      dataPages.size() * PAGE_SIZE_BYTES +
-      bitset.memoryBlock().size() +
-      longArray.memoryBlock().size());
+    long totalDataPagesSize = 0L;
+    for (MemoryBlock dataPage : dataPages) {
+      totalDataPagesSize += dataPage.size();
+    }
+    return totalDataPagesSize + bitset.memoryBlock().size() + longArray.memoryBlock().size();
   }
 
   /**
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
index 10881969dbc784d3cebafd52988efa125e33c338..dd70df3b1f791c2e1d08736e1526819af94472ae 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
@@ -58,8 +58,13 @@ public class TaskMemoryManager {
   /** The number of entries in the page table. */
   private static final int PAGE_TABLE_SIZE = 1 << PAGE_NUMBER_BITS;
 
-  /** Maximum supported data page size */
-  private static final long MAXIMUM_PAGE_SIZE = (1L << OFFSET_BITS);
+  /**
+   * Maximum supported data page size (in bytes). In principle, the maximum addressable page size is
+   * (1L << OFFSET_BITS) bytes, which is 2+ petabytes. However, the on-heap allocator's maximum page
+   * size is limited by the maximum amount of data that can be stored in a  long[] array, which is
+   * (2^32 - 1) * 8 bytes (or 16 gigabytes). Therefore, we cap this at 16 gigabytes.
+   */
+  public static final long MAXIMUM_PAGE_SIZE_BYTES = ((1L << 31) - 1) * 8L;
 
   /** Bit mask for the lower 51 bits of a long. */
   private static final long MASK_LONG_LOWER_51_BITS = 0x7FFFFFFFFFFFFL;
@@ -110,9 +115,9 @@ public class TaskMemoryManager {
    * intended for allocating large blocks of memory that will be shared between operators.
    */
   public MemoryBlock allocatePage(long size) {
-    if (size > MAXIMUM_PAGE_SIZE) {
+    if (size > MAXIMUM_PAGE_SIZE_BYTES) {
       throw new IllegalArgumentException(
-        "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE + " bytes");
+        "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE_BYTES + " bytes");
     }
 
     final int pageNumber;
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index dae47e4bab0cb3b1c89b6825e53df263372a7110..0be94ad37125580d7d2203eaba2784696cf46d8d 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -43,6 +43,7 @@ public abstract class AbstractBytesToBytesMapSuite {
 
   private TaskMemoryManager memoryManager;
   private TaskMemoryManager sizeLimitedMemoryManager;
+  private final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes
 
   @Before
   public void setup() {
@@ -110,7 +111,7 @@ public abstract class AbstractBytesToBytesMapSuite {
 
   @Test
   public void emptyMap() {
-    BytesToBytesMap map = new BytesToBytesMap(memoryManager, 64);
+    BytesToBytesMap map = new BytesToBytesMap(memoryManager, 64, PAGE_SIZE_BYTES);
     try {
       Assert.assertEquals(0, map.size());
       final int keyLengthInWords = 10;
@@ -125,7 +126,7 @@ public abstract class AbstractBytesToBytesMapSuite {
 
   @Test
   public void setAndRetrieveAKey() {
-    BytesToBytesMap map = new BytesToBytesMap(memoryManager, 64);
+    BytesToBytesMap map = new BytesToBytesMap(memoryManager, 64, PAGE_SIZE_BYTES);
     final int recordLengthWords = 10;
     final int recordLengthBytes = recordLengthWords * 8;
     final byte[] keyData = getRandomByteArray(recordLengthWords);
@@ -177,7 +178,7 @@ public abstract class AbstractBytesToBytesMapSuite {
   @Test
   public void iteratorTest() throws Exception {
     final int size = 4096;
-    BytesToBytesMap map = new BytesToBytesMap(memoryManager, size / 2);
+    BytesToBytesMap map = new BytesToBytesMap(memoryManager, size / 2, PAGE_SIZE_BYTES);
     try {
       for (long i = 0; i < size; i++) {
         final long[] value = new long[] { i };
@@ -235,7 +236,7 @@ public abstract class AbstractBytesToBytesMapSuite {
     final int NUM_ENTRIES = 1000 * 1000;
     final int KEY_LENGTH = 16;
     final int VALUE_LENGTH = 40;
-    final BytesToBytesMap map = new BytesToBytesMap(memoryManager, NUM_ENTRIES);
+    final BytesToBytesMap map = new BytesToBytesMap(memoryManager, NUM_ENTRIES, PAGE_SIZE_BYTES);
     // Each record will take 8 + 8 + 16 + 40 = 72 bytes of space in the data page. Our 64-megabyte
     // pages won't be evenly-divisible by records of this size, which will cause us to waste some
     // space at the end of the page. This is necessary in order for us to take the end-of-record
@@ -304,7 +305,7 @@ public abstract class AbstractBytesToBytesMapSuite {
     // Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays
     // into ByteBuffers in order to use them as keys here.
     final Map<ByteBuffer, byte[]> expected = new HashMap<ByteBuffer, byte[]>();
-    final BytesToBytesMap map = new BytesToBytesMap(memoryManager, size);
+    final BytesToBytesMap map = new BytesToBytesMap(memoryManager, size, PAGE_SIZE_BYTES);
 
     try {
       // Fill the map to 90% full so that we can trigger probing
@@ -353,14 +354,15 @@ public abstract class AbstractBytesToBytesMapSuite {
   @Test
   public void initialCapacityBoundsChecking() {
     try {
-      new BytesToBytesMap(sizeLimitedMemoryManager, 0);
+      new BytesToBytesMap(sizeLimitedMemoryManager, 0, PAGE_SIZE_BYTES);
       Assert.fail("Expected IllegalArgumentException to be thrown");
     } catch (IllegalArgumentException e) {
       // expected exception
     }
 
     try {
-      new BytesToBytesMap(sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY + 1);
+      new BytesToBytesMap(
+        sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY + 1, PAGE_SIZE_BYTES);
       Assert.fail("Expected IllegalArgumentException to be thrown");
     } catch (IllegalArgumentException e) {
       // expected exception
@@ -368,15 +370,15 @@ public abstract class AbstractBytesToBytesMapSuite {
 
    // Can allocate _at_ the max capacity
     BytesToBytesMap map =
-      new BytesToBytesMap(sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY);
+      new BytesToBytesMap(sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY, PAGE_SIZE_BYTES);
     map.free();
   }
 
   @Test
   public void resizingLargeMap() {
     // As long as a map's capacity is below the max, we should be able to resize up to the max
-    BytesToBytesMap map =
-      new BytesToBytesMap(sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY - 64);
+    BytesToBytesMap map = new BytesToBytesMap(
+      sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY - 64, PAGE_SIZE_BYTES);
     map.growAndRehash();
     map.free();
   }