diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
index 645eb48d5a51bf11866d0b03024d8dcf115c2a17..5f8a6f88717224cf53353894b2c1ab4ad45a9a96 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
@@ -40,10 +40,6 @@ abstract class UnsafeRowJoiner {
  */
 object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), UnsafeRowJoiner] {
 
-  def dump(word: Long): String = {
-    Seq.tabulate(64) { i => if ((word >> i) % 2 == 0) "0" else "1" }.reverse.mkString
-  }
-
   override protected def create(in: (StructType, StructType)): UnsafeRowJoiner = {
     create(in._1, in._2)
   }
@@ -56,76 +52,45 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
   }
 
   def create(schema1: StructType, schema2: StructType): UnsafeRowJoiner = {
-    val ctx = newCodeGenContext()
     val offset = PlatformDependent.BYTE_ARRAY_OFFSET
+    val getLong = "PlatformDependent.UNSAFE.getLong"
+    val putLong = "PlatformDependent.UNSAFE.putLong"
 
     val bitset1Words = (schema1.size + 63) / 64
     val bitset2Words = (schema2.size + 63) / 64
     val outputBitsetWords = (schema1.size + schema2.size + 63) / 64
     val bitset1Remainder = schema1.size % 64
-    val bitset2Remainder = schema2.size % 64
 
     // The number of words we can reduce when we concat two rows together.
     // The only reduction comes from merging the bitset portion of the two rows, saving 1 word.
     val sizeReduction = bitset1Words + bitset2Words - outputBitsetWords
 
-    // --------------------- copy bitset from row 1 ----------------------- //
-    val copyBitset1 = Seq.tabulate(bitset1Words) { i =>
-      s"""
-         |PlatformDependent.UNSAFE.putLong(buf, ${offset + i * 8},
-         |  PlatformDependent.UNSAFE.getLong(obj1, ${offset + i * 8}));
-       """.stripMargin
-    }.mkString
-
-
-    // --------------------- copy bitset from row 2 ----------------------- //
-    var copyBitset2 = ""
-    if (bitset1Remainder == 0) {
-      copyBitset2 += Seq.tabulate(bitset2Words) { i =>
-        s"""
-           |PlatformDependent.UNSAFE.putLong(buf, ${offset + (bitset1Words + i) * 8},
-           |  PlatformDependent.UNSAFE.getLong(obj2, ${offset + i * 8}));
-         """.stripMargin
-      }.mkString
-    } else {
-      copyBitset2 = Seq.tabulate(bitset2Words) { i =>
-        s"""
-           |long bs2w$i = PlatformDependent.UNSAFE.getLong(obj2, ${offset + i * 8});
-           |long bs2w${i}p1 = (bs2w$i << $bitset1Remainder) & ~((1L << $bitset1Remainder) - 1);
-           |long bs2w${i}p2 = (bs2w$i >>> ${64 - bitset1Remainder});
-         """.stripMargin
-      }.mkString
-
-      copyBitset2 += Seq.tabulate(bitset2Words) { i =>
-        val currentOffset = offset + (bitset1Words + i - 1) * 8
-        if (i == 0) {
-          if (bitset1Words > 0) {
-            s"""
-               |PlatformDependent.UNSAFE.putLong(buf, $currentOffset,
-               |  bs2w${i}p1 | PlatformDependent.UNSAFE.getLong(obj1, $currentOffset));
-            """.stripMargin
-          } else {
-            s"""
-               |PlatformDependent.UNSAFE.putLong(buf, $currentOffset + 8, bs2w${i}p1);
-            """.stripMargin
-          }
+    // --------------------- copy bitset from row 1 and row 2 --------------------------- //
+    val copyBitset = Seq.tabulate(outputBitsetWords) { i =>
+      val bits = if (bitset1Remainder > 0) {
+        if (i < bitset1Words - 1) {
+          s"$getLong(obj1, offset1 + ${i * 8})"
+        } else if (i == bitset1Words - 1) {
+          // combine last work of bitset1 and first word of bitset2
+          s"$getLong(obj1, offset1 + ${i * 8}) | ($getLong(obj2, offset2) << $bitset1Remainder)"
+        } else if (i - bitset1Words < bitset2Words - 1) {
+          // combine next two words of bitset2
+          s"($getLong(obj2, offset2 + ${(i - bitset1Words) * 8}) >>> (64 - $bitset1Remainder))" +
+            s"| ($getLong(obj2, offset2 + ${(i - bitset1Words + 1) * 8}) << $bitset1Remainder)"
+        } else {
+          // last word of bitset2
+          s"$getLong(obj2, offset2 + ${(i - bitset1Words) * 8}) >>> (64 - $bitset1Remainder)"
+        }
+      } else {
+        // they are aligned by word
+        if (i < bitset1Words) {
+          s"$getLong(obj1, offset1 + ${i * 8})"
         } else {
-          s"""
-             |PlatformDependent.UNSAFE.putLong(buf, $currentOffset, bs2w${i}p1 | bs2w${i - 1}p2);
-          """.stripMargin
+          s"$getLong(obj2, offset2 + ${(i - bitset1Words) * 8})"
         }
-      }.mkString("\n")
-
-      if (bitset2Words > 0 &&
-        (bitset2Remainder == 0 || bitset2Remainder > (64 - bitset1Remainder))) {
-        val lastWord = bitset2Words - 1
-        copyBitset2 +=
-          s"""
-             |PlatformDependent.UNSAFE.putLong(buf, ${offset + (outputBitsetWords - 1) * 8},
-             |  bs2w${lastWord}p2);
-          """.stripMargin
       }
-    }
+      s"$putLong(buf, ${offset + i * 8}, $bits);"
+    }.mkString("\n")
 
     // --------------------- copy fixed length portion from row 1 ----------------------- //
     var cursor = offset + outputBitsetWords * 8
@@ -149,10 +114,10 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
     cursor += schema2.size * 8
 
     // --------------------- copy variable length portion from row 1 ----------------------- //
+    val numBytesBitsetAndFixedRow1 = (bitset1Words + schema1.size) * 8
     val copyVariableLengthRow1 = s"""
        |// Copy variable length data for row1
-       |long numBytesBitsetAndFixedRow1 = ${(bitset1Words + schema1.size) * 8};
-       |long numBytesVariableRow1 = row1.getSizeInBytes() - numBytesBitsetAndFixedRow1;
+       |long numBytesVariableRow1 = row1.getSizeInBytes() - $numBytesBitsetAndFixedRow1;
        |PlatformDependent.copyMemory(
        |  obj1, offset1 + ${(bitset1Words + schema1.size) * 8},
        |  buf, $cursor,
@@ -160,10 +125,10 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
      """.stripMargin
 
     // --------------------- copy variable length portion from row 2 ----------------------- //
+    val numBytesBitsetAndFixedRow2 = (bitset2Words + schema2.size) * 8
     val copyVariableLengthRow2 = s"""
        |// Copy variable length data for row2
-       |long numBytesBitsetAndFixedRow2 = ${(bitset2Words + schema2.size) * 8};
-       |long numBytesVariableRow2 = row2.getSizeInBytes() - numBytesBitsetAndFixedRow2;
+       |long numBytesVariableRow2 = row2.getSizeInBytes() - $numBytesBitsetAndFixedRow2;
        |PlatformDependent.copyMemory(
        |  obj2, offset2 + ${(bitset2Words + schema2.size) * 8},
        |  buf, $cursor + numBytesVariableRow1,
@@ -183,12 +148,11 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
           if (i < schema1.size) {
             s"${(outputBitsetWords - bitset1Words + schema2.size) * 8}L"
           } else {
-            s"${(outputBitsetWords - bitset2Words + schema1.size) * 8}L + numBytesVariableRow1"
+            s"(${(outputBitsetWords - bitset2Words + schema1.size) * 8}L + numBytesVariableRow1)"
           }
         val cursor = offset + outputBitsetWords * 8 + i * 8
         s"""
-           |PlatformDependent.UNSAFE.putLong(buf, $cursor,
-           |  PlatformDependent.UNSAFE.getLong(buf, $cursor) + ($shift << 32));
+           |$putLong(buf, $cursor, $getLong(buf, $cursor) + ($shift << 32));
          """.stripMargin
       }
     }.mkString
@@ -217,8 +181,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
        |    final Object obj2 = row2.getBaseObject();
        |    final long offset2 = row2.getBaseOffset();
        |
-       |    $copyBitset1
-       |    $copyBitset2
+       |    $copyBitset
        |    $copyFixedLengthRow1
        |    $copyFixedLengthRow2
        |    $copyVariableLengthRow1
@@ -233,7 +196,6 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
      """.stripMargin
 
     logDebug(s"SpecificUnsafeRowJoiner($schema1, $schema2):\n${CodeFormatter.format(code)}")
-    // println(CodeFormatter.format(code))
 
     val c = compile(code)
     c.generate(Array.empty).asInstanceOf[UnsafeRowJoiner]
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala
index 76d9d991ed0dc902b371e788180479b75b750d42..718a2acc8281d10e87f45d5ca6be3f5aadaf249a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala
@@ -22,6 +22,7 @@ import scala.util.Random
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.PlatformDependent
 
 /**
  * A test suite for the bitset portion of the row concatenation.
@@ -91,8 +92,9 @@ class GenerateUnsafeRowJoinerBitsetSuite extends SparkFunSuite {
   private def createUnsafeRow(numFields: Int): UnsafeRow = {
     val row = new UnsafeRow
     val sizeInBytes = numFields * 8 + ((numFields + 63) / 64) * 8
-    val buf = new Array[Byte](sizeInBytes)
-    row.pointTo(buf, numFields, sizeInBytes)
+    val offset = numFields * 8
+    val buf = new Array[Byte](sizeInBytes + offset)
+    row.pointTo(buf, PlatformDependent.BYTE_ARRAY_OFFSET + offset, numFields, sizeInBytes)
     row
   }
 
@@ -133,6 +135,7 @@ class GenerateUnsafeRowJoinerBitsetSuite extends SparkFunSuite {
          |input1: ${set1.mkString}
          |input2: ${set2.mkString}
          |output: ${out.mkString}
+         |expect: ${set1.mkString}${set2.mkString}
        """.stripMargin
     }