diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index 2eef6a7c1038d14508b8c951ae86d86efb2a4828..2cf46e82b029d046e9b47bb6281a01d6d1075871 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -26,7 +26,7 @@ import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.{Logging, SparkEnv}
-import org.apache.spark.serializer.Serializer
+import org.apache.spark.serializer.{KryoDeserializationStream, KryoSerializationStream, Serializer}
 import org.apache.spark.storage.{BlockId, BlockManager, DiskBlockManager, DiskBlockObjectWriter}
 
 /**
@@ -333,7 +333,18 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
       if (!eof) {
         try {
           if (objectsRead == serializerBatchSize) {
-            deserializeStream = ser.deserializeStream(compressedStream)
+            val newInputStream = deserializeStream match {
+              case stream: KryoDeserializationStream =>
+                // Kryo's serializer stores an internal buffer that pre-fetches from the underlying
+                // stream. We need to capture this buffer and feed it to the new serialization
+                // stream so that the bytes are not lost.
+                val kryoInput = stream.input
+                val remainingBytes = kryoInput.limit() - kryoInput.position()
+                val extraBuf = kryoInput.readBytes(remainingBytes)
+                new SequenceInputStream(new ByteArrayInputStream(extraBuf), compressedStream)
+              case _ => compressedStream
+            }
+            deserializeStream = ser.deserializeStream(newInputStream)
             objectsRead = 0
           }
           objectsRead += 1