From 745425332f41e2ae94649f9d1ad675243f36f743 Mon Sep 17 00:00:00 2001
From: Davies Liu <davies@databricks.com>
Date: Mon, 4 Apr 2016 10:01:24 -0700
Subject: [PATCH] [SPARK-14137] [SQL] Cleanup hash join

## What changes were proposed in this pull request?

This PR did a few cleanup on HashedRelation and HashJoin:

1) Merge HashedRelation and UniqueHashedRelation together
2) Return an iterator from HashedRelation, so we donot need a create many UnsafeRow objects.
3) Return a copy of HashedRelation for thread-safety in BroadcastJoin, so we can re-use the UnafeRow objects.
4) Cleanup HashJoin, share most of the code between BroadcastHashJoin and ShuffleHashJoin
5) Removed UniqueLongHashedRelation, which will be replaced by LongUnsafeMap (another PR).
6) Update benchmark, before this patch, the selectivity of joins are too high.

## How was this patch tested?

Existing tests.

Author: Davies Liu <davies@databricks.com>

Closes #12102 from davies/cleanup_hash.
---
 .../execution/joins/BroadcastHashJoin.scala   |  78 ++----
 .../spark/sql/execution/joins/HashJoin.scala  | 217 ++++++--------
 .../sql/execution/joins/HashedRelation.scala  | 264 ++++++++----------
 .../execution/joins/ShuffledHashJoin.scala    |  32 +--
 .../BenchmarkWholeStageCodegen.scala          |  64 +++--
 .../execution/joins/HashedRelationSuite.scala |  14 +-
 6 files changed, 268 insertions(+), 401 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index 41e566c27b..67ac9e94ff 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -68,37 +68,9 @@ case class BroadcastHashJoin(
 
     val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]()
     streamedPlan.execute().mapPartitions { streamedIter =>
-      val joinedRow = new JoinedRow()
-      val hashTable = broadcastRelation.value
-      TaskContext.get().taskMetrics().incPeakExecutionMemory(hashTable.getMemorySize)
-      val keyGenerator = streamSideKeyGenerator
-      val resultProj = createResultProjection
-
-      joinType match {
-        case Inner =>
-          hashJoin(streamedIter, hashTable, numOutputRows)
-
-        case LeftOuter =>
-          streamedIter.flatMap { currentRow =>
-            val rowKey = keyGenerator(currentRow)
-            joinedRow.withLeft(currentRow)
-            leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj, numOutputRows)
-          }
-
-        case RightOuter =>
-          streamedIter.flatMap { currentRow =>
-            val rowKey = keyGenerator(currentRow)
-            joinedRow.withRight(currentRow)
-            rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj, numOutputRows)
-          }
-
-        case LeftSemi =>
-          hashSemiJoin(streamedIter, hashTable, numOutputRows)
-
-        case x =>
-          throw new IllegalArgumentException(
-            s"BroadcastHashJoin should not take $x as the JoinType")
-      }
+      val hashed = broadcastRelation.value.asReadOnlyCopy()
+      TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.getMemorySize)
+      join(streamedIter, hashed, numOutputRows)
     }
   }
 
@@ -132,7 +104,7 @@ case class BroadcastHashJoin(
     val clsName = broadcastRelation.value.getClass.getName
     ctx.addMutableState(clsName, relationTerm,
       s"""
-         | $relationTerm = ($clsName) $broadcast.value();
+         | $relationTerm = (($clsName) $broadcast.value()).asReadOnlyCopy();
          | incPeakExecutionMemory($relationTerm.getMemorySize());
        """.stripMargin)
     (broadcastRelation, relationTerm)
@@ -217,7 +189,7 @@ case class BroadcastHashJoin(
       case BuildLeft => buildVars ++ input
       case BuildRight => input ++ buildVars
     }
-    if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) {
+    if (broadcastRelation.value.keyIsUnique) {
       s"""
          |// generate join key for stream side
          |${keyEv.code}
@@ -232,18 +204,15 @@ case class BroadcastHashJoin(
     } else {
       ctx.copyResult = true
       val matches = ctx.freshName("matches")
-      val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
-      val i = ctx.freshName("i")
-      val size = ctx.freshName("size")
+      val iteratorCls = classOf[Iterator[UnsafeRow]].getName
       s"""
          |// generate join key for stream side
          |${keyEv.code}
          |// find matches from HashRelation
-         |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value});
+         |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value});
          |if ($matches == null) continue;
-         |int $size = $matches.size();
-         |for (int $i = 0; $i < $size; $i++) {
-         |  UnsafeRow $matched = (UnsafeRow) $matches.apply($i);
+         |while ($matches.hasNext()) {
+         |  UnsafeRow $matched = (UnsafeRow) $matches.next();
          |  $checkCondition
          |  $numOutput.add(1);
          |  ${consume(ctx, resultVars)}
@@ -287,7 +256,7 @@ case class BroadcastHashJoin(
       case BuildLeft => buildVars ++ input
       case BuildRight => input ++ buildVars
     }
-    if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) {
+    if (broadcastRelation.value.keyIsUnique) {
       s"""
          |// generate join key for stream side
          |${keyEv.code}
@@ -306,22 +275,21 @@ case class BroadcastHashJoin(
     } else {
       ctx.copyResult = true
       val matches = ctx.freshName("matches")
-      val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
+      val iteratorCls = classOf[Iterator[UnsafeRow]].getName
       val i = ctx.freshName("i")
-      val size = ctx.freshName("size")
       val found = ctx.freshName("found")
       s"""
          |// generate join key for stream side
          |${keyEv.code}
          |// find matches from HashRelation
-         |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value});
-         |int $size = $matches != null ? $matches.size() : 0;
+         |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value});
          |boolean $found = false;
          |// the last iteration of this loop is to emit an empty row if there is no matched rows.
-         |for (int $i = 0; $i <= $size; $i++) {
-         |  UnsafeRow $matched = $i < $size ? (UnsafeRow) $matches.apply($i) : null;
+         |while ($matches != null && $matches.hasNext() || !$found) {
+         |  UnsafeRow $matched = $matches != null && $matches.hasNext() ?
+         |    (UnsafeRow) $matches.next() : null;
          |  ${checkCondition.trim}
-         |  if (!$conditionPassed || ($i == $size && $found)) continue;
+         |  if (!$conditionPassed) continue;
          |  $found = true;
          |  $numOutput.add(1);
          |  ${consume(ctx, resultVars)}
@@ -356,7 +324,7 @@ case class BroadcastHashJoin(
       ""
     }
 
-    if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) {
+    if (broadcastRelation.value.keyIsUnique) {
       s"""
          |// generate join key for stream side
          |${keyEv.code}
@@ -369,23 +337,19 @@ case class BroadcastHashJoin(
        """.stripMargin
     } else {
       val matches = ctx.freshName("matches")
-      val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
-      val i = ctx.freshName("i")
-      val size = ctx.freshName("size")
+      val iteratorCls = classOf[Iterator[UnsafeRow]].getName
       val found = ctx.freshName("found")
       s"""
          |// generate join key for stream side
          |${keyEv.code}
          |// find matches from HashRelation
-         |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value});
+         |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value});
          |if ($matches == null) continue;
-         |int $size = $matches.size();
          |boolean $found = false;
-         |for (int $i = 0; $i < $size; $i++) {
-         |  UnsafeRow $matched = (UnsafeRow) $matches.apply($i);
+         |while (!$found && $matches.hasNext()) {
+         |  UnsafeRow $matched = (UnsafeRow) $matches.next();
          |  $checkCondition
          |  $found = true;
-         |  break;
          |}
          |if (!$found) continue;
          |$numOutput.add(1);
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index c298b7dee0..b7c0f3e7d1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -19,10 +19,11 @@ package org.apache.spark.sql.execution.joins
 
 import java.util.NoSuchElementException
 
+import org.apache.spark.TaskContext
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.{RowIterator, SparkPlan}
 import org.apache.spark.sql.execution.metric.LongSQLMetric
 import org.apache.spark.sql.types.{IntegerType, IntegralType, LongType}
 import org.apache.spark.util.collection.CompactBuffer
@@ -110,169 +111,113 @@ trait HashJoin {
     sameTypes && key.length == 1 && key.head.dataType.isInstanceOf[LongType]
   }
 
-  protected def buildSideKeyGenerator: Projection =
+  protected def buildSideKeyGenerator(): Projection =
     UnsafeProjection.create(rewriteKeyExpr(buildKeys), buildPlan.output)
 
-  protected def streamSideKeyGenerator: Projection =
+  protected def streamSideKeyGenerator(): UnsafeProjection =
     UnsafeProjection.create(rewriteKeyExpr(streamedKeys), streamedPlan.output)
 
   @transient private[this] lazy val boundCondition = if (condition.isDefined) {
-    newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
+    newPredicate(condition.get, streamedPlan.output ++ buildPlan.output)
   } else {
     (r: InternalRow) => true
   }
 
-  protected def createResultProjection: (InternalRow) => InternalRow =
-    UnsafeProjection.create(self.schema)
-
-  protected def hashJoin(
-      streamIter: Iterator[InternalRow],
-      hashedRelation: HashedRelation,
-      numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
-    new Iterator[InternalRow] {
-      private[this] var currentStreamedRow: InternalRow = _
-      private[this] var currentHashMatches: Seq[InternalRow] = _
-      private[this] var currentMatchPosition: Int = -1
-
-      // Mutable per row objects.
-      private[this] val joinRow = new JoinedRow
-      private[this] val resultProjection = createResultProjection
-
-      private[this] val joinKeys = streamSideKeyGenerator
-
-      override final def hasNext: Boolean = {
-        while (true) {
-          // check if it's end of current matches
-          if (currentHashMatches != null && currentMatchPosition == currentHashMatches.length) {
-            currentHashMatches = null
-            currentMatchPosition = -1
-          }
-
-          // find the next match
-          while (currentHashMatches == null && streamIter.hasNext) {
-            currentStreamedRow = streamIter.next()
-            val key = joinKeys(currentStreamedRow)
-            if (!key.anyNull) {
-              currentHashMatches = hashedRelation.get(key)
-              if (currentHashMatches != null) {
-                currentMatchPosition = 0
-              }
-            }
-          }
-          if (currentHashMatches == null) {
-            return false
-          }
-
-          // found some matches
-          buildSide match {
-            case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition))
-            case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow)
-          }
-          if (boundCondition(joinRow)) {
-            return true
-          } else {
-            currentMatchPosition += 1
-          }
-        }
-        false  // unreachable
-      }
-
-      override final def next(): InternalRow = {
-        // next() could be called without calling hasNext()
-        if (hasNext) {
-          currentMatchPosition += 1
-          numOutputRows += 1
-          resultProjection(joinRow)
-        } else {
-          throw new NoSuchElementException
-        }
-      }
+  protected def createResultProjection(): (InternalRow) => InternalRow = {
+    if (joinType == LeftSemi) {
+      UnsafeProjection.create(output, output)
+    } else {
+      // Always put the stream side on left to simplify implementation
+      // both of left and right side could be null
+      UnsafeProjection.create(
+        output, (streamedPlan.output ++ buildPlan.output).map(_.withNullability(true)))
     }
   }
 
-  @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]()
-
-  @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length)
-  @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length)
-
-  protected[this] def leftOuterIterator(
-      key: InternalRow,
-      joinedRow: JoinedRow,
-      rightIter: Iterable[InternalRow],
-      resultProjection: InternalRow => InternalRow,
-      numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
-    val ret: Iterable[InternalRow] = {
-      if (!key.anyNull) {
-        val temp = if (rightIter != null) {
-          rightIter.collect {
-            case r if boundCondition(joinedRow.withRight(r)) => {
-              numOutputRows += 1
-              resultProjection(joinedRow).copy()
-            }
-          }
-        } else {
-          List.empty
-        }
-        if (temp.isEmpty) {
-          numOutputRows += 1
-          resultProjection(joinedRow.withRight(rightNullRow)) :: Nil
-        } else {
-          temp
-        }
+  private def innerJoin(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation): Iterator[InternalRow] = {
+    val joinRow = new JoinedRow
+    val joinKeys = streamSideKeyGenerator()
+    streamIter.flatMap { srow =>
+      joinRow.withLeft(srow)
+      val matches = hashedRelation.get(joinKeys(srow))
+      if (matches != null) {
+        matches.map(joinRow.withRight(_)).filter(boundCondition)
       } else {
-        numOutputRows += 1
-        resultProjection(joinedRow.withRight(rightNullRow)) :: Nil
+        Seq.empty
       }
     }
-    ret.iterator
   }
 
-  protected[this] def rightOuterIterator(
-      key: InternalRow,
-      leftIter: Iterable[InternalRow],
-      joinedRow: JoinedRow,
-      resultProjection: InternalRow => InternalRow,
-      numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
-    val ret: Iterable[InternalRow] = {
-      if (!key.anyNull) {
-        val temp = if (leftIter != null) {
-          leftIter.collect {
-            case l if boundCondition(joinedRow.withLeft(l)) => {
-              numOutputRows += 1
-              resultProjection(joinedRow).copy()
+  private def outerJoin(
+      streamedIter: Iterator[InternalRow],
+    hashedRelation: HashedRelation): Iterator[InternalRow] = {
+    val joinedRow = new JoinedRow()
+    val keyGenerator = streamSideKeyGenerator()
+    val nullRow = new GenericInternalRow(buildPlan.output.length)
+
+    streamedIter.flatMap { currentRow =>
+      val rowKey = keyGenerator(currentRow)
+      joinedRow.withLeft(currentRow)
+      val buildIter = hashedRelation.get(rowKey)
+      new RowIterator {
+        private var found = false
+        override def advanceNext(): Boolean = {
+          while (buildIter != null && buildIter.hasNext) {
+            val nextBuildRow = buildIter.next()
+            if (boundCondition(joinedRow.withRight(nextBuildRow))) {
+              found = true
+              return true
             }
           }
-        } else {
-          List.empty
-        }
-        if (temp.isEmpty) {
-          numOutputRows += 1
-          resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil
-        } else {
-          temp
+          if (!found) {
+            joinedRow.withRight(nullRow)
+            found = true
+            return true
+          }
+          false
         }
-      } else {
-        numOutputRows += 1
-        resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil
-      }
+        override def getRow: InternalRow = joinedRow
+      }.toScala
     }
-    ret.iterator
   }
 
-  protected def hashSemiJoin(
-    streamIter: Iterator[InternalRow],
-    hashedRelation: HashedRelation,
-    numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
-    val joinKeys = streamSideKeyGenerator
+  private def semiJoin(
+      streamIter: Iterator[InternalRow],
+      hashedRelation: HashedRelation): Iterator[InternalRow] = {
+    val joinKeys = streamSideKeyGenerator()
     val joinedRow = new JoinedRow
     streamIter.filter { current =>
       val key = joinKeys(current)
-      lazy val rowBuffer = hashedRelation.get(key)
-      val r = !key.anyNull && rowBuffer != null && (condition.isEmpty || rowBuffer.exists {
+      lazy val buildIter = hashedRelation.get(key)
+      !key.anyNull && buildIter != null && (condition.isEmpty || buildIter.exists {
         (row: InternalRow) => boundCondition(joinedRow(current, row))
       })
-      if (r) numOutputRows += 1
-      r
+    }
+  }
+
+  protected def join(
+      streamedIter: Iterator[InternalRow],
+      hashed: HashedRelation,
+      numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
+
+    val joinedIter = joinType match {
+      case Inner =>
+        innerJoin(streamedIter, hashed)
+      case LeftOuter | RightOuter =>
+        outerJoin(streamedIter, hashed)
+      case LeftSemi =>
+        semiJoin(streamedIter, hashed)
+      case x =>
+        throw new IllegalArgumentException(
+          s"BroadcastHashJoin should not take $x as the JoinType")
+    }
+
+    val resultProj = createResultProjection
+    joinedIter.map { r =>
+      numOutputRows += 1
+      resultProj(r)
     }
   }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 91c470d187..5ccb435686 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
 import org.apache.spark.sql.execution.SparkSqlSerializer
 import org.apache.spark.unsafe.Platform
-import org.apache.spark.unsafe.hash.Murmur3_x86_32
 import org.apache.spark.unsafe.map.BytesToBytesMap
 import org.apache.spark.util.{KnownSizeEstimation, Utils}
 import org.apache.spark.util.collection.CompactBuffer
@@ -39,16 +38,42 @@ import org.apache.spark.util.collection.CompactBuffer
 private[execution] sealed trait HashedRelation {
   /**
    * Returns matched rows.
+   *
+   * Returns null if there is no matched rows.
    */
-  def get(key: InternalRow): Seq[InternalRow]
+  def get(key: InternalRow): Iterator[InternalRow]
 
   /**
    * Returns matched rows for a key that has only one column with LongType.
+   *
+   * Returns null if there is no matched rows.
+   */
+  def get(key: Long): Iterator[InternalRow] = {
+    throw new UnsupportedOperationException
+  }
+
+  /**
+   * Returns the matched single row.
    */
-  def get(key: Long): Seq[InternalRow] = {
+  def getValue(key: InternalRow): InternalRow
+
+  /**
+   * Returns the matched single row with key that have only one column of LongType.
+   */
+  def getValue(key: Long): InternalRow = {
     throw new UnsupportedOperationException
   }
 
+  /**
+   * Returns true iff all the keys are unique.
+   */
+  def keyIsUnique: Boolean
+
+  /**
+   * Returns a read-only copy of this, to be safely used in current thread.
+   */
+  def asReadOnlyCopy(): HashedRelation
+
   /**
    * Returns the size of used memory.
    */
@@ -76,44 +101,6 @@ private[execution] sealed trait HashedRelation {
   }
 }
 
-/**
- * Interface for a hashed relation that have only one row per key.
- *
- * We should call getValue() for better performance.
- */
-private[execution] trait UniqueHashedRelation extends HashedRelation {
-
-  /**
-   * Returns the matched single row.
-   */
-  def getValue(key: InternalRow): InternalRow
-
-  /**
-   * Returns the matched single row with key that have only one column of LongType.
-   */
-  def getValue(key: Long): InternalRow = {
-    throw new UnsupportedOperationException
-  }
-
-  override def get(key: InternalRow): Seq[InternalRow] = {
-    val row = getValue(key)
-    if (row != null) {
-      CompactBuffer[InternalRow](row)
-    } else {
-      null
-    }
-  }
-
-  override def get(key: Long): Seq[InternalRow] = {
-    val row = getValue(key)
-    if (row != null) {
-      CompactBuffer[InternalRow](row)
-    } else {
-      null
-    }
-  }
-}
-
 private[execution] object HashedRelation {
 
   /**
@@ -150,6 +137,11 @@ private[joins] class UnsafeHashedRelation(
 
   private[joins] def this() = this(0, null)  // Needed for serialization
 
+  override def keyIsUnique: Boolean = binaryMap.numKeys() == binaryMap.numValues()
+
+  override def asReadOnlyCopy(): UnsafeHashedRelation =
+    new UnsafeHashedRelation(numFields, binaryMap)
+
   override def getMemorySize: Long = {
     binaryMap.getTotalMemoryConsumption
   }
@@ -158,23 +150,39 @@ private[joins] class UnsafeHashedRelation(
     binaryMap.getTotalMemoryConsumption
   }
 
-  override def get(key: InternalRow): Seq[InternalRow] = {
+  // re-used in get()/getValue()
+  var resultRow = new UnsafeRow(numFields)
+
+  override def get(key: InternalRow): Iterator[InternalRow] = {
     val unsafeKey = key.asInstanceOf[UnsafeRow]
     val map = binaryMap  // avoid the compiler error
     val loc = new map.Location  // this could be allocated in stack
     binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
       unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode())
     if (loc.isDefined) {
-      val buffer = CompactBuffer[UnsafeRow]()
-      val row = new UnsafeRow(numFields)
-      row.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
-      buffer += row
-      while (loc.nextValue()) {
-        val row = new UnsafeRow(numFields)
-        row.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
-        buffer += row
+      new Iterator[UnsafeRow] {
+        private var _hasNext = true
+        override def hasNext: Boolean = _hasNext
+        override def next(): UnsafeRow = {
+          resultRow.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
+          _hasNext = loc.nextValue()
+          resultRow
+        }
       }
-      buffer
+    } else {
+      null
+    }
+  }
+
+  def getValue(key: InternalRow): InternalRow = {
+    val unsafeKey = key.asInstanceOf[UnsafeRow]
+    val map = binaryMap  // avoid the compiler error
+    val loc = new map.Location  // this could be allocated in stack
+    binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
+      unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode())
+    if (loc.isDefined) {
+      resultRow.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
+      resultRow
     } else {
       null
     }
@@ -212,6 +220,7 @@ private[joins] class UnsafeHashedRelation(
 
   override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
     numFields = in.readInt()
+    resultRow = new UnsafeRow(numFields)
     val nKeys = in.readInt()
     val nValues = in.readInt()
     // This is used in Broadcast, shared by multiple tasks, so we use on-heap memory
@@ -263,29 +272,6 @@ private[joins] class UnsafeHashedRelation(
   }
 }
 
-/**
- * A HashedRelation for UnsafeRow with unique keys.
- */
-private[joins] final class UniqueUnsafeHashedRelation(
-    private var numFields: Int,
-    private var binaryMap: BytesToBytesMap)
-  extends UnsafeHashedRelation(numFields, binaryMap) with UniqueHashedRelation {
-  def getValue(key: InternalRow): InternalRow = {
-    val unsafeKey = key.asInstanceOf[UnsafeRow]
-    val map = binaryMap  // avoid the compiler error
-    val loc = new map.Location  // this could be allocated in stack
-    binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
-      unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode())
-    if (loc.isDefined) {
-      val row = new UnsafeRow(numFields)
-      row.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
-      row
-    } else {
-      null
-    }
-  }
-}
-
 private[joins] object UnsafeHashedRelation {
 
   def apply(
@@ -315,17 +301,12 @@ private[joins] object UnsafeHashedRelation {
 
     // Create a mapping of buildKeys -> rows
     var numFields = 0
-    // Whether all the keys are unique or not
-    var allUnique: Boolean = true
     while (input.hasNext) {
       val row = input.next().asInstanceOf[UnsafeRow]
       numFields = row.numFields()
       val key = keyGenerator(row)
       if (!key.anyNull) {
         val loc = binaryMap.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes)
-        if (loc.isDefined) {
-          allUnique = false
-        }
         val success = loc.append(
           key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
           row.getBaseObject, row.getBaseOffset, row.getSizeInBytes)
@@ -336,11 +317,7 @@ private[joins] object UnsafeHashedRelation {
       }
     }
 
-    if (allUnique) {
-      new UniqueUnsafeHashedRelation(numFields, binaryMap)
-    } else {
-      new UnsafeHashedRelation(numFields, binaryMap)
-    }
+    new UnsafeHashedRelation(numFields, binaryMap)
   }
 }
 
@@ -348,9 +325,12 @@ private[joins] object UnsafeHashedRelation {
  * An interface for a hashed relation that the key is a Long.
  */
 private[joins] trait LongHashedRelation extends HashedRelation {
-  override def get(key: InternalRow): Seq[InternalRow] = {
+  override def get(key: InternalRow): Iterator[InternalRow] = {
     get(key.getLong(0))
   }
+  override def getValue(key: InternalRow): InternalRow = {
+    getValue(key.getLong(0))
+  }
 }
 
 private[joins] final class GeneralLongHashedRelation(
@@ -360,30 +340,18 @@ private[joins] final class GeneralLongHashedRelation(
   // Needed for serialization (it is public to make Java serialization work)
   def this() = this(null)
 
-  override def get(key: Long): Seq[InternalRow] = hashTable.get(key)
+  override def keyIsUnique: Boolean = false
 
-  override def writeExternal(out: ObjectOutput): Unit = {
-    writeBytes(out, SparkSqlSerializer.serialize(hashTable))
-  }
-
-  override def readExternal(in: ObjectInput): Unit = {
-    hashTable = SparkSqlSerializer.deserialize(readBytes(in))
-  }
-}
+  override def asReadOnlyCopy(): GeneralLongHashedRelation =
+    new GeneralLongHashedRelation(hashTable)
 
-private[joins] final class UniqueLongHashedRelation(
-  private var hashTable: JavaHashMap[Long, UnsafeRow])
-  extends UniqueHashedRelation with LongHashedRelation with Externalizable {
-
-  // Needed for serialization (it is public to make Java serialization work)
-  def this() = this(null)
-
-  override def getValue(key: InternalRow): InternalRow = {
-    getValue(key.getLong(0))
-  }
-
-  override def getValue(key: Long): InternalRow = {
-    hashTable.get(key)
+  override def get(key: Long): Iterator[InternalRow] = {
+    val rows = hashTable.get(key)
+    if (rows != null) {
+      rows.toIterator
+    } else {
+      null
+    }
   }
 
   override def writeExternal(out: ObjectOutput): Unit = {
@@ -422,25 +390,36 @@ private[joins] final class LongArrayRelation(
     private var offsets: Array[Int],
     private var sizes: Array[Int],
     private var bytes: Array[Byte]
-  ) extends UniqueHashedRelation with LongHashedRelation with Externalizable {
+  ) extends LongHashedRelation with Externalizable {
 
   // Needed for serialization (it is public to make Java serialization work)
   def this() = this(0, 0L, null, null, null)
 
-  override def getValue(key: InternalRow): InternalRow = {
-    getValue(key.getLong(0))
+  override def keyIsUnique: Boolean = true
+
+  override def asReadOnlyCopy(): LongArrayRelation = {
+    new LongArrayRelation(numFields, start, offsets, sizes, bytes)
   }
 
   override def getMemorySize: Long = {
     offsets.length * 4 + sizes.length * 4 + bytes.length
   }
 
+  override def get(key: Long): Iterator[InternalRow] = {
+    val row = getValue(key)
+    if (row != null) {
+      Seq(row).toIterator
+    } else {
+      null
+    }
+  }
+
+  var resultRow = new UnsafeRow(numFields)
   override def getValue(key: Long): InternalRow = {
     val idx = (key - start).toInt
     if (idx >= 0 && idx < sizes.length && sizes(idx) > 0) {
-      val result = new UnsafeRow(numFields)
-      result.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(idx), sizes(idx))
-      result
+      resultRow.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(idx), sizes(idx))
+      resultRow
     } else {
       null
     }
@@ -461,6 +440,7 @@ private[joins] final class LongArrayRelation(
 
   override def readExternal(in: ObjectInput): Unit = {
     numFields = in.readInt()
+    resultRow = new UnsafeRow(numFields)
     start = in.readLong()
     val length = in.readInt()
     // read sizes of rows
@@ -523,44 +503,32 @@ private[joins] object LongHashedRelation {
       }
     }
 
-    if (keyIsUnique) {
-      if (hashTable.size() > (maxKey - minKey) * DENSE_FACTOR) {
-        // The keys are dense enough, so use LongArrayRelation
-        val length = (maxKey - minKey).toInt + 1
-        val sizes = new Array[Int](length)
-        val offsets = new Array[Int](length)
-        var offset = 0
-        var i = 0
-        while (i < length) {
-          val rows = hashTable.get(i + minKey)
-          if (rows != null) {
-            offsets(i) = offset
-            sizes(i) = rows(0).getSizeInBytes
-            offset += sizes(i)
-          }
-          i += 1
-        }
-        val bytes = new Array[Byte](offset)
-        i = 0
-        while (i < length) {
-          val rows = hashTable.get(i + minKey)
-          if (rows != null) {
-            rows(0).writeToMemory(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(i))
-          }
-          i += 1
+    if (keyIsUnique && hashTable.size() > (maxKey - minKey) * DENSE_FACTOR) {
+      // The keys are dense enough, so use LongArrayRelation
+      val length = (maxKey - minKey).toInt + 1
+      val sizes = new Array[Int](length)
+      val offsets = new Array[Int](length)
+      var offset = 0
+      var i = 0
+      while (i < length) {
+        val rows = hashTable.get(i + minKey)
+        if (rows != null) {
+          offsets(i) = offset
+          sizes(i) = rows(0).getSizeInBytes
+          offset += sizes(i)
         }
-        new LongArrayRelation(numFields, minKey, offsets, sizes, bytes)
-
-      } else {
-        // all the keys are unique, one row per key.
-        val uniqHashTable = new JavaHashMap[Long, UnsafeRow](hashTable.size)
-        val iter = hashTable.entrySet().iterator()
-        while (iter.hasNext) {
-          val entry = iter.next()
-          uniqHashTable.put(entry.getKey, entry.getValue()(0))
+        i += 1
+      }
+      val bytes = new Array[Byte](offset)
+      i = 0
+      while (i < length) {
+        val rows = hashTable.get(i + minKey)
+        if (rows != null) {
+          rows(0).writeToMemory(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(i))
         }
-        new UniqueLongHashedRelation(uniqHashTable)
+        i += 1
       }
+      new LongArrayRelation(numFields, minKey, offsets, sizes, bytes)
     } else {
       new GeneralLongHashedRelation(hashTable)
     }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
index e3a2eaea5d..c63faacf33 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
@@ -102,39 +102,9 @@ case class ShuffledHashJoin(
 
   protected override def doExecute(): RDD[InternalRow] = {
     val numOutputRows = longMetric("numOutputRows")
-
     streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) =>
       val hashed = buildHashedRelation(buildIter.asInstanceOf[Iterator[UnsafeRow]])
-      val joinedRow = new JoinedRow
-      joinType match {
-        case Inner =>
-          hashJoin(streamIter, hashed, numOutputRows)
-
-        case LeftSemi =>
-          hashSemiJoin(streamIter, hashed, numOutputRows)
-
-        case LeftOuter =>
-          val keyGenerator = streamSideKeyGenerator
-          val resultProj = createResultProjection
-          streamIter.flatMap(currentRow => {
-            val rowKey = keyGenerator(currentRow)
-            joinedRow.withLeft(currentRow)
-            leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey), resultProj, numOutputRows)
-          })
-
-        case RightOuter =>
-          val keyGenerator = streamSideKeyGenerator
-          val resultProj = createResultProjection
-          streamIter.flatMap(currentRow => {
-            val rowKey = keyGenerator(currentRow)
-            joinedRow.withRight(currentRow)
-            rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow, resultProj, numOutputRows)
-          })
-
-        case x =>
-          throw new IllegalArgumentException(
-            s"ShuffledHashJoin should not take $x as the JoinType")
-      }
+      join(streamIter, hashed, numOutputRows)
     }
   }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index 289e1b6db9..3566ef3043 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -166,20 +166,35 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
   }
 
   ignore("broadcast hash join") {
-    val N = 100 << 20
+    val N = 20 << 20
     val M = 1 << 16
     val dim = broadcast(sqlContext.range(M).selectExpr("id as k", "cast(id as string) as v"))
 
     runBenchmark("Join w long", N) {
-      sqlContext.range(N).join(dim, (col("id") bitwiseAND M) === col("k")).count()
+      sqlContext.range(N).join(dim, (col("id") % M) === col("k")).count()
     }
 
     /*
+    Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5
     Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
     Join w long:                        Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
     -------------------------------------------------------------------------------------------
-    Join w long codegen=false                5744 / 5814         18.3          54.8       1.0X
-    Join w long codegen=true                  735 /  853        142.7           7.0       7.8X
+    Join w long codegen=false                5351 / 5531          3.9         255.1       1.0X
+    Join w long codegen=true                  275 /  352         76.2          13.1      19.4X
+    */
+
+    runBenchmark("Join w long duplicated", N) {
+      val dim = broadcast(sqlContext.range(M).selectExpr("cast(id/10 as long) as k"))
+      sqlContext.range(N).join(dim, (col("id") % M) === col("k")).count()
+    }
+
+    /**
+    Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5
+    Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+    Join w long duplicated:             Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    -------------------------------------------------------------------------------------------
+    Join w long duplicated codegen=false      4752 / 4906          4.4         226.6       1.0X
+    Join w long duplicated codegen=true       722 /  760         29.0          34.4       6.6X
     */
 
     val dim2 = broadcast(sqlContext.range(M)
@@ -187,16 +202,17 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
 
     runBenchmark("Join w 2 ints", N) {
       sqlContext.range(N).join(dim2,
-        (col("id") bitwiseAND M).cast(IntegerType) === col("k1")
-          && (col("id") bitwiseAND M).cast(IntegerType) === col("k2")).count()
+        (col("id") % M).cast(IntegerType) === col("k1")
+          && (col("id") % M).cast(IntegerType) === col("k2")).count()
     }
 
     /**
+    Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5
     Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
     Join w 2 ints:                      Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
     -------------------------------------------------------------------------------------------
-    Join w 2 ints codegen=false              7159 / 7224         14.6          68.3       1.0X
-    Join w 2 ints codegen=true               1135 / 1197         92.4          10.8       6.3X
+    Join w 2 ints codegen=false              9011 / 9121          2.3         429.7       1.0X
+    Join w 2 ints codegen=true               2565 / 2816          8.2         122.3       3.5X
     */
 
     val dim3 = broadcast(sqlContext.range(M)
@@ -204,16 +220,17 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
 
     runBenchmark("Join w 2 longs", N) {
       sqlContext.range(N).join(dim3,
-        (col("id") bitwiseAND M) === col("k1") && (col("id") bitwiseAND M) === col("k2"))
+        (col("id") % M) === col("k1") && (col("id") % M) === col("k2"))
         .count()
     }
 
     /**
+    Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5
     Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
     Join w 2 longs:                     Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
     -------------------------------------------------------------------------------------------
-    Join w 2 longs codegen=false           12725 / 13158          8.2         121.4       1.0X
-    Join w 2 longs codegen=true              6044 / 6771         17.3          57.6       2.1X
+    Join w 2 longs codegen=false             5905 / 6123          3.6         281.6       1.0X
+    Join w 2 longs codegen=true              2230 / 2529          9.4         106.3       2.6X
       */
 
     val dim4 = broadcast(sqlContext.range(M)
@@ -227,34 +244,36 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
 
     /**
     Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
-    Join w 2 longs:                     Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    Join w 2 longs duplicated:          Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
     -------------------------------------------------------------------------------------------
-    Join w 2 longs duplicated codegen=false 13066 / 13710          8.0         124.6       1.0X
-    Join w 2 longs duplicated codegen=true    7122 / 7277         14.7          67.9       1.8X
+    Join w 2 longs duplicated codegen=false      6420 / 6587          3.3         306.1       1.0X
+    Join w 2 longs duplicated codegen=true      2080 / 2139         10.1          99.2       3.1X
      */
 
     runBenchmark("outer join w long", N) {
-      sqlContext.range(N).join(dim, (col("id") bitwiseAND M) === col("k"), "left").count()
+      sqlContext.range(N).join(dim, (col("id") % M) === col("k"), "left").count()
     }
 
     /**
+    Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5
     Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
     outer join w long:                  Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
     -------------------------------------------------------------------------------------------
-    outer join w long codegen=false        15280 / 16497          6.9         145.7       1.0X
-    outer join w long codegen=true            769 /  796        136.3           7.3      19.9X
+    outer join w long codegen=false          5667 / 5780          3.7         270.2       1.0X
+    outer join w long codegen=true            216 /  226         97.2          10.3      26.3X
       */
 
     runBenchmark("semi join w long", N) {
-      sqlContext.range(N).join(dim, (col("id") bitwiseAND M) === col("k"), "leftsemi").count()
+      sqlContext.range(N).join(dim, (col("id") % M) === col("k"), "leftsemi").count()
     }
 
     /**
+    Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5
     Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
     semi join w long:                   Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
     -------------------------------------------------------------------------------------------
-    semi join w long codegen=false           5804 / 5969         18.1          55.3       1.0X
-    semi join w long codegen=true             814 /  934        128.8           7.8       7.1X
+    semi join w long codegen=false           4690 / 4953          4.5         223.7       1.0X
+    semi join w long codegen=true             211 /  229         99.2          10.1      22.2X
      */
   }
 
@@ -303,11 +322,12 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
     }
 
     /**
+    Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5
     Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
     shuffle hash join:                  Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
     -------------------------------------------------------------------------------------------
-    shuffle hash join codegen=false          1168 / 1902          3.6         278.6       1.0X
-    shuffle hash join codegen=true            850 / 1196          4.9         202.8       1.4X
+    shuffle hash join codegen=false          1538 / 1742          2.7         366.7       1.0X
+    shuffle hash join codegen=true            892 / 1329          4.7         212.6       1.7X
      */
   }
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index ed4cc1c4c4..ed87a99439 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -34,20 +34,20 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
     val schema = StructType(StructField("a", IntegerType, true) :: Nil)
     val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2))
     val toUnsafe = UnsafeProjection.create(schema)
-    val unsafeData = data.map(toUnsafe(_).copy()).toArray
+    val unsafeData = data.map(toUnsafe(_).copy())
 
     val buildKey = Seq(BoundReference(0, IntegerType, false))
     val keyGenerator = UnsafeProjection.create(buildKey)
     val hashed = UnsafeHashedRelation(unsafeData.iterator, keyGenerator, 1)
     assert(hashed.isInstanceOf[UnsafeHashedRelation])
 
-    assert(hashed.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0)))
-    assert(hashed.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1)))
+    assert(hashed.get(unsafeData(0)).toArray === Array(unsafeData(0)))
+    assert(hashed.get(unsafeData(1)).toArray === Array(unsafeData(1)))
     assert(hashed.get(toUnsafe(InternalRow(10))) === null)
 
     val data2 = CompactBuffer[InternalRow](unsafeData(2).copy())
     data2 += unsafeData(2).copy()
-    assert(hashed.get(unsafeData(2)) === data2)
+    assert(hashed.get(unsafeData(2)).toArray === data2.toArray)
 
     val os = new ByteArrayOutputStream()
     val out = new ObjectOutputStream(os)
@@ -56,10 +56,10 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
     val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray))
     val hashed2 = new UnsafeHashedRelation()
     hashed2.readExternal(in)
-    assert(hashed2.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0)))
-    assert(hashed2.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1)))
+    assert(hashed2.get(unsafeData(0)).toArray === Array(unsafeData(0)))
+    assert(hashed2.get(unsafeData(1)).toArray === Array(unsafeData(1)))
     assert(hashed2.get(toUnsafe(InternalRow(10))) === null)
-    assert(hashed2.get(unsafeData(2)) === data2)
+    assert(hashed2.get(unsafeData(2)).toArray === data2)
 
     val os2 = new ByteArrayOutputStream()
     val out2 = new ObjectOutputStream(os2)
-- 
GitLab