diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 042c99db4dcffd23ee24973764de3c69703565d1..382654afacb896d55cddc8d57170a9351955348c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -108,12 +108,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
       // --- Inner joins --------------------------------------------------------------------------
 
       case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
-        joins.BroadcastHashJoin(
-          leftKeys, rightKeys, BuildRight, condition, planLater(left), planLater(right)) :: Nil
+        Seq(joins.BroadcastHashJoin(
+          leftKeys, rightKeys, Inner, BuildRight, condition, planLater(left), planLater(right)))
 
       case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) =>
-        joins.BroadcastHashJoin(
-          leftKeys, rightKeys, BuildLeft, condition, planLater(left), planLater(right)) :: Nil
+        Seq(joins.BroadcastHashJoin(
+          leftKeys, rightKeys, Inner, BuildLeft, condition, planLater(left), planLater(right)))
 
       case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
         if RowOrdering.isOrderable(leftKeys) =>
@@ -124,13 +124,13 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
 
       case ExtractEquiJoinKeys(
           LeftOuter, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
-        joins.BroadcastHashOuterJoin(
-          leftKeys, rightKeys, LeftOuter, condition, planLater(left), planLater(right)) :: Nil
+        Seq(joins.BroadcastHashJoin(
+          leftKeys, rightKeys, LeftOuter, BuildRight, condition, planLater(left), planLater(right)))
 
       case ExtractEquiJoinKeys(
           RightOuter, leftKeys, rightKeys, condition, CanBroadcast(left), right) =>
-        joins.BroadcastHashOuterJoin(
-          leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil
+        Seq(joins.BroadcastHashJoin(
+          leftKeys, rightKeys, RightOuter, BuildLeft, condition, planLater(left), planLater(right)))
 
       case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
         if RowOrdering.isOrderable(leftKeys) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index f35efb5b24b1f2d8815a70032a50f64c2a962eb0..8626f54eb413cd2fcb0b849aa4695520fddff0bb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution.aggregate.TungstenAggregate
 import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight}
-import org.apache.spark.sql.execution.metric.{LongSQLMetric, LongSQLMetricValue, SQLMetric}
+import org.apache.spark.sql.execution.metric.LongSQLMetricValue
 
 /**
   * An interface for those physical operators that support codegen.
@@ -38,7 +38,7 @@ trait CodegenSupport extends SparkPlan {
   /** Prefix used in the current operator's variable names. */
   private def variablePrefix: String = this match {
     case _: TungstenAggregate => "agg"
-    case _: BroadcastHashJoin => "bhj"
+    case _: BroadcastHashJoin => "join"
     case _ => nodeName.toLowerCase
   }
 
@@ -391,9 +391,9 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru
           var inputs = ArrayBuffer[SparkPlan]()
           val combined = plan.transform {
             // The build side can't be compiled together
-            case b @ BroadcastHashJoin(_, _, BuildLeft, _, left, right) =>
+            case b @ BroadcastHashJoin(_, _, _, BuildLeft, _, left, right) =>
               b.copy(left = apply(left))
-            case b @ BroadcastHashJoin(_, _, BuildRight, _, left, right) =>
+            case b @ BroadcastHashJoin(_, _, _, BuildRight, _, left, right) =>
               b.copy(right = apply(right))
             case p if !supportCodegen(p) =>
               val input = apply(p)  // collapse them recursively
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index 985e74011daa7e0e9f2350e910548c68a7f200b3..a64da225800a308830761ca11812665309f56f33 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -24,8 +24,9 @@ import org.apache.spark.TaskContext
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{BindReferences, BoundReference, Expression, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection}
+import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftOuter, RightOuter}
 import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution}
 import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan, SQLExecution}
 import org.apache.spark.sql.execution.metric.SQLMetrics
@@ -41,6 +42,7 @@ import org.apache.spark.util.collection.CompactBuffer
 case class BroadcastHashJoin(
     leftKeys: Seq[Expression],
     rightKeys: Seq[Expression],
+    joinType: JoinType,
     buildSide: BuildSide,
     condition: Option[Expression],
     left: SparkPlan,
@@ -105,75 +107,144 @@ case class BroadcastHashJoin(
     val broadcastRelation = Await.result(broadcastFuture, timeout)
 
     streamedPlan.execute().mapPartitions { streamedIter =>
-      val hashedRelation = broadcastRelation.value
-      TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize)
-      hashJoin(streamedIter, hashedRelation, numOutputRows)
+      val joinedRow = new JoinedRow()
+      val hashTable = broadcastRelation.value
+      TaskContext.get().taskMetrics().incPeakExecutionMemory(hashTable.getMemorySize)
+      val keyGenerator = streamSideKeyGenerator
+      val resultProj = createResultProjection
+
+      joinType match {
+        case Inner =>
+          hashJoin(streamedIter, hashTable, numOutputRows)
+
+        case LeftOuter =>
+          streamedIter.flatMap { currentRow =>
+            val rowKey = keyGenerator(currentRow)
+            joinedRow.withLeft(currentRow)
+            leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj, numOutputRows)
+          }
+
+        case RightOuter =>
+          streamedIter.flatMap { currentRow =>
+            val rowKey = keyGenerator(currentRow)
+            joinedRow.withRight(currentRow)
+            rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj, numOutputRows)
+          }
+
+        case x =>
+          throw new IllegalArgumentException(
+            s"BroadcastHashJoin should not take $x as the JoinType")
+      }
     }
   }
 
-  private var broadcastRelation: Broadcast[HashedRelation] = _
-  // the term for hash relation
-  private var relationTerm: String = _
-
   override def upstream(): RDD[InternalRow] = {
     streamedPlan.asInstanceOf[CodegenSupport].upstream()
   }
 
   override def doProduce(ctx: CodegenContext): String = {
+    streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)
+  }
+
+  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+    if (joinType == Inner) {
+      codegenInner(ctx, input)
+    } else {
+      // LeftOuter and RightOuter
+      codegenOuter(ctx, input)
+    }
+  }
+
+  /**
+    * Returns a tuple of Broadcast of HashedRelation and the variable name for it.
+    */
+  private def prepareBroadcast(ctx: CodegenContext): (Broadcast[HashedRelation], String) = {
     // create a name for HashedRelation
-    broadcastRelation = Await.result(broadcastFuture, timeout)
+    val broadcastRelation = Await.result(broadcastFuture, timeout)
     val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation)
-    relationTerm = ctx.freshName("relation")
+    val relationTerm = ctx.freshName("relation")
     val clsName = broadcastRelation.value.getClass.getName
     ctx.addMutableState(clsName, relationTerm,
       s"""
          | $relationTerm = ($clsName) $broadcast.value();
          | incPeakExecutionMemory($relationTerm.getMemorySize());
        """.stripMargin)
-
-    s"""
-       | ${streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)}
-     """.stripMargin
+    (broadcastRelation, relationTerm)
   }
 
-  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
-    // generate the key as UnsafeRow or Long
+  /**
+    * Returns the code for generating join key for stream side, and expression of whether the key
+    * has any null in it or not.
+    */
+  private def genStreamSideJoinKey(
+      ctx: CodegenContext,
+      input: Seq[ExprCode]): (ExprCode, String) = {
     ctx.currentVars = input
-    val (keyVal, anyNull) = if (canJoinKeyFitWithinLong) {
+    if (canJoinKeyFitWithinLong) {
+      // generate the join key as Long
       val expr = rewriteKeyExpr(streamedKeys).head
       val ev = BindReferences.bindReference(expr, streamedPlan.output).gen(ctx)
       (ev, ev.isNull)
     } else {
+      // generate the join key as UnsafeRow
       val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output))
       val ev = GenerateUnsafeProjection.createCode(ctx, keyExpr)
       (ev, s"${ev.value}.anyNull()")
     }
+  }
 
-    // find the matches from HashedRelation
-    val matched = ctx.freshName("matched")
-
-    // create variables for output
+  /**
+    * Generates the code for variable of build side.
+    */
+  private def genBuildSideVars(ctx: CodegenContext, matched: String): Seq[ExprCode] = {
     ctx.currentVars = null
     ctx.INPUT_ROW = matched
-    val buildColumns = buildPlan.output.zipWithIndex.map { case (a, i) =>
-      BoundReference(i, a.dataType, a.nullable).gen(ctx)
+    buildPlan.output.zipWithIndex.map { case (a, i) =>
+      val ev = BoundReference(i, a.dataType, a.nullable).gen(ctx)
+      if (joinType == Inner) {
+        ev
+      } else {
+        // the variables are needed even there is no matched rows
+        val isNull = ctx.freshName("isNull")
+        val value = ctx.freshName("value")
+        val code = s"""
+          |boolean $isNull = true;
+          |${ctx.javaType(a.dataType)} $value = ${ctx.defaultValue(a.dataType)};
+          |if ($matched != null) {
+          |  ${ev.code}
+          |  $isNull = ${ev.isNull};
+          |  $value = ${ev.value};
+          |}
+         """.stripMargin
+        ExprCode(code, isNull, value)
+      }
     }
+  }
+
+  /**
+    * Generates the code for Inner join.
+    */
+  private def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+    val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
+    val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+    val matched = ctx.freshName("matched")
+    val buildVars = genBuildSideVars(ctx, matched)
     val resultVars = buildSide match {
-      case BuildLeft => buildColumns ++ input
-      case BuildRight => input ++ buildColumns
+      case BuildLeft => buildVars ++ input
+      case BuildRight => input ++ buildVars
     }
-
     val numOutput = metricTerm(ctx, "numOutputRows")
+
     val outputCode = if (condition.isDefined) {
       // filter the output via condition
       ctx.currentVars = resultVars
       val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx)
       s"""
-         | ${ev.code}
-         | if (!${ev.isNull} && ${ev.value}) {
-         |   $numOutput.add(1);
-         |   ${consume(ctx, resultVars)}
-         | }
+         |${ev.code}
+         |if (!${ev.isNull} && ${ev.value}) {
+         |  $numOutput.add(1);
+         |  ${consume(ctx, resultVars)}
+         |}
        """.stripMargin
     } else {
       s"""
@@ -184,36 +255,110 @@ case class BroadcastHashJoin(
 
     if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) {
       s"""
-         | // generate join key
-         | ${keyVal.code}
-         | // find matches from HashedRelation
-         | UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyVal.value});
-         | if ($matched != null) {
-         |   ${buildColumns.map(_.code).mkString("\n")}
-         |   $outputCode
-         | }
-     """.stripMargin
+         |// generate join key for stream side
+         |${keyEv.code}
+         |// find matches from HashedRelation
+         |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value});
+         |if ($matched != null) {
+         |  ${buildVars.map(_.code).mkString("\n")}
+         |  $outputCode
+         |}
+       """.stripMargin
+
+    } else {
+      val matches = ctx.freshName("matches")
+      val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
+      val i = ctx.freshName("i")
+      val size = ctx.freshName("size")
+      s"""
+         |// generate join key for stream side
+         |${keyEv.code}
+         |// find matches from HashRelation
+         |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value});
+         |if ($matches != null) {
+         |  int $size = $matches.size();
+         |  for (int $i = 0; $i < $size; $i++) {
+         |    UnsafeRow $matched = (UnsafeRow) $matches.apply($i);
+         |    ${buildVars.map(_.code).mkString("\n")}
+         |    $outputCode
+         |  }
+         |}
+       """.stripMargin
+    }
+  }
+
+
+  /**
+    * Generates the code for left or right outer join.
+    */
+  private def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+    val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
+    val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+    val matched = ctx.freshName("matched")
+    val buildVars = genBuildSideVars(ctx, matched)
+    val resultVars = buildSide match {
+      case BuildLeft => buildVars ++ input
+      case BuildRight => input ++ buildVars
+    }
+    val numOutput = metricTerm(ctx, "numOutputRows")
+
+    // filter the output via condition
+    val conditionPassed = ctx.freshName("conditionPassed")
+    val checkCondition = if (condition.isDefined) {
+      ctx.currentVars = resultVars
+      val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx)
+      s"""
+         |boolean $conditionPassed = true;
+         |if ($matched != null) {
+         |  ${ev.code}
+         |  $conditionPassed = !${ev.isNull} && ${ev.value};
+         |}
+       """.stripMargin
+    } else {
+      s"final boolean $conditionPassed = true;"
+    }
+
+    if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) {
+      s"""
+         |// generate join key for stream side
+         |${keyEv.code}
+         |// find matches from HashedRelation
+         |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value});
+         |${buildVars.map(_.code).mkString("\n")}
+         |${checkCondition.trim}
+         |if (!$conditionPassed) {
+         |  // reset to null
+         |  ${buildVars.map(v => s"${v.isNull} = true;").mkString("\n")}
+         |}
+         |$numOutput.add(1);
+         |${consume(ctx, resultVars)}
+       """.stripMargin
 
     } else {
       val matches = ctx.freshName("matches")
       val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
       val i = ctx.freshName("i")
       val size = ctx.freshName("size")
+      val found = ctx.freshName("found")
       s"""
-         | // generate join key
-         | ${keyVal.code}
-         | // find matches from HashRelation
-         | $bufferType $matches = ${anyNull} ? null :
-         |  ($bufferType) $relationTerm.get(${keyVal.value});
-         | if ($matches != null) {
-         |   int $size = $matches.size();
-         |   for (int $i = 0; $i < $size; $i++) {
-         |     UnsafeRow $matched = (UnsafeRow) $matches.apply($i);
-         |     ${buildColumns.map(_.code).mkString("\n")}
-         |     $outputCode
-         |   }
-         | }
-     """.stripMargin
+         |// generate join key for stream side
+         |${keyEv.code}
+         |// find matches from HashRelation
+         |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value});
+         |int $size = $matches != null ? $matches.size() : 0;
+         |boolean $found = false;
+         |// the last iteration of this loop is to emit an empty row if there is no matched rows.
+         |for (int $i = 0; $i <= $size; $i++) {
+         |  UnsafeRow $matched = $i < $size ? (UnsafeRow) $matches.apply($i) : null;
+         |  ${buildVars.map(_.code).mkString("\n")}
+         |  ${checkCondition.trim}
+         |  if ($conditionPassed && ($i < $size || !$found)) {
+         |    $found = true;
+         |    $numOutput.add(1);
+         |    ${consume(ctx, resultVars)}
+         |  }
+         |}
+       """.stripMargin
     }
   }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
deleted file mode 100644
index 5e8c8ca0436293276ae5bfcdbab59c80a0b9d71c..0000000000000000000000000000000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
+++ /dev/null
@@ -1,121 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.joins
-
-import scala.concurrent._
-import scala.concurrent.duration._
-
-import org.apache.spark.{InternalAccumulator, TaskContext}
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter}
-import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution}
-import org.apache.spark.sql.execution.{BinaryNode, SparkPlan, SQLExecution}
-import org.apache.spark.sql.execution.metric.SQLMetrics
-
-/**
- * Performs a outer hash join for two child relations.  When the output RDD of this operator is
- * being constructed, a Spark job is asynchronously started to calculate the values for the
- * broadcasted relation.  This data is then placed in a Spark broadcast variable.  The streamed
- * relation is not shuffled.
- */
-case class BroadcastHashOuterJoin(
-    leftKeys: Seq[Expression],
-    rightKeys: Seq[Expression],
-    joinType: JoinType,
-    condition: Option[Expression],
-    left: SparkPlan,
-    right: SparkPlan) extends BinaryNode with HashOuterJoin {
-
-  override private[sql] lazy val metrics = Map(
-    "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
-
-  val timeout = {
-    val timeoutValue = sqlContext.conf.broadcastTimeout
-    if (timeoutValue < 0) {
-      Duration.Inf
-    } else {
-      timeoutValue.seconds
-    }
-  }
-
-  override def requiredChildDistribution: Seq[Distribution] =
-    UnspecifiedDistribution :: UnspecifiedDistribution :: Nil
-
-  override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning
-
-  // Use lazy so that we won't do broadcast when calling explain but still cache the broadcast value
-  // for the same query.
-  @transient
-  private lazy val broadcastFuture = {
-    // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here.
-    val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
-    Future {
-      // This will run in another thread. Set the execution id so that we can connect these jobs
-      // with the correct execution.
-      SQLExecution.withExecutionId(sparkContext, executionId) {
-        // Note that we use .execute().collect() because we don't want to convert data to Scala
-        // types
-        val input: Array[InternalRow] = buildPlan.execute().map { row =>
-          row.copy()
-        }.collect()
-        val hashed = HashedRelation(input.iterator, buildKeyGenerator, input.size)
-        sparkContext.broadcast(hashed)
-      }
-    }(BroadcastHashJoin.broadcastHashJoinExecutionContext)
-  }
-
-  protected override def doPrepare(): Unit = {
-    broadcastFuture
-  }
-
-  override def doExecute(): RDD[InternalRow] = {
-    val numOutputRows = longMetric("numOutputRows")
-
-    val broadcastRelation = Await.result(broadcastFuture, timeout)
-
-    streamedPlan.execute().mapPartitions { streamedIter =>
-      val joinedRow = new JoinedRow()
-      val hashTable = broadcastRelation.value
-      val keyGenerator = streamedKeyGenerator
-      TaskContext.get().taskMetrics().incPeakExecutionMemory(hashTable.getMemorySize)
-
-      val resultProj = resultProjection
-      joinType match {
-        case LeftOuter =>
-          streamedIter.flatMap(currentRow => {
-            val rowKey = keyGenerator(currentRow)
-            joinedRow.withLeft(currentRow)
-            leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj, numOutputRows)
-          })
-
-        case RightOuter =>
-          streamedIter.flatMap(currentRow => {
-            val rowKey = keyGenerator(currentRow)
-            joinedRow.withRight(currentRow)
-            rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj, numOutputRows)
-          })
-
-        case x =>
-          throw new IllegalArgumentException(
-            s"BroadcastHashOuterJoin should not take $x as the JoinType")
-      }
-    }
-  }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index 332a748d3bfc06f79bebddd3d16fde6a1d950918..2fe9c06cc95375427a99d9e3d518b5454eddaed6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -21,20 +21,38 @@ import java.util.NoSuchElementException
 
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.execution.metric.LongSQLMetric
-import org.apache.spark.sql.types.{IntegralType, LongType}
+import org.apache.spark.sql.types.{IntegerType, IntegralType, LongType}
+import org.apache.spark.util.collection.CompactBuffer
 
 trait HashJoin {
   self: SparkPlan =>
 
   val leftKeys: Seq[Expression]
   val rightKeys: Seq[Expression]
+  val joinType: JoinType
   val buildSide: BuildSide
   val condition: Option[Expression]
   val left: SparkPlan
   val right: SparkPlan
 
+  override def output: Seq[Attribute] = {
+    joinType match {
+      case Inner =>
+        left.output ++ right.output
+      case LeftOuter =>
+        left.output ++ right.output.map(_.withNullability(true))
+      case RightOuter =>
+        left.output.map(_.withNullability(true)) ++ right.output
+      case FullOuter =>
+        left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
+      case x =>
+        throw new IllegalArgumentException(s"HashJoin should not take $x as the JoinType")
+    }
+  }
+
   protected lazy val (buildPlan, streamedPlan) = buildSide match {
     case BuildLeft => (left, right)
     case BuildRight => (right, left)
@@ -45,8 +63,6 @@ trait HashJoin {
     case BuildRight => (rightKeys, leftKeys)
   }
 
-  override def output: Seq[Attribute] = left.output ++ right.output
-
   /**
     * Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long.
     *
@@ -67,8 +83,17 @@ trait HashJoin {
             width = dt.defaultSize
           } else {
             val bits = dt.defaultSize * 8
+            // hashCode of Long is (l >> 32) ^ l.toInt, it means the hash code of an long with same
+            // value in high 32 bit and low 32 bit will be 0. To avoid the worst case that keys
+            // with two same ints have hash code 0, we rotate the bits of second one.
+            val rotated = if (e.dataType == IntegerType) {
+              // (e >>> 15) | (e << 17)
+              BitwiseOr(ShiftRightUnsigned(e, Literal(15)), ShiftLeft(e, Literal(17)))
+            } else {
+              e
+            }
             keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)),
-              BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1)))
+              BitwiseAnd(Cast(rotated, LongType), Literal((1L << bits) - 1)))
             width -= bits
           }
         // TODO: support BooleanType, DateType and TimestampType
@@ -97,11 +122,13 @@ trait HashJoin {
     (r: InternalRow) => true
   }
 
+  protected def createResultProjection: (InternalRow) => InternalRow =
+    UnsafeProjection.create(self.schema)
+
   protected def hashJoin(
       streamIter: Iterator[InternalRow],
       hashedRelation: HashedRelation,
-      numOutputRows: LongSQLMetric): Iterator[InternalRow] =
-  {
+      numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
     new Iterator[InternalRow] {
       private[this] var currentStreamedRow: InternalRow = _
       private[this] var currentHashMatches: Seq[InternalRow] = _
@@ -109,8 +136,7 @@ trait HashJoin {
 
       // Mutable per row objects.
       private[this] val joinRow = new JoinedRow
-      private[this] val resultProjection: (InternalRow) => InternalRow =
-        UnsafeProjection.create(self.schema)
+      private[this] val resultProjection = createResultProjection
 
       private[this] val joinKeys = streamSideKeyGenerator
 
@@ -163,4 +189,73 @@ trait HashJoin {
       }
     }
   }
+
+  @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]()
+
+  @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length)
+  @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length)
+
+  protected[this] def leftOuterIterator(
+      key: InternalRow,
+      joinedRow: JoinedRow,
+      rightIter: Iterable[InternalRow],
+      resultProjection: InternalRow => InternalRow,
+      numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
+    val ret: Iterable[InternalRow] = {
+      if (!key.anyNull) {
+        val temp = if (rightIter != null) {
+          rightIter.collect {
+            case r if boundCondition(joinedRow.withRight(r)) => {
+              numOutputRows += 1
+              resultProjection(joinedRow).copy()
+            }
+          }
+        } else {
+          List.empty
+        }
+        if (temp.isEmpty) {
+          numOutputRows += 1
+          resultProjection(joinedRow.withRight(rightNullRow)) :: Nil
+        } else {
+          temp
+        }
+      } else {
+        numOutputRows += 1
+        resultProjection(joinedRow.withRight(rightNullRow)) :: Nil
+      }
+    }
+    ret.iterator
+  }
+
+  protected[this] def rightOuterIterator(
+      key: InternalRow,
+      leftIter: Iterable[InternalRow],
+      joinedRow: JoinedRow,
+      resultProjection: InternalRow => InternalRow,
+      numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
+    val ret: Iterable[InternalRow] = {
+      if (!key.anyNull) {
+        val temp = if (leftIter != null) {
+          leftIter.collect {
+            case l if boundCondition(joinedRow.withLeft(l)) => {
+              numOutputRows += 1
+              resultProjection(joinedRow).copy()
+            }
+          }
+        } else {
+          List.empty
+        }
+        if (temp.isEmpty) {
+          numOutputRows += 1
+          resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil
+        } else {
+          temp
+        }
+      } else {
+        numOutputRows += 1
+        resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil
+      }
+    }
+    ret.iterator
+  }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
deleted file mode 100644
index 9e614309de129e5d5203118685b67ffeff6d11c7..0000000000000000000000000000000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
+++ /dev/null
@@ -1,153 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.joins
-
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.execution.metric.LongSQLMetric
-import org.apache.spark.util.collection.CompactBuffer
-
-
-trait HashOuterJoin {
-  self: SparkPlan =>
-
-  val leftKeys: Seq[Expression]
-  val rightKeys: Seq[Expression]
-  val joinType: JoinType
-  val condition: Option[Expression]
-  val left: SparkPlan
-  val right: SparkPlan
-
-  override def output: Seq[Attribute] = {
-    joinType match {
-      case LeftOuter =>
-        left.output ++ right.output.map(_.withNullability(true))
-      case RightOuter =>
-        left.output.map(_.withNullability(true)) ++ right.output
-      case FullOuter =>
-        left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
-      case x =>
-        throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType")
-    }
-  }
-
-  protected[this] lazy val (buildPlan, streamedPlan) = joinType match {
-    case RightOuter => (left, right)
-    case LeftOuter => (right, left)
-    case x =>
-      throw new IllegalArgumentException(
-        s"HashOuterJoin should not take $x as the JoinType")
-  }
-
-  protected[this] lazy val (buildKeys, streamedKeys) = joinType match {
-    case RightOuter => (leftKeys, rightKeys)
-    case LeftOuter => (rightKeys, leftKeys)
-    case x =>
-      throw new IllegalArgumentException(
-        s"HashOuterJoin should not take $x as the JoinType")
-  }
-
-  protected def buildKeyGenerator: Projection =
-    UnsafeProjection.create(buildKeys, buildPlan.output)
-
-  protected[this] def streamedKeyGenerator: Projection =
-    UnsafeProjection.create(streamedKeys, streamedPlan.output)
-
-  protected[this] def resultProjection: InternalRow => InternalRow =
-    UnsafeProjection.create(output, output)
-
-  @transient private[this] lazy val DUMMY_LIST = CompactBuffer[InternalRow](null)
-  @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]()
-
-  @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length)
-  @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length)
-  @transient private[this] lazy val boundCondition = if (condition.isDefined) {
-    newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
-  } else {
-    (row: InternalRow) => true
-  }
-
-  // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala
-  // iterator for performance purpose.
-
-  protected[this] def leftOuterIterator(
-      key: InternalRow,
-      joinedRow: JoinedRow,
-      rightIter: Iterable[InternalRow],
-      resultProjection: InternalRow => InternalRow,
-      numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
-    val ret: Iterable[InternalRow] = {
-      if (!key.anyNull) {
-        val temp = if (rightIter != null) {
-          rightIter.collect {
-            case r if boundCondition(joinedRow.withRight(r)) => {
-              numOutputRows += 1
-              resultProjection(joinedRow).copy()
-            }
-          }
-        } else {
-          List.empty
-        }
-        if (temp.isEmpty) {
-          numOutputRows += 1
-          resultProjection(joinedRow.withRight(rightNullRow)) :: Nil
-        } else {
-          temp
-        }
-      } else {
-        numOutputRows += 1
-        resultProjection(joinedRow.withRight(rightNullRow)) :: Nil
-      }
-    }
-    ret.iterator
-  }
-
-  protected[this] def rightOuterIterator(
-      key: InternalRow,
-      leftIter: Iterable[InternalRow],
-      joinedRow: JoinedRow,
-      resultProjection: InternalRow => InternalRow,
-      numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
-    val ret: Iterable[InternalRow] = {
-      if (!key.anyNull) {
-        val temp = if (leftIter != null) {
-          leftIter.collect {
-            case l if boundCondition(joinedRow.withLeft(l)) => {
-              numOutputRows += 1
-              resultProjection(joinedRow).copy()
-            }
-          }
-        } else {
-          List.empty
-        }
-        if (temp.isEmpty) {
-          numOutputRows += 1
-          resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil
-        } else {
-          temp
-        }
-      } else {
-        numOutputRows += 1
-        resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil
-      }
-    }
-    ret.iterator
-  }
-}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 9a3c262e9485d0542e5043b209430abccde54797..92ff7e73fad88cc11ca2972cdfc0bf0e28b0b3f2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -46,7 +46,6 @@ class JoinSuite extends QueryTest with SharedSQLContext {
     val operators = physical.collect {
       case j: LeftSemiJoinHash => j
       case j: BroadcastHashJoin => j
-      case j: BroadcastHashOuterJoin => j
       case j: LeftSemiJoinBNL => j
       case j: CartesianProduct => j
       case j: BroadcastNestedLoopJoin => j
@@ -123,9 +122,9 @@ class JoinSuite extends QueryTest with SharedSQLContext {
       ("SELECT * FROM testData LEFT JOIN testData2 ON key = a",
         classOf[SortMergeOuterJoin]),
       ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2",
-        classOf[BroadcastHashOuterJoin]),
+        classOf[BroadcastHashJoin]),
       ("SELECT * FROM testData right join testData2 ON key = a and key = 2",
-        classOf[BroadcastHashOuterJoin])
+        classOf[BroadcastHashJoin])
     ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
     sql("UNCACHE TABLE testData")
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index 4a151179bf6f288118dab6a99faa321d729ee055..bcac660a35a6561977ee60bb1aaea8101e6e4f07 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.execution
 
+import java.util.HashMap
+
 import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
 import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
 import org.apache.spark.sql.SQLContext
@@ -124,37 +126,65 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
 
   ignore("broadcast hash join") {
     val N = 100 << 20
-    val dim = broadcast(sqlContext.range(1 << 16).selectExpr("id as k", "cast(id as string) as v"))
+    val M = 1 << 16
+    val dim = broadcast(sqlContext.range(M).selectExpr("id as k", "cast(id as string) as v"))
 
     runBenchmark("Join w long", N) {
-      sqlContext.range(N).join(dim, (col("id") % 60000) === col("k")).count()
+      sqlContext.range(N).join(dim, (col("id") bitwiseAND M) === col("k")).count()
     }
 
     /*
     Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
-    BroadcastHashJoin:                  Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    Join w long:                        Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
     -------------------------------------------------------------------------------------------
-    Join w long codegen=false        10174 / 10317         10.0         100.0       1.0X
-    Join w long codegen=true           1069 / 1107         98.0          10.2       9.5X
+    Join w long codegen=false                5744 / 5814         18.3          54.8       1.0X
+    Join w long codegen=true                  735 /  853        142.7           7.0       7.8X
     */
 
-    val dim2 = broadcast(sqlContext.range(1 << 16)
+    val dim2 = broadcast(sqlContext.range(M)
       .selectExpr("cast(id as int) as k1", "cast(id as int) as k2", "cast(id as string) as v"))
 
     runBenchmark("Join w 2 ints", N) {
       sqlContext.range(N).join(dim2,
-        (col("id") bitwiseAND 60000).cast(IntegerType) === col("k1")
-          && (col("id") bitwiseAND 50000).cast(IntegerType) === col("k2")).count()
+        (col("id") bitwiseAND M).cast(IntegerType) === col("k1")
+          && (col("id") bitwiseAND M).cast(IntegerType) === col("k2")).count()
     }
 
     /**
     Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
-    BroadcastHashJoin:                  Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    Join w 2 ints:                      Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
     -------------------------------------------------------------------------------------------
-    Join w 2 ints codegen=false           11435 / 11530          9.0         111.1       1.0X
-    Join w 2 ints codegen=true              1265 / 1424         82.0          12.2       9.0X
+    Join w 2 ints codegen=false              7159 / 7224         14.6          68.3       1.0X
+    Join w 2 ints codegen=true               1135 / 1197         92.4          10.8       6.3X
     */
 
+    val dim3 = broadcast(sqlContext.range(M)
+      .selectExpr("id as k1", "id as k2", "cast(id as string) as v"))
+
+    runBenchmark("Join w 2 longs", N) {
+      sqlContext.range(N).join(dim3,
+        (col("id") bitwiseAND M) === col("k1") && (col("id") bitwiseAND M) === col("k2"))
+        .count()
+    }
+
+    /**
+    Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+    Join w 2 longs:                      Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    -------------------------------------------------------------------------------------------
+    Join w 2 longs codegen=false              7877 / 8358         13.3          75.1       1.0X
+    Join w 2 longs codegen=true               3877 / 3937         27.0          37.0       2.0X
+      */
+    runBenchmark("outer join w long", N) {
+      sqlContext.range(N).join(dim, (col("id") bitwiseAND M) === col("k"), "left").count()
+    }
+
+    /**
+    Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+    outer join w long:                  Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    -------------------------------------------------------------------------------------------
+    outer join w long codegen=false        15280 / 16497          6.9         145.7       1.0X
+    outer join w long codegen=true            769 /  796        136.3           7.3      19.9X
+      */
   }
 
   ignore("rube") {
@@ -175,7 +205,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
   }
 
   ignore("hash and BytesToBytesMap") {
-    val N = 50 << 20
+    val N = 10 << 20
 
     val benchmark = new Benchmark("BytesToBytesMap", N)
 
@@ -227,6 +257,80 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
       }
     }
 
+    benchmark.addCase("Java HashMap (Long)") { iter =>
+      var i = 0
+      val keyBytes = new Array[Byte](16)
+      val valueBytes = new Array[Byte](16)
+      val value = new UnsafeRow(1)
+      value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+      value.setInt(0, 555)
+      val map = new HashMap[Long, UnsafeRow]()
+      while (i < 65536) {
+        value.setInt(0, i)
+        map.put(i.toLong, value)
+        i += 1
+      }
+      var s = 0
+      i = 0
+      while (i < N) {
+        if (map.get(i % 100000) != null) {
+          s += 1
+        }
+        i += 1
+      }
+    }
+
+    benchmark.addCase("Java HashMap (two ints) ") { iter =>
+      var i = 0
+      val valueBytes = new Array[Byte](16)
+      val value = new UnsafeRow(1)
+      value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+      value.setInt(0, 555)
+      val map = new HashMap[Long, UnsafeRow]()
+      while (i < 65536) {
+        value.setInt(0, i)
+        val key = (i.toLong << 32) + Integer.rotateRight(i, 15)
+        map.put(key, value)
+        i += 1
+      }
+      var s = 0
+      i = 0
+      while (i < N) {
+        val key = ((i & 100000).toLong << 32) + Integer.rotateRight(i & 100000, 15)
+        if (map.get(key) != null) {
+          s += 1
+        }
+        i += 1
+      }
+    }
+
+    benchmark.addCase("Java HashMap (UnsafeRow)") { iter =>
+      var i = 0
+      val keyBytes = new Array[Byte](16)
+      val valueBytes = new Array[Byte](16)
+      val key = new UnsafeRow(1)
+      key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+      val value = new UnsafeRow(1)
+      value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+      value.setInt(0, 555)
+      val map = new HashMap[UnsafeRow, UnsafeRow]()
+      while (i < 65536) {
+        key.setInt(0, i)
+        value.setInt(0, i)
+        map.put(key, value.copy())
+        i += 1
+      }
+      var s = 0
+      i = 0
+      while (i < N) {
+        key.setInt(0, i % 100000)
+        if (map.get(key) != null) {
+          s += 1
+        }
+        i += 1
+      }
+    }
+
     Seq("off", "on").foreach { heap =>
       benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter =>
         val taskMemoryManager = new TaskMemoryManager(
@@ -268,6 +372,9 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
     hash                                      651 /  678         80.0          12.5       1.0X
     fast hash                                 336 /  343        155.9           6.4       1.9X
     arrayEqual                                417 /  428        125.0           8.0       1.6X
+    Java HashMap (Long)                       145 /  168         72.2          13.8       0.8X
+    Java HashMap (two ints)                   157 /  164         66.8          15.0       0.8X
+    Java HashMap (UnsafeRow)                  538 /  573         19.5          51.3       0.2X
     BytesToBytesMap (off Heap)               2594 / 2664         20.2          49.5       0.2X
     BytesToBytesMap (on Heap)                2693 / 2989         19.5          51.4       0.2X
       */
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index aee8e84db56e2ebde37c46a82cf8f08ef13b9c72..e25b5e0610ea12c4e362835e56e15290b9e8a9bd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -73,7 +73,7 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll {
   }
 
   test("unsafe broadcast hash outer join updates peak execution memory") {
-    testBroadcastJoin[BroadcastHashOuterJoin]("unsafe broadcast hash outer join", "left_outer")
+    testBroadcastJoin[BroadcastHashJoin]("unsafe broadcast hash outer join", "left_outer")
   }
 
   test("unsafe broadcast left semi join updates peak execution memory") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
index 149f34dbd748fad1f6e796f3a1b933317f5b5c3d..e22a810a6b42fe205707b715bd0420a27a22fb60 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
@@ -88,7 +88,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
         leftPlan: SparkPlan,
         rightPlan: SparkPlan,
         side: BuildSide) = {
-      joins.BroadcastHashJoin(leftKeys, rightKeys, side, boundCondition, leftPlan, rightPlan)
+      joins.BroadcastHashJoin(leftKeys, rightKeys, Inner, side, boundCondition, leftPlan, rightPlan)
     }
 
     def makeSortMergeJoin(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
index 3d3e9a7b90928552e562e75be95768c5a9f3ecaf..f4b01fbad05859d57e8bfb4d4912a8ad6031e5d2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
@@ -75,11 +75,16 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext {
     }
 
     if (joinType != FullOuter) {
-      test(s"$testName using BroadcastHashOuterJoin") {
+      test(s"$testName using BroadcastHashJoin") {
+        val buildSide = joinType match {
+          case LeftOuter => BuildRight
+          case RightOuter => BuildLeft
+        }
         extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
           withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
             checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
-              BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right),
+              BroadcastHashJoin(
+                leftKeys, rightKeys, joinType, buildSide, boundCondition, left, right),
               expectedAnswer.map(Row.fromTuple),
               sortAnswers = true)
           }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index f4bc9e501c21c77e471f049c86de2b30acdd828b..46bb699b780a9b57b67006e3af190531c326cd70 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -209,20 +209,20 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
     )
   }
 
-  test("BroadcastHashOuterJoin metrics") {
+  test("BroadcastHashJoin(outer) metrics") {
     val df1 = Seq((1, "a"), (1, "b"), (4, "c")).toDF("key", "value")
     val df2 = Seq((1, "a"), (1, "b"), (2, "c"), (3, "d")).toDF("key2", "value")
     // Assume the execution plan is
-    // ... -> BroadcastHashOuterJoin(nodeId = 0)
+    // ... -> BroadcastHashJoin(nodeId = 0)
     val df = df1.join(broadcast(df2), $"key" === $"key2", "left_outer")
     testSparkPlanMetrics(df, 2, Map(
-      0L -> ("BroadcastHashOuterJoin", Map(
+      0L -> ("BroadcastHashJoin", Map(
         "number of output rows" -> 5L)))
     )
 
     val df3 = df1.join(broadcast(df2), $"key" === $"key2", "right_outer")
     testSparkPlanMetrics(df3, 2, Map(
-      0L -> ("BroadcastHashOuterJoin", Map(
+      0L -> ("BroadcastHashJoin", Map(
         "number of output rows" -> 6L)))
     )
   }