diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 0427db4e3bf2550818a2fd401f9e8b71fb65c9ce..b280c76c70a61a5c492bc2b132d18bb4294b9077 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -173,8 +173,8 @@ private[joins] class UnsafeHashedRelation(
   override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
     out.writeInt(numFields)
     // TODO: move these into BytesToBytesMap
-    out.writeInt(binaryMap.numKeys())
-    out.writeInt(binaryMap.numValues())
+    out.writeLong(binaryMap.numKeys())
+    out.writeLong(binaryMap.numValues())
 
     var buffer = new Array[Byte](64)
     def write(base: Object, offset: Long, length: Int): Unit = {
@@ -199,8 +199,8 @@ private[joins] class UnsafeHashedRelation(
   override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
     numFields = in.readInt()
     resultRow = new UnsafeRow(numFields)
-    val nKeys = in.readInt()
-    val nValues = in.readInt()
+    val nKeys = in.readLong()
+    val nValues = in.readLong()
     // This is used in Broadcast, shared by multiple tasks, so we use on-heap memory
     // TODO(josh): This needs to be revisited before we merge this patch; making this change now
     // so that tests compile:
@@ -345,16 +345,20 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap
 
   // The page to store all bytes of UnsafeRow and the pointer to next rows.
   // [row1][pointer1] [row2][pointer2]
-  private var page: Array[Byte] = null
+  private var page: Array[Long] = null
 
   // Current write cursor in the page.
-  private var cursor = Platform.BYTE_ARRAY_OFFSET
+  private var cursor: Long = Platform.LONG_ARRAY_OFFSET
+
+  // The number of bits for size in address
+  private val SIZE_BITS = 28
+  private val SIZE_MASK = 0xfffffff
 
   // The total number of values of all keys.
-  private var numValues = 0
+  private var numValues = 0L
 
   // The number of unique keys.
-  private var numKeys = 0
+  private var numKeys = 0L
 
   // needed by serializer
   def this() = {
@@ -390,7 +394,7 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap
       acquireMemory(n * 2 * 8 + (1 << 20))
       array = new Array[Long](n * 2)
       mask = n * 2 - 2
-      page = new Array[Byte](1 << 20)  // 1M bytes
+      page = new Array[Long](1 << 17)  // 1M bytes
     }
   }
 
@@ -406,7 +410,7 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap
   /**
    * Returns total memory consumption.
    */
-  def getTotalMemoryConsumption: Long = array.length * 8 + page.length
+  def getTotalMemoryConsumption: Long = array.length * 8L + page.length * 8L
 
   /**
    * Returns the first slot of array that store the keys (sparse mode).
@@ -422,8 +426,8 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap
   private def nextSlot(pos: Int): Int = (pos + 2) & mask
 
   private def getRow(address: Long, resultRow: UnsafeRow): UnsafeRow = {
-    val offset = address >>> 32
-    val size = address & 0xffffffffL
+    val offset = address >>> SIZE_BITS
+    val size = address & SIZE_MASK
     resultRow.pointTo(page, offset, size.toInt)
     resultRow
   }
@@ -450,15 +454,15 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap
   }
 
   /**
-   * Returns an interator of UnsafeRow for multiple linked values.
+   * Returns an iterator of UnsafeRow for multiple linked values.
    */
   private def valueIter(address: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = {
     new Iterator[UnsafeRow] {
       var addr = address
       override def hasNext: Boolean = addr != 0
       override def next(): UnsafeRow = {
-        val offset = addr >>> 32
-        val size = addr & 0xffffffffL
+        val offset = addr >>> SIZE_BITS
+        val size = addr & SIZE_MASK
         resultRow.pointTo(page, offset, size.toInt)
         addr = Platform.getLong(page, offset + size)
         resultRow
@@ -491,6 +495,11 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap
    * Appends the key and row into this map.
    */
   def append(key: Long, row: UnsafeRow): Unit = {
+    val sizeInBytes = row.getSizeInBytes
+    if (sizeInBytes >= (1 << SIZE_BITS)) {
+      sys.error("Does not support row that is larger than 256M")
+    }
+
     if (key < minKey) {
       minKey = key
     }
@@ -499,16 +508,17 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap
     }
 
     // There is 8 bytes for the pointer to next value
-    if (cursor + 8 + row.getSizeInBytes > page.length + Platform.BYTE_ARRAY_OFFSET) {
+    if (cursor + 8 + row.getSizeInBytes > page.length * 8L + Platform.LONG_ARRAY_OFFSET) {
       val used = page.length
-      if (used * 2L > (1L << 31)) {
-        sys.error("Can't allocate a page that is larger than 2G")
+      if (used >= (1 << 30)) {
+        sys.error("Can not build a HashedRelation that is larger than 8G")
       }
-      acquireMemory(used * 2)
-      val newPage = new Array[Byte](used * 2)
-      System.arraycopy(page, 0, newPage, 0, cursor - Platform.BYTE_ARRAY_OFFSET)
+      acquireMemory(used * 8L * 2)
+      val newPage = new Array[Long](used * 2)
+      Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET,
+        cursor - Platform.LONG_ARRAY_OFFSET)
       page = newPage
-      freeMemory(used)
+      freeMemory(used * 8)
     }
 
     // copy the bytes of UnsafeRow
@@ -518,7 +528,7 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap
     Platform.putLong(page, cursor, 0)
     cursor += 8
     numValues += 1
-    updateIndex(key, (offset.toLong << 32) | row.getSizeInBytes)
+    updateIndex(key, (offset.toLong << SIZE_BITS) | row.getSizeInBytes)
   }
 
   /**
@@ -536,11 +546,17 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap
       numKeys += 1
       if (numKeys * 4 > array.length) {
         // reach half of the capacity
-        growArray()
+        if (array.length < (1 << 30)) {
+          // Cannot allocate an array with 2G elements
+          growArray()
+        } else if (numKeys > array.length / 2 * 0.75) {
+          // The fill ratio should be less than 0.75
+          sys.error("Cannot build HashedRelation with more than 1/3 billions unique keys")
+        }
       }
     } else {
       // there are some values for this key, put the address in the front of them.
-      val pointer = (address >>> 32) + (address & 0xffffffffL)
+      val pointer = (address >>> SIZE_BITS) + (address & SIZE_MASK)
       Platform.putLong(page, pointer, array(pos + 1))
       array(pos + 1) = address
     }
@@ -550,7 +566,7 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap
     var old_array = array
     val n = array.length
     numKeys = 0
-    acquireMemory(n * 2 * 8)
+    acquireMemory(n * 2 * 8L)
     array = new Array[Long](n * 2)
     mask = n * 2 - 2
     var i = 0
@@ -599,7 +615,7 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap
    */
   def free(): Unit = {
     if (page != null) {
-      freeMemory(page.length)
+      freeMemory(page.length * 8)
       page = null
     }
     if (array != null) {
@@ -608,52 +624,58 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap
     }
   }
 
+  private def writeLongArray(out: ObjectOutput, arr: Array[Long], len: Int): Unit = {
+    val buffer = new Array[Byte](4 << 10)
+    var offset: Long = Platform.LONG_ARRAY_OFFSET
+    val end = len * 8L + Platform.LONG_ARRAY_OFFSET
+    while (offset < end) {
+      val size = Math.min(buffer.length, (end - offset).toInt)
+      Platform.copyMemory(arr, offset, buffer, Platform.BYTE_ARRAY_OFFSET, size)
+      out.write(buffer, 0, size)
+      offset += size
+    }
+  }
+
   override def writeExternal(out: ObjectOutput): Unit = {
     out.writeBoolean(isDense)
     out.writeLong(minKey)
     out.writeLong(maxKey)
-    out.writeInt(numKeys)
-    out.writeInt(numValues)
+    out.writeLong(numKeys)
+    out.writeLong(numValues)
+
+    out.writeLong(array.length)
+    writeLongArray(out, array, array.length)
+    val used = ((cursor - Platform.LONG_ARRAY_OFFSET) / 8).toInt
+    out.writeLong(used)
+    writeLongArray(out, page, used)
+  }
 
-    out.writeInt(array.length)
+  private def readLongArray(in: ObjectInput, length: Int): Array[Long] = {
+    val array = new Array[Long](length)
     val buffer = new Array[Byte](4 << 10)
-    var offset = Platform.LONG_ARRAY_OFFSET
-    val end = array.length * 8 + Platform.LONG_ARRAY_OFFSET
+    var offset: Long = Platform.LONG_ARRAY_OFFSET
+    val end = length * 8L + Platform.LONG_ARRAY_OFFSET
     while (offset < end) {
-      val size = Math.min(buffer.length, end - offset)
-      Platform.copyMemory(array, offset, buffer, Platform.BYTE_ARRAY_OFFSET, size)
-      out.write(buffer, 0, size)
+      val size = Math.min(buffer.length, (end - offset).toInt)
+      in.readFully(buffer, 0, size)
+      Platform.copyMemory(buffer, Platform.BYTE_ARRAY_OFFSET, array, offset, size)
       offset += size
     }
-
-    val used = cursor - Platform.BYTE_ARRAY_OFFSET
-    out.writeInt(used)
-    out.write(page, 0, used)
+    array
   }
 
   override def readExternal(in: ObjectInput): Unit = {
     isDense = in.readBoolean()
     minKey = in.readLong()
     maxKey = in.readLong()
-    numKeys = in.readInt()
-    numValues = in.readInt()
+    numKeys = in.readLong
+    numValues = in.readLong()
 
-    val length = in.readInt()
-    array = new Array[Long](length)
+    val length = in.readLong().toInt
     mask = length - 2
-    val buffer = new Array[Byte](4 << 10)
-    var offset = Platform.LONG_ARRAY_OFFSET
-    val end = length * 8 + Platform.LONG_ARRAY_OFFSET
-    while (offset < end) {
-      val size = Math.min(buffer.length, end - offset)
-      in.readFully(buffer, 0, size)
-      Platform.copyMemory(buffer, Platform.BYTE_ARRAY_OFFSET, array, offset, size)
-      offset += size
-    }
-
-    val numBytes = in.readInt()
-    page = new Array[Byte](numBytes)
-    in.readFully(page)
+    array = readLongArray(in, length)
+    val pageLength = in.readLong().toInt
+    page = readLongArray(in, pageLength)
   }
 }
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index 371a9ed617d65c0c90811764893289a869fad0a9..3ee25c0996035e34ad656b53afce7a48ac108bd2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -24,8 +24,9 @@ import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
+import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
 import org.apache.spark.unsafe.map.BytesToBytesMap
+import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.collection.CompactBuffer
 
 class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
@@ -149,4 +150,31 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
       assert(rows(1).getInt(1) === i + 1)
     }
   }
+
+  // This test require 4G heap to run, should run it manually
+  ignore("build HashedRelation that is larger than 1G") {
+    val unsafeProj = UnsafeProjection.create(
+      Seq(BoundReference(0, IntegerType, false),
+        BoundReference(1, StringType, true)))
+    val unsafeRow = unsafeProj(InternalRow(0, UTF8String.fromString(" " * 100)))
+    val key = Seq(BoundReference(0, IntegerType, false))
+    val rows = (0 until (1 << 24)).iterator.map { i =>
+      unsafeRow.setInt(0, i % 1000000)
+      unsafeRow.setInt(1, i)
+      unsafeRow
+    }
+
+    val unsafeRelation = UnsafeHashedRelation(rows, key, 1000, mm)
+    assert(unsafeRelation.estimatedSize > (2L << 30))
+    unsafeRelation.close()
+
+    val rows2 = (0 until (1 << 24)).iterator.map { i =>
+      unsafeRow.setInt(0, i % 1000000)
+      unsafeRow.setInt(1, i)
+      unsafeRow
+    }
+    val longRelation = LongHashedRelation(rows2, key, 1000, mm)
+    assert(longRelation.estimatedSize > (2L << 30))
+    longRelation.close()
+  }
 }