Skip to content
Snippets Groups Projects
Commit 06a56df2 authored by Davies Liu's avatar Davies Liu Committed by Shixiong Zhu
Browse files

[SPARK-18188] add checksum for blocks of broadcast


## What changes were proposed in this pull request?

A TorrentBroadcast is serialized and compressed first, then splitted as fixed size blocks, if any block is corrupt when fetching from remote, the decompression/deserialization will fail without knowing which block is corrupt. Also, the corrupt block is kept in block manager and reported to driver, so other tasks (in same executor or from different executor) will also fail because of it.

This PR add checksum for each block, and check it after fetching a block from remote executor, because it's very likely that the corruption happen in network. When the corruption happen, it will throw the block away and throw an exception to fail the task, which will be retried.

Added a config for it: `spark.broadcast.checksum`, which is true by default.

## How was this patch tested?

Existing tests.

Author: Davies Liu <davies@databricks.com>

Closes #15935 from davies/broadcast_checksum.

(cherry picked from commit 7d5cb3af)
Signed-off-by: default avatarShixiong Zhu <shixiong@databricks.com>
parent ea6957da
No related branches found
No related tags found
No related merge requests found
......@@ -19,6 +19,7 @@ package org.apache.spark.broadcast
import java.io._
import java.nio.ByteBuffer
import java.util.zip.Adler32
import scala.collection.JavaConverters._
import scala.reflect.ClassTag
......@@ -77,6 +78,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
}
// Note: use getSizeAsKb (not bytes) to maintain compatibility if no units are provided
blockSize = conf.getSizeAsKb("spark.broadcast.blockSize", "4m").toInt * 1024
checksumEnabled = conf.getBoolean("spark.broadcast.checksum", true)
}
setConf(SparkEnv.get.conf)
......@@ -85,10 +87,27 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
/** Total number of blocks this broadcast variable contains. */
private val numBlocks: Int = writeBlocks(obj)
/** Whether to generate checksum for blocks or not. */
private var checksumEnabled: Boolean = false
/** The checksum for all the blocks. */
private var checksums: Array[Int] = _
override protected def getValue() = {
_value
}
private def calcChecksum(block: ByteBuffer): Int = {
val adler = new Adler32()
if (block.hasArray) {
adler.update(block.array, block.arrayOffset + block.position, block.limit - block.position)
} else {
val bytes = new Array[Byte](block.remaining())
block.duplicate.get(bytes)
adler.update(bytes)
}
adler.getValue.toInt
}
/**
* Divide the object into multiple blocks and put those blocks in the block manager.
*
......@@ -105,7 +124,13 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
}
val blocks =
TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec)
if (checksumEnabled) {
checksums = new Array[Int](blocks.length)
}
blocks.zipWithIndex.foreach { case (block, i) =>
if (checksumEnabled) {
checksums(i) = calcChecksum(block)
}
val pieceId = BroadcastBlockId(id, "piece" + i)
val bytes = new ChunkedByteBuffer(block.duplicate())
if (!blockManager.putBytes(pieceId, bytes, MEMORY_AND_DISK_SER, tellMaster = true)) {
......@@ -135,6 +160,13 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
case None =>
bm.getRemoteBytes(pieceId) match {
case Some(b) =>
if (checksumEnabled) {
val sum = calcChecksum(b.chunks(0))
if (sum != checksums(pid)) {
throw new SparkException(s"corrupt remote block $pieceId of $broadcastId:" +
s" $sum != ${checksums(pid)}")
}
}
// We found the block from remote executors/driver's BlockManager, so put the block
// in this executor's BlockManager.
if (!bm.putBytes(pieceId, b, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment