diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index ea02076b41a6ffd2903891a1764d5f9f5dd9344a..6c0196c21a0d16d04a9818f30855b7066ba5d500 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.SparkSqlSerializer
 import org.apache.spark.sql.execution.metric.LongSQLMetric
 import org.apache.spark.unsafe.Platform
 import org.apache.spark.unsafe.map.BytesToBytesMap
-import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
+import org.apache.spark.unsafe.memory.{MemoryLocation, ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
 import org.apache.spark.util.Utils
 import org.apache.spark.util.collection.CompactBuffer
 import org.apache.spark.{SparkConf, SparkEnv}
@@ -247,40 +247,67 @@ private[joins] final class UnsafeHashedRelation(
   }
 
   override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
-    out.writeInt(hashTable.size())
-
-    val iter = hashTable.entrySet().iterator()
-    while (iter.hasNext) {
-      val entry = iter.next()
-      val key = entry.getKey
-      val values = entry.getValue
-
-      // write all the values as single byte array
-      var totalSize = 0L
-      var i = 0
-      while (i < values.length) {
-        totalSize += values(i).getSizeInBytes + 4 + 4
-        i += 1
+    if (binaryMap != null) {
+      // This could happen when a cached broadcast object need to be dumped into disk to free memory
+      out.writeInt(binaryMap.numElements())
+
+      var buffer = new Array[Byte](64)
+      def write(addr: MemoryLocation, length: Int): Unit = {
+        if (buffer.length < length) {
+          buffer = new Array[Byte](length)
+        }
+        Platform.copyMemory(addr.getBaseObject, addr.getBaseOffset,
+          buffer, Platform.BYTE_ARRAY_OFFSET, length)
+        out.write(buffer, 0, length)
       }
-      assert(totalSize < Integer.MAX_VALUE, "values are too big")
-
-      // [key size] [values size] [key bytes] [values bytes]
-      out.writeInt(key.getSizeInBytes)
-      out.writeInt(totalSize.toInt)
-      out.write(key.getBytes)
-      i = 0
-      while (i < values.length) {
-        // [num of fields] [num of bytes] [row bytes]
-        // write the integer in native order, so they can be read by UNSAFE.getInt()
-        if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) {
-          out.writeInt(values(i).numFields())
-          out.writeInt(values(i).getSizeInBytes)
-        } else {
-          out.writeInt(Integer.reverseBytes(values(i).numFields()))
-          out.writeInt(Integer.reverseBytes(values(i).getSizeInBytes))
+
+      val iter = binaryMap.iterator()
+      while (iter.hasNext) {
+        val loc = iter.next()
+        // [key size] [values size] [key bytes] [values bytes]
+        out.writeInt(loc.getKeyLength)
+        out.writeInt(loc.getValueLength)
+        write(loc.getKeyAddress, loc.getKeyLength)
+        write(loc.getValueAddress, loc.getValueLength)
+      }
+
+    } else {
+      assert(hashTable != null)
+      out.writeInt(hashTable.size())
+
+      val iter = hashTable.entrySet().iterator()
+      while (iter.hasNext) {
+        val entry = iter.next()
+        val key = entry.getKey
+        val values = entry.getValue
+
+        // write all the values as single byte array
+        var totalSize = 0L
+        var i = 0
+        while (i < values.length) {
+          totalSize += values(i).getSizeInBytes + 4 + 4
+          i += 1
+        }
+        assert(totalSize < Integer.MAX_VALUE, "values are too big")
+
+        // [key size] [values size] [key bytes] [values bytes]
+        out.writeInt(key.getSizeInBytes)
+        out.writeInt(totalSize.toInt)
+        out.write(key.getBytes)
+        i = 0
+        while (i < values.length) {
+          // [num of fields] [num of bytes] [row bytes]
+          // write the integer in native order, so they can be read by UNSAFE.getInt()
+          if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) {
+            out.writeInt(values(i).numFields())
+            out.writeInt(values(i).getSizeInBytes)
+          } else {
+            out.writeInt(Integer.reverseBytes(values(i).numFields()))
+            out.writeInt(Integer.reverseBytes(values(i).getSizeInBytes))
+          }
+          out.write(values(i).getBytes)
+          i += 1
         }
-        out.write(values(i).getBytes)
-        i += 1
       }
     }
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index c635b2d51f46442daca0c5fcb5489ee01ae3cb01..d33a967093ca5edd66b58ea43d0ee76cd4bd4325 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -102,6 +102,14 @@ class HashedRelationSuite extends SparkFunSuite {
     assert(hashed2.get(toUnsafe(InternalRow(10))) === null)
     assert(hashed2.get(unsafeData(2)) === data2)
     assert(numDataRows.value.value === data.length)
+
+    val os2 = new ByteArrayOutputStream()
+    val out2 = new ObjectOutputStream(os2)
+    hashed2.asInstanceOf[UnsafeHashedRelation].writeExternal(out2)
+    out2.flush()
+    // This depends on that the order of items in BytesToBytesMap.iterator() is exactly the same
+    // as they are inserted
+    assert(java.util.Arrays.equals(os2.toByteArray, os.toByteArray))
   }
 
   test("test serialization empty hash map") {
@@ -119,5 +127,11 @@ class HashedRelationSuite extends SparkFunSuite {
     val toUnsafe = UnsafeProjection.create(schema)
     val row = toUnsafe(InternalRow(0))
     assert(hashed2.get(row) === null)
+
+    val os2 = new ByteArrayOutputStream()
+    val out2 = new ObjectOutputStream(os2)
+    hashed2.writeExternal(out2)
+    out2.flush()
+    assert(java.util.Arrays.equals(os2.toByteArray, os.toByteArray))
   }
 }