diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
index 1f99fbedde411d8f9ecc738818702007a5144869..d3bcfad7c3de0c88188caa079fbfad978e580825 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
@@ -26,8 +26,8 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
 import org.apache.spark.sql.execution.metric.SQLMetrics
 
 /**
- * Build the right table's join keys into a HashSet, and iteratively go through the left
- * table, to find the if join keys are in the Hash set.
+ * Build the right table's join keys into a HashedRelation, and iteratively go through the left
+ * table, to find if the join keys are in the HashedRelation.
  */
 case class BroadcastLeftSemiJoinHash(
     leftKeys: Seq[Expression],
@@ -40,29 +40,18 @@ case class BroadcastLeftSemiJoinHash(
     "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
 
   override def requiredChildDistribution: Seq[Distribution] = {
-    val mode = if (condition.isEmpty) {
-      HashSetBroadcastMode(rightKeys, right.output)
-    } else {
-      HashedRelationBroadcastMode(canJoinKeyFitWithinLong = false, rightKeys, right.output)
-    }
+    val mode = HashedRelationBroadcastMode(canJoinKeyFitWithinLong = false, rightKeys, right.output)
     UnspecifiedDistribution :: BroadcastDistribution(mode) :: Nil
   }
 
   protected override def doExecute(): RDD[InternalRow] = {
     val numOutputRows = longMetric("numOutputRows")
 
-    if (condition.isEmpty) {
-      val broadcastedRelation = right.executeBroadcast[java.util.Set[InternalRow]]()
-      left.execute().mapPartitionsInternal { streamIter =>
-        hashSemiJoin(streamIter, broadcastedRelation.value, numOutputRows)
-      }
-    } else {
-      val broadcastedRelation = right.executeBroadcast[HashedRelation]()
-      left.execute().mapPartitionsInternal { streamIter =>
-        val hashedRelation = broadcastedRelation.value
-        TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize)
-        hashSemiJoin(streamIter, hashedRelation, numOutputRows)
-      }
+    val broadcastedRelation = right.executeBroadcast[HashedRelation]()
+    left.execute().mapPartitionsInternal { streamIter =>
+      val hashedRelation = broadcastedRelation.value
+      TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize)
+      hashSemiJoin(streamIter, hashedRelation, numOutputRows)
     }
   }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
index 1cb6a00617c5eb521ff7805deafa5361af83320f..3eed6e3e11131c593f63c694dd7615c5006fb45e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
@@ -43,24 +43,6 @@ trait HashSemiJoin {
   @transient private lazy val boundCondition =
     newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
 
-  protected def buildKeyHashSet(
-      buildIter: Iterator[InternalRow]): java.util.Set[InternalRow] = {
-    HashSemiJoin.buildKeyHashSet(rightKeys, right.output, buildIter)
-  }
-
-  protected def hashSemiJoin(
-    streamIter: Iterator[InternalRow],
-    hashSet: java.util.Set[InternalRow],
-    numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
-    val joinKeys = leftKeyGenerator
-    streamIter.filter(current => {
-      val key = joinKeys(current)
-      val r = !key.anyNull && hashSet.contains(key)
-      if (r) numOutputRows += 1
-      r
-    })
-  }
-
   protected def hashSemiJoin(
       streamIter: Iterator[InternalRow],
       hashedRelation: HashedRelation,
@@ -70,44 +52,11 @@ trait HashSemiJoin {
     streamIter.filter { current =>
       val key = joinKeys(current)
       lazy val rowBuffer = hashedRelation.get(key)
-      val r = !key.anyNull && rowBuffer != null && rowBuffer.exists {
+      val r = !key.anyNull && rowBuffer != null && (condition.isEmpty || rowBuffer.exists {
         (row: InternalRow) => boundCondition(joinedRow(current, row))
-      }
+      })
       if (r) numOutputRows += 1
       r
     }
   }
 }
-
-private[execution] object HashSemiJoin {
-  def buildKeyHashSet(
-    keys: Seq[Expression],
-    attributes: Seq[Attribute],
-    rows: Iterator[InternalRow]): java.util.HashSet[InternalRow] = {
-    val hashSet = new java.util.HashSet[InternalRow]()
-
-    // Create a Hash set of buildKeys
-    val key = UnsafeProjection.create(keys, attributes)
-    while (rows.hasNext) {
-      val currentRow = rows.next()
-      val rowKey = key(currentRow)
-      if (!rowKey.anyNull) {
-        val keyExists = hashSet.contains(rowKey)
-        if (!keyExists) {
-          hashSet.add(rowKey.copy())
-        }
-      }
-    }
-    hashSet
-  }
-}
-
-/** HashSetBroadcastMode requires that the input rows are broadcasted as a set. */
-private[execution] case class HashSetBroadcastMode(
-    keys: Seq[Expression],
-    attributes: Seq[Attribute]) extends BroadcastMode {
-
-  override def transform(rows: Array[InternalRow]): java.util.HashSet[InternalRow] = {
-    HashSemiJoin.buildKeyHashSet(keys, attributes, rows.iterator)
-  }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
index d8d3045ccf5c310516693515c2f0ae8644372048..242ed612327cc35324a193abbd716e7219c995cc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
@@ -25,8 +25,8 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
 import org.apache.spark.sql.execution.metric.SQLMetrics
 
 /**
- * Build the right table's join keys into a HashSet, and iteratively go through the left
- * table, to find the if join keys are in the Hash set.
+ * Build the right table's join keys into a HashedRelation, and iteratively go through the left
+ * table, to find if the join keys are in the HashedRelation.
  */
 case class LeftSemiJoinHash(
     leftKeys: Seq[Expression],
@@ -47,13 +47,8 @@ case class LeftSemiJoinHash(
     val numOutputRows = longMetric("numOutputRows")
 
     right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) =>
-      if (condition.isEmpty) {
-        val hashSet = buildKeyHashSet(buildIter)
-        hashSemiJoin(streamIter, hashSet, numOutputRows)
-      } else {
-        val hashRelation = HashedRelation(buildIter, rightKeyGenerator)
-        hashSemiJoin(streamIter, hashRelation, numOutputRows)
-      }
+      val hashRelation = HashedRelation(buildIter, rightKeyGenerator)
+      hashSemiJoin(streamIter, hashRelation, numOutputRows)
     }
   }
 }