diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 2eef6a7c1038d14508b8c951ae86d86efb2a4828..2cf46e82b029d046e9b47bb6281a01d6d1075871 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -26,7 +26,7 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.{Logging, SparkEnv} -import org.apache.spark.serializer.Serializer +import org.apache.spark.serializer.{KryoDeserializationStream, KryoSerializationStream, Serializer} import org.apache.spark.storage.{BlockId, BlockManager, DiskBlockManager, DiskBlockObjectWriter} /** @@ -333,7 +333,18 @@ private[spark] class ExternalAppendOnlyMap[K, V, C]( if (!eof) { try { if (objectsRead == serializerBatchSize) { - deserializeStream = ser.deserializeStream(compressedStream) + val newInputStream = deserializeStream match { + case stream: KryoDeserializationStream => + // Kryo's serializer stores an internal buffer that pre-fetches from the underlying + // stream. We need to capture this buffer and feed it to the new serialization + // stream so that the bytes are not lost. + val kryoInput = stream.input + val remainingBytes = kryoInput.limit() - kryoInput.position() + val extraBuf = kryoInput.readBytes(remainingBytes) + new SequenceInputStream(new ByteArrayInputStream(extraBuf), compressedStream) + case _ => compressedStream + } + deserializeStream = ser.deserializeStream(newInputStream) objectsRead = 0 } objectsRead += 1