From 95e1ab223e87fc216f3256d404fe3be50d111a9d Mon Sep 17 00:00:00 2001
From: Davies Liu <davies@databricks.com>
Date: Thu, 18 Feb 2016 15:15:06 -0800
Subject: [PATCH] [SPARK-13237] [SQL] generated broadcast outer join

This PR support codegen for broadcast outer join.

In order to reduce the duplicated codes, this PR merge HashJoin and HashOuterJoin together (also BroadcastHashJoin and BroadcastHashOuterJoin).

Author: Davies Liu <davies@databricks.com>

Closes #11130 from davies/gen_out.
---
 .../spark/sql/execution/SparkStrategies.scala |  16 +-
 .../sql/execution/WholeStageCodegen.scala     |   8 +-
 .../execution/joins/BroadcastHashJoin.scala   | 253 ++++++++++++++----
 .../joins/BroadcastHashOuterJoin.scala        | 121 ---------
 .../spark/sql/execution/joins/HashJoin.scala  | 111 +++++++-
 .../sql/execution/joins/HashOuterJoin.scala   | 153 -----------
 .../org/apache/spark/sql/JoinSuite.scala      |   5 +-
 .../BenchmarkWholeStageCodegen.scala          | 131 ++++++++-
 .../execution/joins/BroadcastJoinSuite.scala  |   2 +-
 .../sql/execution/joins/InnerJoinSuite.scala  |   2 +-
 .../sql/execution/joins/OuterJoinSuite.scala  |   9 +-
 .../execution/metric/SQLMetricsSuite.scala    |   8 +-
 12 files changed, 448 insertions(+), 371 deletions(-)
 delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
 delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 042c99db4d..382654afac 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -108,12 +108,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
       // --- Inner joins --------------------------------------------------------------------------
 
       case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
-        joins.BroadcastHashJoin(
-          leftKeys, rightKeys, BuildRight, condition, planLater(left), planLater(right)) :: Nil
+        Seq(joins.BroadcastHashJoin(
+          leftKeys, rightKeys, Inner, BuildRight, condition, planLater(left), planLater(right)))
 
       case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) =>
-        joins.BroadcastHashJoin(
-          leftKeys, rightKeys, BuildLeft, condition, planLater(left), planLater(right)) :: Nil
+        Seq(joins.BroadcastHashJoin(
+          leftKeys, rightKeys, Inner, BuildLeft, condition, planLater(left), planLater(right)))
 
       case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
         if RowOrdering.isOrderable(leftKeys) =>
@@ -124,13 +124,13 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
 
       case ExtractEquiJoinKeys(
           LeftOuter, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
-        joins.BroadcastHashOuterJoin(
-          leftKeys, rightKeys, LeftOuter, condition, planLater(left), planLater(right)) :: Nil
+        Seq(joins.BroadcastHashJoin(
+          leftKeys, rightKeys, LeftOuter, BuildRight, condition, planLater(left), planLater(right)))
 
       case ExtractEquiJoinKeys(
           RightOuter, leftKeys, rightKeys, condition, CanBroadcast(left), right) =>
-        joins.BroadcastHashOuterJoin(
-          leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil
+        Seq(joins.BroadcastHashJoin(
+          leftKeys, rightKeys, RightOuter, BuildLeft, condition, planLater(left), planLater(right)))
 
       case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
         if RowOrdering.isOrderable(leftKeys) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index f35efb5b24..8626f54eb4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution.aggregate.TungstenAggregate
 import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight}
-import org.apache.spark.sql.execution.metric.{LongSQLMetric, LongSQLMetricValue, SQLMetric}
+import org.apache.spark.sql.execution.metric.LongSQLMetricValue
 
 /**
   * An interface for those physical operators that support codegen.
@@ -38,7 +38,7 @@ trait CodegenSupport extends SparkPlan {
   /** Prefix used in the current operator's variable names. */
   private def variablePrefix: String = this match {
     case _: TungstenAggregate => "agg"
-    case _: BroadcastHashJoin => "bhj"
+    case _: BroadcastHashJoin => "join"
     case _ => nodeName.toLowerCase
   }
 
@@ -391,9 +391,9 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru
           var inputs = ArrayBuffer[SparkPlan]()
           val combined = plan.transform {
             // The build side can't be compiled together
-            case b @ BroadcastHashJoin(_, _, BuildLeft, _, left, right) =>
+            case b @ BroadcastHashJoin(_, _, _, BuildLeft, _, left, right) =>
               b.copy(left = apply(left))
-            case b @ BroadcastHashJoin(_, _, BuildRight, _, left, right) =>
+            case b @ BroadcastHashJoin(_, _, _, BuildRight, _, left, right) =>
               b.copy(right = apply(right))
             case p if !supportCodegen(p) =>
               val input = apply(p)  // collapse them recursively
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index 985e74011d..a64da22580 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -24,8 +24,9 @@ import org.apache.spark.TaskContext
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{BindReferences, BoundReference, Expression, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection}
+import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftOuter, RightOuter}
 import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution}
 import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan, SQLExecution}
 import org.apache.spark.sql.execution.metric.SQLMetrics
@@ -41,6 +42,7 @@ import org.apache.spark.util.collection.CompactBuffer
 case class BroadcastHashJoin(
     leftKeys: Seq[Expression],
     rightKeys: Seq[Expression],
+    joinType: JoinType,
     buildSide: BuildSide,
     condition: Option[Expression],
     left: SparkPlan,
@@ -105,75 +107,144 @@ case class BroadcastHashJoin(
     val broadcastRelation = Await.result(broadcastFuture, timeout)
 
     streamedPlan.execute().mapPartitions { streamedIter =>
-      val hashedRelation = broadcastRelation.value
-      TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize)
-      hashJoin(streamedIter, hashedRelation, numOutputRows)
+      val joinedRow = new JoinedRow()
+      val hashTable = broadcastRelation.value
+      TaskContext.get().taskMetrics().incPeakExecutionMemory(hashTable.getMemorySize)
+      val keyGenerator = streamSideKeyGenerator
+      val resultProj = createResultProjection
+
+      joinType match {
+        case Inner =>
+          hashJoin(streamedIter, hashTable, numOutputRows)
+
+        case LeftOuter =>
+          streamedIter.flatMap { currentRow =>
+            val rowKey = keyGenerator(currentRow)
+            joinedRow.withLeft(currentRow)
+            leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj, numOutputRows)
+          }
+
+        case RightOuter =>
+          streamedIter.flatMap { currentRow =>
+            val rowKey = keyGenerator(currentRow)
+            joinedRow.withRight(currentRow)
+            rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj, numOutputRows)
+          }
+
+        case x =>
+          throw new IllegalArgumentException(
+            s"BroadcastHashJoin should not take $x as the JoinType")
+      }
     }
   }
 
-  private var broadcastRelation: Broadcast[HashedRelation] = _
-  // the term for hash relation
-  private var relationTerm: String = _
-
   override def upstream(): RDD[InternalRow] = {
     streamedPlan.asInstanceOf[CodegenSupport].upstream()
   }
 
   override def doProduce(ctx: CodegenContext): String = {
+    streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)
+  }
+
+  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+    if (joinType == Inner) {
+      codegenInner(ctx, input)
+    } else {
+      // LeftOuter and RightOuter
+      codegenOuter(ctx, input)
+    }
+  }
+
+  /**
+    * Returns a tuple of Broadcast of HashedRelation and the variable name for it.
+    */
+  private def prepareBroadcast(ctx: CodegenContext): (Broadcast[HashedRelation], String) = {
     // create a name for HashedRelation
-    broadcastRelation = Await.result(broadcastFuture, timeout)
+    val broadcastRelation = Await.result(broadcastFuture, timeout)
     val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation)
-    relationTerm = ctx.freshName("relation")
+    val relationTerm = ctx.freshName("relation")
     val clsName = broadcastRelation.value.getClass.getName
     ctx.addMutableState(clsName, relationTerm,
       s"""
          | $relationTerm = ($clsName) $broadcast.value();
          | incPeakExecutionMemory($relationTerm.getMemorySize());
        """.stripMargin)
-
-    s"""
-       | ${streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)}
-     """.stripMargin
+    (broadcastRelation, relationTerm)
   }
 
-  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
-    // generate the key as UnsafeRow or Long
+  /**
+    * Returns the code for generating join key for stream side, and expression of whether the key
+    * has any null in it or not.
+    */
+  private def genStreamSideJoinKey(
+      ctx: CodegenContext,
+      input: Seq[ExprCode]): (ExprCode, String) = {
     ctx.currentVars = input
-    val (keyVal, anyNull) = if (canJoinKeyFitWithinLong) {
+    if (canJoinKeyFitWithinLong) {
+      // generate the join key as Long
       val expr = rewriteKeyExpr(streamedKeys).head
       val ev = BindReferences.bindReference(expr, streamedPlan.output).gen(ctx)
       (ev, ev.isNull)
     } else {
+      // generate the join key as UnsafeRow
       val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output))
       val ev = GenerateUnsafeProjection.createCode(ctx, keyExpr)
       (ev, s"${ev.value}.anyNull()")
     }
+  }
 
-    // find the matches from HashedRelation
-    val matched = ctx.freshName("matched")
-
-    // create variables for output
+  /**
+    * Generates the code for variable of build side.
+    */
+  private def genBuildSideVars(ctx: CodegenContext, matched: String): Seq[ExprCode] = {
     ctx.currentVars = null
     ctx.INPUT_ROW = matched
-    val buildColumns = buildPlan.output.zipWithIndex.map { case (a, i) =>
-      BoundReference(i, a.dataType, a.nullable).gen(ctx)
+    buildPlan.output.zipWithIndex.map { case (a, i) =>
+      val ev = BoundReference(i, a.dataType, a.nullable).gen(ctx)
+      if (joinType == Inner) {
+        ev
+      } else {
+        // the variables are needed even there is no matched rows
+        val isNull = ctx.freshName("isNull")
+        val value = ctx.freshName("value")
+        val code = s"""
+          |boolean $isNull = true;
+          |${ctx.javaType(a.dataType)} $value = ${ctx.defaultValue(a.dataType)};
+          |if ($matched != null) {
+          |  ${ev.code}
+          |  $isNull = ${ev.isNull};
+          |  $value = ${ev.value};
+          |}
+         """.stripMargin
+        ExprCode(code, isNull, value)
+      }
     }
+  }
+
+  /**
+    * Generates the code for Inner join.
+    */
+  private def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+    val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
+    val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+    val matched = ctx.freshName("matched")
+    val buildVars = genBuildSideVars(ctx, matched)
     val resultVars = buildSide match {
-      case BuildLeft => buildColumns ++ input
-      case BuildRight => input ++ buildColumns
+      case BuildLeft => buildVars ++ input
+      case BuildRight => input ++ buildVars
     }
-
     val numOutput = metricTerm(ctx, "numOutputRows")
+
     val outputCode = if (condition.isDefined) {
       // filter the output via condition
       ctx.currentVars = resultVars
       val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx)
       s"""
-         | ${ev.code}
-         | if (!${ev.isNull} && ${ev.value}) {
-         |   $numOutput.add(1);
-         |   ${consume(ctx, resultVars)}
-         | }
+         |${ev.code}
+         |if (!${ev.isNull} && ${ev.value}) {
+         |  $numOutput.add(1);
+         |  ${consume(ctx, resultVars)}
+         |}
        """.stripMargin
     } else {
       s"""
@@ -184,36 +255,110 @@ case class BroadcastHashJoin(
 
     if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) {
       s"""
-         | // generate join key
-         | ${keyVal.code}
-         | // find matches from HashedRelation
-         | UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyVal.value});
-         | if ($matched != null) {
-         |   ${buildColumns.map(_.code).mkString("\n")}
-         |   $outputCode
-         | }
-     """.stripMargin
+         |// generate join key for stream side
+         |${keyEv.code}
+         |// find matches from HashedRelation
+         |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value});
+         |if ($matched != null) {
+         |  ${buildVars.map(_.code).mkString("\n")}
+         |  $outputCode
+         |}
+       """.stripMargin
+
+    } else {
+      val matches = ctx.freshName("matches")
+      val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
+      val i = ctx.freshName("i")
+      val size = ctx.freshName("size")
+      s"""
+         |// generate join key for stream side
+         |${keyEv.code}
+         |// find matches from HashRelation
+         |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value});
+         |if ($matches != null) {
+         |  int $size = $matches.size();
+         |  for (int $i = 0; $i < $size; $i++) {
+         |    UnsafeRow $matched = (UnsafeRow) $matches.apply($i);
+         |    ${buildVars.map(_.code).mkString("\n")}
+         |    $outputCode
+         |  }
+         |}
+       """.stripMargin
+    }
+  }
+
+
+  /**
+    * Generates the code for left or right outer join.
+    */
+  private def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+    val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
+    val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+    val matched = ctx.freshName("matched")
+    val buildVars = genBuildSideVars(ctx, matched)
+    val resultVars = buildSide match {
+      case BuildLeft => buildVars ++ input
+      case BuildRight => input ++ buildVars
+    }
+    val numOutput = metricTerm(ctx, "numOutputRows")
+
+    // filter the output via condition
+    val conditionPassed = ctx.freshName("conditionPassed")
+    val checkCondition = if (condition.isDefined) {
+      ctx.currentVars = resultVars
+      val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx)
+      s"""
+         |boolean $conditionPassed = true;
+         |if ($matched != null) {
+         |  ${ev.code}
+         |  $conditionPassed = !${ev.isNull} && ${ev.value};
+         |}
+       """.stripMargin
+    } else {
+      s"final boolean $conditionPassed = true;"
+    }
+
+    if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) {
+      s"""
+         |// generate join key for stream side
+         |${keyEv.code}
+         |// find matches from HashedRelation
+         |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value});
+         |${buildVars.map(_.code).mkString("\n")}
+         |${checkCondition.trim}
+         |if (!$conditionPassed) {
+         |  // reset to null
+         |  ${buildVars.map(v => s"${v.isNull} = true;").mkString("\n")}
+         |}
+         |$numOutput.add(1);
+         |${consume(ctx, resultVars)}
+       """.stripMargin
 
     } else {
       val matches = ctx.freshName("matches")
       val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
       val i = ctx.freshName("i")
       val size = ctx.freshName("size")
+      val found = ctx.freshName("found")
       s"""
-         | // generate join key
-         | ${keyVal.code}
-         | // find matches from HashRelation
-         | $bufferType $matches = ${anyNull} ? null :
-         |  ($bufferType) $relationTerm.get(${keyVal.value});
-         | if ($matches != null) {
-         |   int $size = $matches.size();
-         |   for (int $i = 0; $i < $size; $i++) {
-         |     UnsafeRow $matched = (UnsafeRow) $matches.apply($i);
-         |     ${buildColumns.map(_.code).mkString("\n")}
-         |     $outputCode
-         |   }
-         | }
-     """.stripMargin
+         |// generate join key for stream side
+         |${keyEv.code}
+         |// find matches from HashRelation
+         |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value});
+         |int $size = $matches != null ? $matches.size() : 0;
+         |boolean $found = false;
+         |// the last iteration of this loop is to emit an empty row if there is no matched rows.
+         |for (int $i = 0; $i <= $size; $i++) {
+         |  UnsafeRow $matched = $i < $size ? (UnsafeRow) $matches.apply($i) : null;
+         |  ${buildVars.map(_.code).mkString("\n")}
+         |  ${checkCondition.trim}
+         |  if ($conditionPassed && ($i < $size || !$found)) {
+         |    $found = true;
+         |    $numOutput.add(1);
+         |    ${consume(ctx, resultVars)}
+         |  }
+         |}
+       """.stripMargin
     }
   }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
deleted file mode 100644
index 5e8c8ca043..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
+++ /dev/null
@@ -1,121 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.joins
-
-import scala.concurrent._
-import scala.concurrent.duration._
-
-import org.apache.spark.{InternalAccumulator, TaskContext}
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter}
-import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution}
-import org.apache.spark.sql.execution.{BinaryNode, SparkPlan, SQLExecution}
-import org.apache.spark.sql.execution.metric.SQLMetrics
-
-/**
- * Performs a outer hash join for two child relations.  When the output RDD of this operator is
- * being constructed, a Spark job is asynchronously started to calculate the values for the
- * broadcasted relation.  This data is then placed in a Spark broadcast variable.  The streamed
- * relation is not shuffled.
- */
-case class BroadcastHashOuterJoin(
-    leftKeys: Seq[Expression],
-    rightKeys: Seq[Expression],
-    joinType: JoinType,
-    condition: Option[Expression],
-    left: SparkPlan,
-    right: SparkPlan) extends BinaryNode with HashOuterJoin {
-
-  override private[sql] lazy val metrics = Map(
-    "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
-
-  val timeout = {
-    val timeoutValue = sqlContext.conf.broadcastTimeout
-    if (timeoutValue < 0) {
-      Duration.Inf
-    } else {
-      timeoutValue.seconds
-    }
-  }
-
-  override def requiredChildDistribution: Seq[Distribution] =
-    UnspecifiedDistribution :: UnspecifiedDistribution :: Nil
-
-  override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning
-
-  // Use lazy so that we won't do broadcast when calling explain but still cache the broadcast value
-  // for the same query.
-  @transient
-  private lazy val broadcastFuture = {
-    // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here.
-    val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
-    Future {
-      // This will run in another thread. Set the execution id so that we can connect these jobs
-      // with the correct execution.
-      SQLExecution.withExecutionId(sparkContext, executionId) {
-        // Note that we use .execute().collect() because we don't want to convert data to Scala
-        // types
-        val input: Array[InternalRow] = buildPlan.execute().map { row =>
-          row.copy()
-        }.collect()
-        val hashed = HashedRelation(input.iterator, buildKeyGenerator, input.size)
-        sparkContext.broadcast(hashed)
-      }
-    }(BroadcastHashJoin.broadcastHashJoinExecutionContext)
-  }
-
-  protected override def doPrepare(): Unit = {
-    broadcastFuture
-  }
-
-  override def doExecute(): RDD[InternalRow] = {
-    val numOutputRows = longMetric("numOutputRows")
-
-    val broadcastRelation = Await.result(broadcastFuture, timeout)
-
-    streamedPlan.execute().mapPartitions { streamedIter =>
-      val joinedRow = new JoinedRow()
-      val hashTable = broadcastRelation.value
-      val keyGenerator = streamedKeyGenerator
-      TaskContext.get().taskMetrics().incPeakExecutionMemory(hashTable.getMemorySize)
-
-      val resultProj = resultProjection
-      joinType match {
-        case LeftOuter =>
-          streamedIter.flatMap(currentRow => {
-            val rowKey = keyGenerator(currentRow)
-            joinedRow.withLeft(currentRow)
-            leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj, numOutputRows)
-          })
-
-        case RightOuter =>
-          streamedIter.flatMap(currentRow => {
-            val rowKey = keyGenerator(currentRow)
-            joinedRow.withRight(currentRow)
-            rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj, numOutputRows)
-          })
-
-        case x =>
-          throw new IllegalArgumentException(
-            s"BroadcastHashOuterJoin should not take $x as the JoinType")
-      }
-    }
-  }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index 332a748d3b..2fe9c06cc9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -21,20 +21,38 @@ import java.util.NoSuchElementException
 
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.execution.metric.LongSQLMetric
-import org.apache.spark.sql.types.{IntegralType, LongType}
+import org.apache.spark.sql.types.{IntegerType, IntegralType, LongType}
+import org.apache.spark.util.collection.CompactBuffer
 
 trait HashJoin {
   self: SparkPlan =>
 
   val leftKeys: Seq[Expression]
   val rightKeys: Seq[Expression]
+  val joinType: JoinType
   val buildSide: BuildSide
   val condition: Option[Expression]
   val left: SparkPlan
   val right: SparkPlan
 
+  override def output: Seq[Attribute] = {
+    joinType match {
+      case Inner =>
+        left.output ++ right.output
+      case LeftOuter =>
+        left.output ++ right.output.map(_.withNullability(true))
+      case RightOuter =>
+        left.output.map(_.withNullability(true)) ++ right.output
+      case FullOuter =>
+        left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
+      case x =>
+        throw new IllegalArgumentException(s"HashJoin should not take $x as the JoinType")
+    }
+  }
+
   protected lazy val (buildPlan, streamedPlan) = buildSide match {
     case BuildLeft => (left, right)
     case BuildRight => (right, left)
@@ -45,8 +63,6 @@ trait HashJoin {
     case BuildRight => (rightKeys, leftKeys)
   }
 
-  override def output: Seq[Attribute] = left.output ++ right.output
-
   /**
     * Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long.
     *
@@ -67,8 +83,17 @@ trait HashJoin {
             width = dt.defaultSize
           } else {
             val bits = dt.defaultSize * 8
+            // hashCode of Long is (l >> 32) ^ l.toInt, it means the hash code of an long with same
+            // value in high 32 bit and low 32 bit will be 0. To avoid the worst case that keys
+            // with two same ints have hash code 0, we rotate the bits of second one.
+            val rotated = if (e.dataType == IntegerType) {
+              // (e >>> 15) | (e << 17)
+              BitwiseOr(ShiftRightUnsigned(e, Literal(15)), ShiftLeft(e, Literal(17)))
+            } else {
+              e
+            }
             keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)),
-              BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1)))
+              BitwiseAnd(Cast(rotated, LongType), Literal((1L << bits) - 1)))
             width -= bits
           }
         // TODO: support BooleanType, DateType and TimestampType
@@ -97,11 +122,13 @@ trait HashJoin {
     (r: InternalRow) => true
   }
 
+  protected def createResultProjection: (InternalRow) => InternalRow =
+    UnsafeProjection.create(self.schema)
+
   protected def hashJoin(
       streamIter: Iterator[InternalRow],
       hashedRelation: HashedRelation,
-      numOutputRows: LongSQLMetric): Iterator[InternalRow] =
-  {
+      numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
     new Iterator[InternalRow] {
       private[this] var currentStreamedRow: InternalRow = _
       private[this] var currentHashMatches: Seq[InternalRow] = _
@@ -109,8 +136,7 @@ trait HashJoin {
 
       // Mutable per row objects.
       private[this] val joinRow = new JoinedRow
-      private[this] val resultProjection: (InternalRow) => InternalRow =
-        UnsafeProjection.create(self.schema)
+      private[this] val resultProjection = createResultProjection
 
       private[this] val joinKeys = streamSideKeyGenerator
 
@@ -163,4 +189,73 @@ trait HashJoin {
       }
     }
   }
+
+  @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]()
+
+  @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length)
+  @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length)
+
+  protected[this] def leftOuterIterator(
+      key: InternalRow,
+      joinedRow: JoinedRow,
+      rightIter: Iterable[InternalRow],
+      resultProjection: InternalRow => InternalRow,
+      numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
+    val ret: Iterable[InternalRow] = {
+      if (!key.anyNull) {
+        val temp = if (rightIter != null) {
+          rightIter.collect {
+            case r if boundCondition(joinedRow.withRight(r)) => {
+              numOutputRows += 1
+              resultProjection(joinedRow).copy()
+            }
+          }
+        } else {
+          List.empty
+        }
+        if (temp.isEmpty) {
+          numOutputRows += 1
+          resultProjection(joinedRow.withRight(rightNullRow)) :: Nil
+        } else {
+          temp
+        }
+      } else {
+        numOutputRows += 1
+        resultProjection(joinedRow.withRight(rightNullRow)) :: Nil
+      }
+    }
+    ret.iterator
+  }
+
+  protected[this] def rightOuterIterator(
+      key: InternalRow,
+      leftIter: Iterable[InternalRow],
+      joinedRow: JoinedRow,
+      resultProjection: InternalRow => InternalRow,
+      numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
+    val ret: Iterable[InternalRow] = {
+      if (!key.anyNull) {
+        val temp = if (leftIter != null) {
+          leftIter.collect {
+            case l if boundCondition(joinedRow.withLeft(l)) => {
+              numOutputRows += 1
+              resultProjection(joinedRow).copy()
+            }
+          }
+        } else {
+          List.empty
+        }
+        if (temp.isEmpty) {
+          numOutputRows += 1
+          resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil
+        } else {
+          temp
+        }
+      } else {
+        numOutputRows += 1
+        resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil
+      }
+    }
+    ret.iterator
+  }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
deleted file mode 100644
index 9e614309de..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
+++ /dev/null
@@ -1,153 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.joins
-
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.execution.metric.LongSQLMetric
-import org.apache.spark.util.collection.CompactBuffer
-
-
-trait HashOuterJoin {
-  self: SparkPlan =>
-
-  val leftKeys: Seq[Expression]
-  val rightKeys: Seq[Expression]
-  val joinType: JoinType
-  val condition: Option[Expression]
-  val left: SparkPlan
-  val right: SparkPlan
-
-  override def output: Seq[Attribute] = {
-    joinType match {
-      case LeftOuter =>
-        left.output ++ right.output.map(_.withNullability(true))
-      case RightOuter =>
-        left.output.map(_.withNullability(true)) ++ right.output
-      case FullOuter =>
-        left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
-      case x =>
-        throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType")
-    }
-  }
-
-  protected[this] lazy val (buildPlan, streamedPlan) = joinType match {
-    case RightOuter => (left, right)
-    case LeftOuter => (right, left)
-    case x =>
-      throw new IllegalArgumentException(
-        s"HashOuterJoin should not take $x as the JoinType")
-  }
-
-  protected[this] lazy val (buildKeys, streamedKeys) = joinType match {
-    case RightOuter => (leftKeys, rightKeys)
-    case LeftOuter => (rightKeys, leftKeys)
-    case x =>
-      throw new IllegalArgumentException(
-        s"HashOuterJoin should not take $x as the JoinType")
-  }
-
-  protected def buildKeyGenerator: Projection =
-    UnsafeProjection.create(buildKeys, buildPlan.output)
-
-  protected[this] def streamedKeyGenerator: Projection =
-    UnsafeProjection.create(streamedKeys, streamedPlan.output)
-
-  protected[this] def resultProjection: InternalRow => InternalRow =
-    UnsafeProjection.create(output, output)
-
-  @transient private[this] lazy val DUMMY_LIST = CompactBuffer[InternalRow](null)
-  @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]()
-
-  @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length)
-  @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length)
-  @transient private[this] lazy val boundCondition = if (condition.isDefined) {
-    newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
-  } else {
-    (row: InternalRow) => true
-  }
-
-  // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala
-  // iterator for performance purpose.
-
-  protected[this] def leftOuterIterator(
-      key: InternalRow,
-      joinedRow: JoinedRow,
-      rightIter: Iterable[InternalRow],
-      resultProjection: InternalRow => InternalRow,
-      numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
-    val ret: Iterable[InternalRow] = {
-      if (!key.anyNull) {
-        val temp = if (rightIter != null) {
-          rightIter.collect {
-            case r if boundCondition(joinedRow.withRight(r)) => {
-              numOutputRows += 1
-              resultProjection(joinedRow).copy()
-            }
-          }
-        } else {
-          List.empty
-        }
-        if (temp.isEmpty) {
-          numOutputRows += 1
-          resultProjection(joinedRow.withRight(rightNullRow)) :: Nil
-        } else {
-          temp
-        }
-      } else {
-        numOutputRows += 1
-        resultProjection(joinedRow.withRight(rightNullRow)) :: Nil
-      }
-    }
-    ret.iterator
-  }
-
-  protected[this] def rightOuterIterator(
-      key: InternalRow,
-      leftIter: Iterable[InternalRow],
-      joinedRow: JoinedRow,
-      resultProjection: InternalRow => InternalRow,
-      numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
-    val ret: Iterable[InternalRow] = {
-      if (!key.anyNull) {
-        val temp = if (leftIter != null) {
-          leftIter.collect {
-            case l if boundCondition(joinedRow.withLeft(l)) => {
-              numOutputRows += 1
-              resultProjection(joinedRow).copy()
-            }
-          }
-        } else {
-          List.empty
-        }
-        if (temp.isEmpty) {
-          numOutputRows += 1
-          resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil
-        } else {
-          temp
-        }
-      } else {
-        numOutputRows += 1
-        resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil
-      }
-    }
-    ret.iterator
-  }
-}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 9a3c262e94..92ff7e73fa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -46,7 +46,6 @@ class JoinSuite extends QueryTest with SharedSQLContext {
     val operators = physical.collect {
       case j: LeftSemiJoinHash => j
       case j: BroadcastHashJoin => j
-      case j: BroadcastHashOuterJoin => j
       case j: LeftSemiJoinBNL => j
       case j: CartesianProduct => j
       case j: BroadcastNestedLoopJoin => j
@@ -123,9 +122,9 @@ class JoinSuite extends QueryTest with SharedSQLContext {
       ("SELECT * FROM testData LEFT JOIN testData2 ON key = a",
         classOf[SortMergeOuterJoin]),
       ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2",
-        classOf[BroadcastHashOuterJoin]),
+        classOf[BroadcastHashJoin]),
       ("SELECT * FROM testData right join testData2 ON key = a and key = 2",
-        classOf[BroadcastHashOuterJoin])
+        classOf[BroadcastHashJoin])
     ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
     sql("UNCACHE TABLE testData")
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index 4a151179bf..bcac660a35 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.execution
 
+import java.util.HashMap
+
 import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
 import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
 import org.apache.spark.sql.SQLContext
@@ -124,37 +126,65 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
 
   ignore("broadcast hash join") {
     val N = 100 << 20
-    val dim = broadcast(sqlContext.range(1 << 16).selectExpr("id as k", "cast(id as string) as v"))
+    val M = 1 << 16
+    val dim = broadcast(sqlContext.range(M).selectExpr("id as k", "cast(id as string) as v"))
 
     runBenchmark("Join w long", N) {
-      sqlContext.range(N).join(dim, (col("id") % 60000) === col("k")).count()
+      sqlContext.range(N).join(dim, (col("id") bitwiseAND M) === col("k")).count()
     }
 
     /*
     Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
-    BroadcastHashJoin:                  Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    Join w long:                        Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
     -------------------------------------------------------------------------------------------
-    Join w long codegen=false        10174 / 10317         10.0         100.0       1.0X
-    Join w long codegen=true           1069 / 1107         98.0          10.2       9.5X
+    Join w long codegen=false                5744 / 5814         18.3          54.8       1.0X
+    Join w long codegen=true                  735 /  853        142.7           7.0       7.8X
     */
 
-    val dim2 = broadcast(sqlContext.range(1 << 16)
+    val dim2 = broadcast(sqlContext.range(M)
       .selectExpr("cast(id as int) as k1", "cast(id as int) as k2", "cast(id as string) as v"))
 
     runBenchmark("Join w 2 ints", N) {
       sqlContext.range(N).join(dim2,
-        (col("id") bitwiseAND 60000).cast(IntegerType) === col("k1")
-          && (col("id") bitwiseAND 50000).cast(IntegerType) === col("k2")).count()
+        (col("id") bitwiseAND M).cast(IntegerType) === col("k1")
+          && (col("id") bitwiseAND M).cast(IntegerType) === col("k2")).count()
     }
 
     /**
     Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
-    BroadcastHashJoin:                  Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    Join w 2 ints:                      Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
     -------------------------------------------------------------------------------------------
-    Join w 2 ints codegen=false           11435 / 11530          9.0         111.1       1.0X
-    Join w 2 ints codegen=true              1265 / 1424         82.0          12.2       9.0X
+    Join w 2 ints codegen=false              7159 / 7224         14.6          68.3       1.0X
+    Join w 2 ints codegen=true               1135 / 1197         92.4          10.8       6.3X
     */
 
+    val dim3 = broadcast(sqlContext.range(M)
+      .selectExpr("id as k1", "id as k2", "cast(id as string) as v"))
+
+    runBenchmark("Join w 2 longs", N) {
+      sqlContext.range(N).join(dim3,
+        (col("id") bitwiseAND M) === col("k1") && (col("id") bitwiseAND M) === col("k2"))
+        .count()
+    }
+
+    /**
+    Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+    Join w 2 longs:                      Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    -------------------------------------------------------------------------------------------
+    Join w 2 longs codegen=false              7877 / 8358         13.3          75.1       1.0X
+    Join w 2 longs codegen=true               3877 / 3937         27.0          37.0       2.0X
+      */
+    runBenchmark("outer join w long", N) {
+      sqlContext.range(N).join(dim, (col("id") bitwiseAND M) === col("k"), "left").count()
+    }
+
+    /**
+    Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+    outer join w long:                  Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    -------------------------------------------------------------------------------------------
+    outer join w long codegen=false        15280 / 16497          6.9         145.7       1.0X
+    outer join w long codegen=true            769 /  796        136.3           7.3      19.9X
+      */
   }
 
   ignore("rube") {
@@ -175,7 +205,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
   }
 
   ignore("hash and BytesToBytesMap") {
-    val N = 50 << 20
+    val N = 10 << 20
 
     val benchmark = new Benchmark("BytesToBytesMap", N)
 
@@ -227,6 +257,80 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
       }
     }
 
+    benchmark.addCase("Java HashMap (Long)") { iter =>
+      var i = 0
+      val keyBytes = new Array[Byte](16)
+      val valueBytes = new Array[Byte](16)
+      val value = new UnsafeRow(1)
+      value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+      value.setInt(0, 555)
+      val map = new HashMap[Long, UnsafeRow]()
+      while (i < 65536) {
+        value.setInt(0, i)
+        map.put(i.toLong, value)
+        i += 1
+      }
+      var s = 0
+      i = 0
+      while (i < N) {
+        if (map.get(i % 100000) != null) {
+          s += 1
+        }
+        i += 1
+      }
+    }
+
+    benchmark.addCase("Java HashMap (two ints) ") { iter =>
+      var i = 0
+      val valueBytes = new Array[Byte](16)
+      val value = new UnsafeRow(1)
+      value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+      value.setInt(0, 555)
+      val map = new HashMap[Long, UnsafeRow]()
+      while (i < 65536) {
+        value.setInt(0, i)
+        val key = (i.toLong << 32) + Integer.rotateRight(i, 15)
+        map.put(key, value)
+        i += 1
+      }
+      var s = 0
+      i = 0
+      while (i < N) {
+        val key = ((i & 100000).toLong << 32) + Integer.rotateRight(i & 100000, 15)
+        if (map.get(key) != null) {
+          s += 1
+        }
+        i += 1
+      }
+    }
+
+    benchmark.addCase("Java HashMap (UnsafeRow)") { iter =>
+      var i = 0
+      val keyBytes = new Array[Byte](16)
+      val valueBytes = new Array[Byte](16)
+      val key = new UnsafeRow(1)
+      key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+      val value = new UnsafeRow(1)
+      value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+      value.setInt(0, 555)
+      val map = new HashMap[UnsafeRow, UnsafeRow]()
+      while (i < 65536) {
+        key.setInt(0, i)
+        value.setInt(0, i)
+        map.put(key, value.copy())
+        i += 1
+      }
+      var s = 0
+      i = 0
+      while (i < N) {
+        key.setInt(0, i % 100000)
+        if (map.get(key) != null) {
+          s += 1
+        }
+        i += 1
+      }
+    }
+
     Seq("off", "on").foreach { heap =>
       benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter =>
         val taskMemoryManager = new TaskMemoryManager(
@@ -268,6 +372,9 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
     hash                                      651 /  678         80.0          12.5       1.0X
     fast hash                                 336 /  343        155.9           6.4       1.9X
     arrayEqual                                417 /  428        125.0           8.0       1.6X
+    Java HashMap (Long)                       145 /  168         72.2          13.8       0.8X
+    Java HashMap (two ints)                   157 /  164         66.8          15.0       0.8X
+    Java HashMap (UnsafeRow)                  538 /  573         19.5          51.3       0.2X
     BytesToBytesMap (off Heap)               2594 / 2664         20.2          49.5       0.2X
     BytesToBytesMap (on Heap)                2693 / 2989         19.5          51.4       0.2X
       */
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index aee8e84db5..e25b5e0610 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -73,7 +73,7 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll {
   }
 
   test("unsafe broadcast hash outer join updates peak execution memory") {
-    testBroadcastJoin[BroadcastHashOuterJoin]("unsafe broadcast hash outer join", "left_outer")
+    testBroadcastJoin[BroadcastHashJoin]("unsafe broadcast hash outer join", "left_outer")
   }
 
   test("unsafe broadcast left semi join updates peak execution memory") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
index 149f34dbd7..e22a810a6b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
@@ -88,7 +88,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
         leftPlan: SparkPlan,
         rightPlan: SparkPlan,
         side: BuildSide) = {
-      joins.BroadcastHashJoin(leftKeys, rightKeys, side, boundCondition, leftPlan, rightPlan)
+      joins.BroadcastHashJoin(leftKeys, rightKeys, Inner, side, boundCondition, leftPlan, rightPlan)
     }
 
     def makeSortMergeJoin(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
index 3d3e9a7b90..f4b01fbad0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
@@ -75,11 +75,16 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext {
     }
 
     if (joinType != FullOuter) {
-      test(s"$testName using BroadcastHashOuterJoin") {
+      test(s"$testName using BroadcastHashJoin") {
+        val buildSide = joinType match {
+          case LeftOuter => BuildRight
+          case RightOuter => BuildLeft
+        }
         extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
           withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
             checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
-              BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right),
+              BroadcastHashJoin(
+                leftKeys, rightKeys, joinType, buildSide, boundCondition, left, right),
               expectedAnswer.map(Row.fromTuple),
               sortAnswers = true)
           }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index f4bc9e501c..46bb699b78 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -209,20 +209,20 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
     )
   }
 
-  test("BroadcastHashOuterJoin metrics") {
+  test("BroadcastHashJoin(outer) metrics") {
     val df1 = Seq((1, "a"), (1, "b"), (4, "c")).toDF("key", "value")
     val df2 = Seq((1, "a"), (1, "b"), (2, "c"), (3, "d")).toDF("key2", "value")
     // Assume the execution plan is
-    // ... -> BroadcastHashOuterJoin(nodeId = 0)
+    // ... -> BroadcastHashJoin(nodeId = 0)
     val df = df1.join(broadcast(df2), $"key" === $"key2", "left_outer")
     testSparkPlanMetrics(df, 2, Map(
-      0L -> ("BroadcastHashOuterJoin", Map(
+      0L -> ("BroadcastHashJoin", Map(
         "number of output rows" -> 5L)))
     )
 
     val df3 = df1.join(broadcast(df2), $"key" === $"key2", "right_outer")
     testSparkPlanMetrics(df3, 2, Map(
-      0L -> ("BroadcastHashOuterJoin", Map(
+      0L -> ("BroadcastHashJoin", Map(
         "number of output rows" -> 6L)))
     )
   }
-- 
GitLab