diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index e8d6d587b4824173b7d18da1751aa06269a9cffc..f350784378795b852255fdab797ad5c92d79daa7 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -19,6 +19,7 @@ package org.apache.spark.broadcast import java.io._ import java.nio.ByteBuffer +import java.util.zip.Adler32 import scala.collection.JavaConverters._ import scala.reflect.ClassTag @@ -77,6 +78,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) } // Note: use getSizeAsKb (not bytes) to maintain compatibility if no units are provided blockSize = conf.getSizeAsKb("spark.broadcast.blockSize", "4m").toInt * 1024 + checksumEnabled = conf.getBoolean("spark.broadcast.checksum", true) } setConf(SparkEnv.get.conf) @@ -85,10 +87,27 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) /** Total number of blocks this broadcast variable contains. */ private val numBlocks: Int = writeBlocks(obj) + /** Whether to generate checksum for blocks or not. */ + private var checksumEnabled: Boolean = false + /** The checksum for all the blocks. */ + private var checksums: Array[Int] = _ + override protected def getValue() = { _value } + private def calcChecksum(block: ByteBuffer): Int = { + val adler = new Adler32() + if (block.hasArray) { + adler.update(block.array, block.arrayOffset + block.position, block.limit - block.position) + } else { + val bytes = new Array[Byte](block.remaining()) + block.duplicate.get(bytes) + adler.update(bytes) + } + adler.getValue.toInt + } + /** * Divide the object into multiple blocks and put those blocks in the block manager. * @@ -105,7 +124,13 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) } val blocks = TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec) + if (checksumEnabled) { + checksums = new Array[Int](blocks.length) + } blocks.zipWithIndex.foreach { case (block, i) => + if (checksumEnabled) { + checksums(i) = calcChecksum(block) + } val pieceId = BroadcastBlockId(id, "piece" + i) val bytes = new ChunkedByteBuffer(block.duplicate()) if (!blockManager.putBytes(pieceId, bytes, MEMORY_AND_DISK_SER, tellMaster = true)) { @@ -135,6 +160,13 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) case None => bm.getRemoteBytes(pieceId) match { case Some(b) => + if (checksumEnabled) { + val sum = calcChecksum(b.chunks(0)) + if (sum != checksums(pid)) { + throw new SparkException(s"corrupt remote block $pieceId of $broadcastId:" + + s" $sum != ${checksums(pid)}") + } + } // We found the block from remote executors/driver's BlockManager, so put the block // in this executor's BlockManager. if (!bm.putBytes(pieceId, b, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) {