diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index 8ddac19bf1b57bc418b5bf6343148c4f9f49efab..05c5e2f4cd77bda29fd5b033dd4eeee137a7aa85 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -63,45 +63,16 @@ trait HashJoin {
   protected lazy val (buildKeys, streamedKeys) = {
     require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType),
       "Join keys from two sides should have same types")
-    val lkeys = rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output))
-    val rkeys = rewriteKeyExpr(rightKeys).map(BindReferences.bindReference(_, right.output))
+    val lkeys = HashJoin.rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output))
+    val rkeys = HashJoin.rewriteKeyExpr(rightKeys)
+      .map(BindReferences.bindReference(_, right.output))
     buildSide match {
       case BuildLeft => (lkeys, rkeys)
       case BuildRight => (rkeys, lkeys)
     }
   }
 
-  /**
-   * Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long.
-   *
-   * If not, returns the original expressions.
-   */
-  private def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = {
-    var keyExpr: Expression = null
-    var width = 0
-    keys.foreach { e =>
-      e.dataType match {
-        case dt: IntegralType if dt.defaultSize <= 8 - width =>
-          if (width == 0) {
-            if (e.dataType != LongType) {
-              keyExpr = Cast(e, LongType)
-            } else {
-              keyExpr = e
-            }
-            width = dt.defaultSize
-          } else {
-            val bits = dt.defaultSize * 8
-            keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)),
-              BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1)))
-            width -= bits
-          }
-        // TODO: support BooleanType, DateType and TimestampType
-        case other =>
-          return keys
-      }
-    }
-    keyExpr :: Nil
-  }
+
 
   protected def buildSideKeyGenerator(): Projection =
     UnsafeProjection.create(buildKeys)
@@ -247,3 +218,31 @@ trait HashJoin {
     }
   }
 }
+
+object HashJoin {
+  /**
+   * Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long.
+   *
+   * If not, returns the original expressions.
+   */
+  private[joins] def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = {
+    assert(keys.nonEmpty)
+    // TODO: support BooleanType, DateType and TimestampType
+    if (keys.exists(!_.dataType.isInstanceOf[IntegralType])
+      || keys.map(_.dataType.defaultSize).sum > 8) {
+      return keys
+    }
+
+    var keyExpr: Expression = if (keys.head.dataType != LongType) {
+      Cast(keys.head, LongType)
+    } else {
+      keys.head
+    }
+    keys.tail.foreach { e =>
+      val bits = e.dataType.defaultSize * 8
+      keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)),
+        BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1)))
+    }
+    keyExpr :: Nil
+  }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index 97adffa8ce1012f91df160278b5b3d42d0ddcc67..83db81ea3f1c268859424afdc66811cdf366dda1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -21,11 +21,13 @@ import scala.reflect.ClassTag
 
 import org.apache.spark.AccumulatorSuite
 import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession}
+import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft}
 import org.apache.spark.sql.execution.exchange.EnsureRequirements
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql.types.{LongType, ShortType}
 
 /**
  * Test various broadcast join operators.
@@ -153,4 +155,49 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
       cases.foreach(assertBroadcastJoin)
     }
   }
+
+  test("join key rewritten") {
+    val l = Literal(1L)
+    val i = Literal(2)
+    val s = Literal.create(3, ShortType)
+    val ss = Literal("hello")
+
+    assert(HashJoin.rewriteKeyExpr(l :: Nil) === l :: Nil)
+    assert(HashJoin.rewriteKeyExpr(l :: l :: Nil) === l :: l :: Nil)
+    assert(HashJoin.rewriteKeyExpr(l :: i :: Nil) === l :: i :: Nil)
+
+    assert(HashJoin.rewriteKeyExpr(i :: Nil) === Cast(i, LongType) :: Nil)
+    assert(HashJoin.rewriteKeyExpr(i :: l :: Nil) === i :: l :: Nil)
+    assert(HashJoin.rewriteKeyExpr(i :: i :: Nil) ===
+      BitwiseOr(ShiftLeft(Cast(i, LongType), Literal(32)),
+        BitwiseAnd(Cast(i, LongType), Literal((1L << 32) - 1))) :: Nil)
+    assert(HashJoin.rewriteKeyExpr(i :: i :: i :: Nil) === i :: i :: i :: Nil)
+
+    assert(HashJoin.rewriteKeyExpr(s :: Nil) === Cast(s, LongType) :: Nil)
+    assert(HashJoin.rewriteKeyExpr(s :: l :: Nil) === s :: l :: Nil)
+    assert(HashJoin.rewriteKeyExpr(s :: s :: Nil) ===
+      BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)),
+        BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil)
+    assert(HashJoin.rewriteKeyExpr(s :: s :: s :: Nil) ===
+      BitwiseOr(ShiftLeft(
+        BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)),
+          BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))),
+        Literal(16)),
+        BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil)
+    assert(HashJoin.rewriteKeyExpr(s :: s :: s :: s :: Nil) ===
+      BitwiseOr(ShiftLeft(
+        BitwiseOr(ShiftLeft(
+          BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)),
+            BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))),
+          Literal(16)),
+          BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))),
+        Literal(16)),
+        BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil)
+    assert(HashJoin.rewriteKeyExpr(s :: s :: s :: s :: s :: Nil) ===
+      s :: s :: s :: s :: s :: Nil)
+
+    assert(HashJoin.rewriteKeyExpr(ss :: Nil) === ss :: Nil)
+    assert(HashJoin.rewriteKeyExpr(l :: ss :: Nil) === l :: ss :: Nil)
+    assert(HashJoin.rewriteKeyExpr(i :: ss :: Nil) === i :: ss :: Nil)
+  }
 }