diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 131efea20f31e0f2f58e4a998bdbcd3819c79399..4ca2d85406bb77e40b9cd4a6eeb1214e0b82789d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -38,6 +38,7 @@ trait CodegenSupport extends SparkPlan {
   /** Prefix used in the current operator's variable names. */
   private def variablePrefix: String = this match {
     case _: TungstenAggregate => "agg"
+    case _: BroadcastHashJoin => "bhj"
     case _ => nodeName.toLowerCase
   }
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index 943ad31c0cef52bd950081b0a21fbb30cebf938c..cbd549763ac95876b14bfc388de0530b2c0307a7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -90,8 +90,14 @@ case class BroadcastHashJoin(
         // The following line doesn't run in a job so we cannot track the metric value. However, we
         // have already tracked it in the above lines. So here we can use
         // `SQLMetrics.nullLongMetric` to ignore it.
-        val hashed = HashedRelation(
-          input.iterator, SQLMetrics.nullLongMetric, buildSideKeyGenerator, input.size)
+        // TODO: move this check into HashedRelation
+        val hashed = if (canJoinKeyFitWithinLong) {
+          LongHashedRelation(
+            input.iterator, SQLMetrics.nullLongMetric, buildSideKeyGenerator, input.size)
+        } else {
+          HashedRelation(
+            input.iterator, SQLMetrics.nullLongMetric, buildSideKeyGenerator, input.size)
+        }
         sparkContext.broadcast(hashed)
       }
     }(BroadcastHashJoin.broadcastHashJoinExecutionContext)
@@ -112,15 +118,12 @@ case class BroadcastHashJoin(
 
     streamedPlan.execute().mapPartitions { streamedIter =>
       val hashedRelation = broadcastRelation.value
-      hashedRelation match {
-        case unsafe: UnsafeHashedRelation =>
-          TaskContext.get().taskMetrics().incPeakExecutionMemory(unsafe.getUnsafeSize)
-        case _ =>
-      }
+      TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize)
       hashJoin(streamedIter, numStreamedRows, hashedRelation, numOutputRows)
     }
   }
 
+  private var broadcastRelation: Broadcast[HashedRelation] = _
   // the term for hash relation
   private var relationTerm: String = _
 
@@ -129,16 +132,15 @@ case class BroadcastHashJoin(
   }
 
   override def doProduce(ctx: CodegenContext): String = {
-    // create a name for HashRelation
-    val broadcastRelation = Await.result(broadcastFuture, timeout)
+    // create a name for HashedRelation
+    broadcastRelation = Await.result(broadcastFuture, timeout)
     val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation)
     relationTerm = ctx.freshName("relation")
-    // TODO: create specialized HashRelation for single join key
-    val clsName = classOf[UnsafeHashedRelation].getName
+    val clsName = broadcastRelation.value.getClass.getName
     ctx.addMutableState(clsName, relationTerm,
       s"""
          | $relationTerm = ($clsName) $broadcast.value();
-         | incPeakExecutionMemory($relationTerm.getUnsafeSize());
+         | incPeakExecutionMemory($relationTerm.getMemorySize());
        """.stripMargin)
 
     s"""
@@ -147,23 +149,24 @@ case class BroadcastHashJoin(
   }
 
   override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
-    // generate the key as UnsafeRow
+    // generate the key as UnsafeRow or Long
     ctx.currentVars = input
-    val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output))
-    val keyVal = GenerateUnsafeProjection.createCode(ctx, keyExpr)
-    val keyTerm = keyVal.value
-    val anyNull = if (keyExpr.exists(_.nullable)) s"$keyTerm.anyNull()" else "false"
+    val (keyVal, anyNull) = if (canJoinKeyFitWithinLong) {
+      val expr = rewriteKeyExpr(streamedKeys).head
+      val ev = BindReferences.bindReference(expr, streamedPlan.output).gen(ctx)
+      (ev, ev.isNull)
+    } else {
+      val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output))
+      val ev = GenerateUnsafeProjection.createCode(ctx, keyExpr)
+      (ev, s"${ev.value}.anyNull()")
+    }
 
     // find the matches from HashedRelation
-    val matches = ctx.freshName("matches")
-    val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
-    val i = ctx.freshName("i")
-    val size = ctx.freshName("size")
-    val row = ctx.freshName("row")
+    val matched = ctx.freshName("matched")
 
     // create variables for output
     ctx.currentVars = null
-    ctx.INPUT_ROW = row
+    ctx.INPUT_ROW = matched
     val buildColumns = buildPlan.output.zipWithIndex.map { case (a, i) =>
       BoundReference(i, a.dataType, a.nullable).gen(ctx)
     }
@@ -172,7 +175,7 @@ case class BroadcastHashJoin(
       case BuildRight => input ++ buildColumns
     }
 
-    val ouputCode = if (condition.isDefined) {
+    val outputCode = if (condition.isDefined) {
       // filter the output via condition
       ctx.currentVars = resultVars
       val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx)
@@ -186,20 +189,39 @@ case class BroadcastHashJoin(
       consume(ctx, resultVars)
     }
 
-    s"""
-       | // generate join key
-       | ${keyVal.code}
-       | // find matches from HashRelation
-       | $bufferType $matches = $anyNull ? null : ($bufferType) $relationTerm.get($keyTerm);
-       | if ($matches != null) {
-       |   int $size = $matches.size();
-       |   for (int $i = 0; $i < $size; $i++) {
-       |     UnsafeRow $row = (UnsafeRow) $matches.apply($i);
-       |     ${buildColumns.map(_.code).mkString("\n")}
-       |     $ouputCode
-       |   }
-       | }
+    if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) {
+      s"""
+         | // generate join key
+         | ${keyVal.code}
+         | // find matches from HashedRelation
+         | UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyVal.value});
+         | if ($matched != null) {
+         |   ${buildColumns.map(_.code).mkString("\n")}
+         |   $outputCode
+         | }
      """.stripMargin
+
+    } else {
+      val matches = ctx.freshName("matches")
+      val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
+      val i = ctx.freshName("i")
+      val size = ctx.freshName("size")
+      s"""
+         | // generate join key
+         | ${keyVal.code}
+         | // find matches from HashRelation
+         | $bufferType $matches = ${anyNull} ? null :
+         |  ($bufferType) $relationTerm.get(${keyVal.value});
+         | if ($matches != null) {
+         |   int $size = $matches.size();
+         |   for (int $i = 0; $i < $size; $i++) {
+         |     UnsafeRow $matched = (UnsafeRow) $matches.apply($i);
+         |     ${buildColumns.map(_.code).mkString("\n")}
+         |     $outputCode
+         |   }
+         | }
+     """.stripMargin
+    }
   }
 }
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
index f48fc3b84864d9872eca4b26d4a312357da99525..ad3275696e637237ba5201d1a30096f0bdaabe89 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
@@ -116,12 +116,7 @@ case class BroadcastHashOuterJoin(
       val joinedRow = new JoinedRow()
       val hashTable = broadcastRelation.value
       val keyGenerator = streamedKeyGenerator
-
-      hashTable match {
-        case unsafe: UnsafeHashedRelation =>
-          TaskContext.get().taskMetrics().incPeakExecutionMemory(unsafe.getUnsafeSize)
-        case _ =>
-      }
+      TaskContext.get().taskMetrics().incPeakExecutionMemory(hashTable.getMemorySize)
 
       val resultProj = resultProjection
       joinType match {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
index 8929dc3af19121497f4c916ff3236e222ad31ae4..d0e18dfcf3d9003e268925becbded24f178bd5d4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
@@ -64,11 +64,7 @@ case class BroadcastLeftSemiJoinHash(
 
       left.execute().mapPartitionsInternal { streamIter =>
         val hashedRelation = broadcastedRelation.value
-        hashedRelation match {
-          case unsafe: UnsafeHashedRelation =>
-            TaskContext.get().taskMetrics().incPeakExecutionMemory(unsafe.getUnsafeSize)
-          case _ =>
-        }
+        TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize)
         hashSemiJoin(streamIter, numLeftRows, hashedRelation, numOutputRows)
       }
     }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index 8ef854001f4de063afd4038649fcbd165061451c..ecbb1ac64b7c08e2574571b7f83a3116671d5ee9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.execution.metric.LongSQLMetric
-
+import org.apache.spark.sql.types.{IntegralType, LongType}
 
 trait HashJoin {
   self: SparkPlan =>
@@ -47,11 +47,49 @@ trait HashJoin {
 
   override def output: Seq[Attribute] = left.output ++ right.output
 
+  /**
+    * Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long.
+    *
+    * If not, returns the original expressions.
+    */
+  def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = {
+    var keyExpr: Expression = null
+    var width = 0
+    keys.foreach { e =>
+      e.dataType match {
+        case dt: IntegralType if dt.defaultSize <= 8 - width =>
+          if (width == 0) {
+            if (e.dataType != LongType) {
+              keyExpr = Cast(e, LongType)
+            } else {
+              keyExpr = e
+            }
+            width = dt.defaultSize
+          } else {
+            val bits = dt.defaultSize * 8
+            keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)),
+              BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1)))
+            width -= bits
+          }
+        // TODO: support BooleanType, DateType and TimestampType
+        case other =>
+          return keys
+      }
+    }
+    keyExpr :: Nil
+  }
+
+  protected val canJoinKeyFitWithinLong: Boolean = {
+    val sameTypes = buildKeys.map(_.dataType) == streamedKeys.map(_.dataType)
+    val key = rewriteKeyExpr(buildKeys)
+    sameTypes && key.length == 1 && key.head.dataType.isInstanceOf[LongType]
+  }
+
   protected def buildSideKeyGenerator: Projection =
-    UnsafeProjection.create(buildKeys, buildPlan.output)
+    UnsafeProjection.create(rewriteKeyExpr(buildKeys), buildPlan.output)
 
   protected def streamSideKeyGenerator: Projection =
-    UnsafeProjection.create(streamedKeys, streamedPlan.output)
+    UnsafeProjection.create(rewriteKeyExpr(streamedKeys), streamedPlan.output)
 
   @transient private[this] lazy val boundCondition = if (condition.isDefined) {
     newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index ee7a1bdc343c02dd6e4f312ba9605fff29ac8d27..c94d6c195b1d8b050bbbeab70fd555eef6bb112c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -39,8 +39,23 @@ import org.apache.spark.util.collection.CompactBuffer
  * object.
  */
 private[execution] sealed trait HashedRelation {
+  /**
+    * Returns matched rows.
+    */
   def get(key: InternalRow): Seq[InternalRow]
 
+  /**
+    * Returns matched rows for a key that has only one column with LongType.
+    */
+  def get(key: Long): Seq[InternalRow] = {
+    throw new UnsupportedOperationException
+  }
+
+  /**
+    * Returns the size of used memory.
+    */
+  def getMemorySize: Long = 1L  // to make the test happy
+
   // This is a helper method to implement Externalizable, and is used by
   // GeneralHashedRelation and UniqueKeyHashedRelation
   protected def writeBytes(out: ObjectOutput, serialized: Array[Byte]): Unit = {
@@ -58,11 +73,48 @@ private[execution] sealed trait HashedRelation {
   }
 }
 
+/**
+  * Interface for a hashed relation that have only one row per key.
+  *
+  * We should call getValue() for better performance.
+  */
+private[execution] trait UniqueHashedRelation extends HashedRelation {
+
+  /**
+    * Returns the matched single row.
+    */
+  def getValue(key: InternalRow): InternalRow
+
+  /**
+    * Returns the matched single row with key that have only one column of LongType.
+    */
+  def getValue(key: Long): InternalRow = {
+    throw new UnsupportedOperationException
+  }
+
+  override def get(key: InternalRow): Seq[InternalRow] = {
+    val row = getValue(key)
+    if (row != null) {
+      CompactBuffer[InternalRow](row)
+    } else {
+      null
+    }
+  }
+
+  override def get(key: Long): Seq[InternalRow] = {
+    val row = getValue(key)
+    if (row != null) {
+      CompactBuffer[InternalRow](row)
+    } else {
+      null
+    }
+  }
+}
 
 /**
  * A general [[HashedRelation]] backed by a hash map that maps the key into a sequence of values.
  */
-private[joins] final class GeneralHashedRelation(
+private[joins] class GeneralHashedRelation(
     private var hashTable: JavaHashMap[InternalRow, CompactBuffer[InternalRow]])
   extends HashedRelation with Externalizable {
 
@@ -85,19 +137,14 @@ private[joins] final class GeneralHashedRelation(
  * A specialized [[HashedRelation]] that maps key into a single value. This implementation
  * assumes the key is unique.
  */
-private[joins]
-final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalRow, InternalRow])
-  extends HashedRelation with Externalizable {
+private[joins] class UniqueKeyHashedRelation(
+  private var hashTable: JavaHashMap[InternalRow, InternalRow])
+  extends UniqueHashedRelation with Externalizable {
 
   // Needed for serialization (it is public to make Java serialization work)
   def this() = this(null)
 
-  override def get(key: InternalRow): Seq[InternalRow] = {
-    val v = hashTable.get(key)
-    if (v eq null) null else CompactBuffer(v)
-  }
-
-  def getValue(key: InternalRow): InternalRow = hashTable.get(key)
+  override def getValue(key: InternalRow): InternalRow = hashTable.get(key)
 
   override def writeExternal(out: ObjectOutput): Unit = {
     writeBytes(out, SparkSqlSerializer.serialize(hashTable))
@@ -108,8 +155,6 @@ final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalR
   }
 }
 
-// TODO(rxin): a version of [[HashedRelation]] backed by arrays for consecutive integer keys.
-
 
 private[execution] object HashedRelation {
 
@@ -208,7 +253,7 @@ private[joins] final class UnsafeHashedRelation(
    *
    * For non-broadcast joins or in local mode, return 0.
    */
-  def getUnsafeSize: Long = {
+  override def getMemorySize: Long = {
     if (binaryMap != null) {
       binaryMap.getTotalMemoryConsumption
     } else {
@@ -408,6 +453,232 @@ private[joins] object UnsafeHashedRelation {
       }
     }
 
+    // TODO: create UniqueUnsafeRelation
     new UnsafeHashedRelation(hashTable)
   }
 }
+
+/**
+  * An interface for a hashed relation that the key is a Long.
+  */
+private[joins] trait LongHashedRelation extends HashedRelation {
+  override def get(key: InternalRow): Seq[InternalRow] = {
+    get(key.getLong(0))
+  }
+}
+
+private[joins] final class GeneralLongHashedRelation(
+  private var hashTable: JavaHashMap[Long, CompactBuffer[UnsafeRow]])
+  extends LongHashedRelation with Externalizable {
+
+  // Needed for serialization (it is public to make Java serialization work)
+  def this() = this(null)
+
+  override def get(key: Long): Seq[InternalRow] = hashTable.get(key)
+
+  override def writeExternal(out: ObjectOutput): Unit = {
+    writeBytes(out, SparkSqlSerializer.serialize(hashTable))
+  }
+
+  override def readExternal(in: ObjectInput): Unit = {
+    hashTable = SparkSqlSerializer.deserialize(readBytes(in))
+  }
+}
+
+private[joins] final class UniqueLongHashedRelation(
+  private var hashTable: JavaHashMap[Long, UnsafeRow])
+  extends UniqueHashedRelation with LongHashedRelation with Externalizable {
+
+  // Needed for serialization (it is public to make Java serialization work)
+  def this() = this(null)
+
+  override def getValue(key: InternalRow): InternalRow = {
+    getValue(key.getLong(0))
+  }
+
+  override def getValue(key: Long): InternalRow = {
+    hashTable.get(key)
+  }
+
+  override def writeExternal(out: ObjectOutput): Unit = {
+    writeBytes(out, SparkSqlSerializer.serialize(hashTable))
+  }
+
+  override def readExternal(in: ObjectInput): Unit = {
+    hashTable = SparkSqlSerializer.deserialize(readBytes(in))
+  }
+}
+
+/**
+  * A relation that pack all the rows into a byte array, together with offsets and sizes.
+  *
+  * All the bytes of UnsafeRow are packed together as `bytes`:
+  *
+  *  [  Row0  ][  Row1  ][] ... [  RowN  ]
+  *
+  * With keys:
+  *
+  *   start    start+1   ...       start+N
+  *
+  * `offsets` are offsets of UnsafeRows in the `bytes`
+  * `sizes` are the numbers of bytes of UnsafeRows, 0 means no row for this key.
+  *
+  *  For example, two UnsafeRows (24 bytes and 32 bytes), with keys as 3 and 5 will stored as:
+  *
+  *  start   = 3
+  *  offsets = [0, 0, 24]
+  *  sizes   = [24, 0, 32]
+  *  bytes   = [0 - 24][][24 - 56]
+  */
+private[joins] final class LongArrayRelation(
+    private var numFields: Int,
+    private var start: Long,
+    private var offsets: Array[Int],
+    private var sizes: Array[Int],
+    private var bytes: Array[Byte]
+  ) extends UniqueHashedRelation with LongHashedRelation with Externalizable {
+
+  // Needed for serialization (it is public to make Java serialization work)
+  def this() = this(0, 0L, null, null, null)
+
+  override def getValue(key: InternalRow): InternalRow = {
+    getValue(key.getLong(0))
+  }
+
+  override def getMemorySize: Long = {
+    offsets.length * 4 + sizes.length * 4 + bytes.length
+  }
+
+  override def getValue(key: Long): InternalRow = {
+    val idx = (key - start).toInt
+    if (idx >= 0 && idx < sizes.length && sizes(idx) > 0) {
+      val result = new UnsafeRow(numFields)
+      result.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(idx), sizes(idx))
+      result
+    } else {
+      null
+    }
+  }
+
+  override def writeExternal(out: ObjectOutput): Unit = {
+    out.writeInt(numFields)
+    out.writeLong(start)
+    out.writeInt(sizes.length)
+    var i = 0
+    while (i < sizes.length) {
+      out.writeInt(sizes(i))
+      i += 1
+    }
+    out.writeInt(bytes.length)
+    out.write(bytes)
+  }
+
+  override def readExternal(in: ObjectInput): Unit = {
+    numFields = in.readInt()
+    start = in.readLong()
+    val length = in.readInt()
+    // read sizes of rows
+    sizes = new Array[Int](length)
+    offsets = new Array[Int](length)
+    var i = 0
+    var offset = 0
+    while (i < length) {
+      offsets(i) = offset
+      sizes(i) = in.readInt()
+      offset += sizes(i)
+      i += 1
+    }
+    // read all the bytes
+    val total = in.readInt()
+    assert(total == offset)
+    bytes = new Array[Byte](total)
+    in.readFully(bytes)
+  }
+}
+
+/**
+  * Create hashed relation with key that is long.
+  */
+private[joins] object LongHashedRelation {
+
+  val DENSE_FACTOR = 0.2
+
+  def apply(
+    input: Iterator[InternalRow],
+    numInputRows: LongSQLMetric,
+    keyGenerator: Projection,
+    sizeEstimate: Int): HashedRelation = {
+
+    // Use a Java hash table here because unsafe maps expect fixed size records
+    val hashTable = new JavaHashMap[Long, CompactBuffer[UnsafeRow]](sizeEstimate)
+
+    // Create a mapping of key -> rows
+    var numFields = 0
+    var keyIsUnique = true
+    var minKey = Long.MaxValue
+    var maxKey = Long.MinValue
+    while (input.hasNext) {
+      val unsafeRow = input.next().asInstanceOf[UnsafeRow]
+      numFields = unsafeRow.numFields()
+      numInputRows += 1
+      val rowKey = keyGenerator(unsafeRow)
+      if (!rowKey.anyNull) {
+        val key = rowKey.getLong(0)
+        minKey = math.min(minKey, key)
+        maxKey = math.max(maxKey, key)
+        val existingMatchList = hashTable.get(key)
+        val matchList = if (existingMatchList == null) {
+          val newMatchList = new CompactBuffer[UnsafeRow]()
+          hashTable.put(key, newMatchList)
+          newMatchList
+        } else {
+          keyIsUnique = false
+          existingMatchList
+        }
+        matchList += unsafeRow.copy()
+      }
+    }
+
+    if (keyIsUnique) {
+      if (hashTable.size() > (maxKey - minKey) * DENSE_FACTOR) {
+        // The keys are dense enough, so use LongArrayRelation
+        val length = (maxKey - minKey).toInt + 1
+        val sizes = new Array[Int](length)
+        val offsets = new Array[Int](length)
+        var offset = 0
+        var i = 0
+        while (i < length) {
+          val rows = hashTable.get(i + minKey)
+          if (rows != null) {
+            offsets(i) = offset
+            sizes(i) = rows(0).getSizeInBytes
+            offset += sizes(i)
+          }
+          i += 1
+        }
+        val bytes = new Array[Byte](offset)
+        i = 0
+        while (i < length) {
+          val rows = hashTable.get(i + minKey)
+          if (rows != null) {
+            rows(0).writeToMemory(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(i))
+          }
+          i += 1
+        }
+        new LongArrayRelation(numFields, minKey, offsets, sizes, bytes)
+
+      } else {
+        // all the keys are unique, one row per key.
+        val uniqHashTable = new JavaHashMap[Long, UnsafeRow](hashTable.size)
+        val iter = hashTable.entrySet().iterator()
+        while (iter.hasNext) {
+          val entry = iter.next()
+          uniqHashTable.put(entry.getKey, entry.getValue()(0))
+        }
+        new UniqueLongHashedRelation(uniqHashTable)
+      }
+    } else {
+      new GeneralLongHashedRelation(hashTable)
+    }
+  }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index 33d4976403d9ae22b5880d5e6be5b00146716d9d..f015d297048a33e4a8bd38ec4eb972d4db7bfd5f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -22,6 +22,7 @@ import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
 import org.apache.spark.sql.SQLContext
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.IntegerType
 import org.apache.spark.unsafe.Platform
 import org.apache.spark.unsafe.hash.Murmur3_x86_32
 import org.apache.spark.unsafe.map.BytesToBytesMap
@@ -122,10 +123,10 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
   }
 
   ignore("broadcast hash join") {
-    val N = 20 << 20
+    val N = 100 << 20
     val dim = broadcast(sqlContext.range(1 << 16).selectExpr("id as k", "cast(id as string) as v"))
 
-    runBenchmark("BroadcastHashJoin", N) {
+    runBenchmark("Join w long", N) {
       sqlContext.range(N).join(dim, (col("id") % 60000) === col("k")).count()
     }
 
@@ -133,9 +134,27 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
     Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
     BroadcastHashJoin:                  Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
     -------------------------------------------------------------------------------------------
-    BroadcastHashJoin codegen=false          4405 / 6147          4.0         250.0       1.0X
-    BroadcastHashJoin codegen=true           1857 / 1878         11.0          90.9       2.4X
+    Join w long codegen=false        10174 / 10317         10.0         100.0       1.0X
+    Join w long codegen=true           1069 / 1107         98.0          10.2       9.5X
+    */
+
+    val dim2 = broadcast(sqlContext.range(1 << 16)
+      .selectExpr("cast(id as int) as k1", "cast(id as int) as k2", "cast(id as string) as v"))
+
+    runBenchmark("Join w 2 ints", N) {
+      sqlContext.range(N).join(dim2,
+        (col("id") bitwiseAND 60000).cast(IntegerType) === col("k1")
+          && (col("id") bitwiseAND 50000).cast(IntegerType) === col("k2")).count()
+    }
+
+    /**
+    Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+    BroadcastHashJoin:                  Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    -------------------------------------------------------------------------------------------
+    Join w 2 ints codegen=false           11435 / 11530          9.0         111.1       1.0X
+    Join w 2 ints codegen=true              1265 / 1424         82.0          12.2       9.0X
     */
+
   }
 
   ignore("hash and BytesToBytesMap") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index e5fd9e277fc61e288f35521637bf1a2b21a726ff..f985dfbd8ade9fdacade3f0467f3d09d35063714 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -27,7 +27,6 @@ import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
 import org.apache.spark.util.collection.CompactBuffer
 
-
 class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
 
   // Key is simply the record itself
@@ -134,4 +133,32 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
     out2.flush()
     assert(java.util.Arrays.equals(os2.toByteArray, os.toByteArray))
   }
+
+  test("LongArrayRelation") {
+    val unsafeProj = UnsafeProjection.create(
+      Seq(BoundReference(0, IntegerType, false), BoundReference(1, IntegerType, true)))
+    val rows = (0 until 100).map(i => unsafeProj(InternalRow(i, i + 1)).copy())
+    val keyProj = UnsafeProjection.create(Seq(BoundReference(0, IntegerType, false)))
+    val longRelation = LongHashedRelation(rows.iterator, SQLMetrics.nullLongMetric, keyProj, 100)
+    assert(longRelation.isInstanceOf[LongArrayRelation])
+    val longArrayRelation = longRelation.asInstanceOf[LongArrayRelation]
+    (0 until 100).foreach { i =>
+      val row = longArrayRelation.getValue(i)
+      assert(row.getInt(0) === i)
+      assert(row.getInt(1) === i + 1)
+    }
+
+    val os = new ByteArrayOutputStream()
+    val out = new ObjectOutputStream(os)
+    longArrayRelation.writeExternal(out)
+    out.flush()
+    val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray))
+    val relation = new LongArrayRelation()
+    relation.readExternal(in)
+    (0 until 100).foreach { i =>
+      val row = longArrayRelation.getValue(i)
+      assert(row.getInt(0) === i)
+      assert(row.getInt(1) === i + 1)
+    }
+  }
 }