diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
index b9d83495d29b63eaa7af3f135f41bf2c21a21d78..8b2e26cdd94fbfd2712330cfc102da573324e788 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
@@ -42,24 +42,21 @@ private[spark] class BlockStoreShuffleReader[K, C](
 
   /** Read the combined key-values for this reduce task */
   override def read(): Iterator[Product2[K, C]] = {
-    val blockFetcherItr = new ShuffleBlockFetcherIterator(
+    val wrappedStreams = new ShuffleBlockFetcherIterator(
       context,
       blockManager.shuffleClient,
       blockManager,
       mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
+      serializerManager.wrapStream,
       // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
       SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
-      SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue))
-
-    // Wrap the streams for compression and encryption based on configuration
-    val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
-      serializerManager.wrapStream(blockId, inputStream)
-    }
+      SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
+      SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))
 
     val serializerInstance = dep.serializer.newInstance()
 
     // Create a key/value iterator for each stream
-    val recordIter = wrappedStreams.flatMap { wrappedStream =>
+    val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
       // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
       // NextIterator. The NextIterator makes sure that close() is called on the
       // underlying InputStream when all records have been read.
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index 269c12d6da4447314f867c364326dc6b50c292da..b720aaee7c7e9ef82a56cbedb13bd585a3dc1c5a 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -17,19 +17,21 @@
 
 package org.apache.spark.storage
 
-import java.io.InputStream
+import java.io.{InputStream, IOException}
+import java.nio.ByteBuffer
 import java.util.concurrent.LinkedBlockingQueue
 import javax.annotation.concurrent.GuardedBy
 
+import scala.collection.mutable
 import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
-import scala.util.control.NonFatal
 
 import org.apache.spark.{SparkException, TaskContext}
 import org.apache.spark.internal.Logging
-import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
 import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
 import org.apache.spark.shuffle.FetchFailedException
 import org.apache.spark.util.Utils
+import org.apache.spark.util.io.ChunkedByteBufferOutputStream
 
 /**
  * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
@@ -47,8 +49,10 @@ import org.apache.spark.util.Utils
  * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]].
  *                        For each block we also require the size (in bytes as a long field) in
  *                        order to throttle the memory usage.
+ * @param streamWrapper A function to wrap the returned input stream.
  * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point.
  * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point.
+ * @param detectCorrupt whether to detect any corruption in fetched blocks.
  */
 private[spark]
 final class ShuffleBlockFetcherIterator(
@@ -56,8 +60,10 @@ final class ShuffleBlockFetcherIterator(
     shuffleClient: ShuffleClient,
     blockManager: BlockManager,
     blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
+    streamWrapper: (BlockId, InputStream) => InputStream,
     maxBytesInFlight: Long,
-    maxReqsInFlight: Int)
+    maxReqsInFlight: Int,
+    detectCorrupt: Boolean)
   extends Iterator[(BlockId, InputStream)] with Logging {
 
   import ShuffleBlockFetcherIterator._
@@ -94,7 +100,7 @@ final class ShuffleBlockFetcherIterator(
    * Current [[FetchResult]] being processed. We track this so we can release the current buffer
    * in case of a runtime exception when processing the current buffer.
    */
-  @volatile private[this] var currentResult: FetchResult = null
+  @volatile private[this] var currentResult: SuccessFetchResult = null
 
   /**
    * Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
@@ -108,6 +114,12 @@ final class ShuffleBlockFetcherIterator(
   /** Current number of requests in flight */
   private[this] var reqsInFlight = 0
 
+  /**
+   * The blocks that can't be decompressed successfully, it is used to guarantee that we retry
+   * at most once for those corrupted blocks.
+   */
+  private[this] val corruptedBlocks = mutable.HashSet[BlockId]()
+
   private[this] val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics()
 
   /**
@@ -123,9 +135,8 @@ final class ShuffleBlockFetcherIterator(
   // The currentResult is set to null to prevent releasing the buffer again on cleanup()
   private[storage] def releaseCurrentResultBuffer(): Unit = {
     // Release the current buffer if necessary
-    currentResult match {
-      case SuccessFetchResult(_, _, _, buf, _) => buf.release()
-      case _ =>
+    if (currentResult != null) {
+      currentResult.buf.release()
     }
     currentResult = null
   }
@@ -305,40 +316,84 @@ final class ShuffleBlockFetcherIterator(
    */
   override def next(): (BlockId, InputStream) = {
     numBlocksProcessed += 1
-    val startFetchWait = System.currentTimeMillis()
-    currentResult = results.take()
-    val result = currentResult
-    val stopFetchWait = System.currentTimeMillis()
-    shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)
-
-    result match {
-      case SuccessFetchResult(_, address, size, buf, isNetworkReqDone) =>
-        if (address != blockManager.blockManagerId) {
-          shuffleMetrics.incRemoteBytesRead(buf.size)
-          shuffleMetrics.incRemoteBlocksFetched(1)
-        }
-        bytesInFlight -= size
-        if (isNetworkReqDone) {
-          reqsInFlight -= 1
-          logDebug("Number of requests in flight " + reqsInFlight)
-        }
-      case _ =>
-    }
-    // Send fetch requests up to maxBytesInFlight
-    fetchUpToMaxBytes()
 
-    result match {
-      case FailureFetchResult(blockId, address, e) =>
-        throwFetchFailedException(blockId, address, e)
+    var result: FetchResult = null
+    var input: InputStream = null
+    // Take the next fetched result and try to decompress it to detect data corruption,
+    // then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch
+    // is also corrupt, so the previous stage could be retried.
+    // For local shuffle block, throw FailureFetchResult for the first IOException.
+    while (result == null) {
+      val startFetchWait = System.currentTimeMillis()
+      result = results.take()
+      val stopFetchWait = System.currentTimeMillis()
+      shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)
 
-      case SuccessFetchResult(blockId, address, _, buf, _) =>
-        try {
-          (result.blockId, new BufferReleasingInputStream(buf.createInputStream(), this))
-        } catch {
-          case NonFatal(t) =>
-            throwFetchFailedException(blockId, address, t)
-        }
+      result match {
+        case r @ SuccessFetchResult(blockId, address, size, buf, isNetworkReqDone) =>
+          if (address != blockManager.blockManagerId) {
+            shuffleMetrics.incRemoteBytesRead(buf.size)
+            shuffleMetrics.incRemoteBlocksFetched(1)
+          }
+          bytesInFlight -= size
+          if (isNetworkReqDone) {
+            reqsInFlight -= 1
+            logDebug("Number of requests in flight " + reqsInFlight)
+          }
+
+          val in = try {
+            buf.createInputStream()
+          } catch {
+            // The exception could only be throwed by local shuffle block
+            case e: IOException =>
+              assert(buf.isInstanceOf[FileSegmentManagedBuffer])
+              logError("Failed to create input stream from local block", e)
+              buf.release()
+              throwFetchFailedException(blockId, address, e)
+          }
+
+          input = streamWrapper(blockId, in)
+          // Only copy the stream if it's wrapped by compression or encryption, also the size of
+          // block is small (the decompressed block is smaller than maxBytesInFlight)
+          if (detectCorrupt && !input.eq(in) && size < maxBytesInFlight / 3) {
+            val originalInput = input
+            val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate)
+            try {
+              // Decompress the whole block at once to detect any corruption, which could increase
+              // the memory usage tne potential increase the chance of OOM.
+              // TODO: manage the memory used here, and spill it into disk in case of OOM.
+              Utils.copyStream(input, out)
+              out.close()
+              input = out.toChunkedByteBuffer.toInputStream(dispose = true)
+            } catch {
+              case e: IOException =>
+                buf.release()
+                if (buf.isInstanceOf[FileSegmentManagedBuffer]
+                  || corruptedBlocks.contains(blockId)) {
+                  throwFetchFailedException(blockId, address, e)
+                } else {
+                  logWarning(s"got an corrupted block $blockId from $address, fetch again", e)
+                  corruptedBlocks += blockId
+                  fetchRequests += FetchRequest(address, Array((blockId, size)))
+                  result = null
+                }
+            } finally {
+              // TODO: release the buf here to free memory earlier
+              originalInput.close()
+              in.close()
+            }
+          }
+
+        case FailureFetchResult(blockId, address, e) =>
+          throwFetchFailedException(blockId, address, e)
+      }
+
+      // Send fetch requests up to maxBytesInFlight
+      fetchUpToMaxBytes()
     }
+
+    currentResult = result.asInstanceOf[SuccessFetchResult]
+    (currentResult.blockId, new BufferReleasingInputStream(input, this))
   }
 
   private def fetchUpToMaxBytes(): Unit = {
diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
index da08661d137d0c11bd14fb07e4fb2d8179b7ec1e..7572cac39317c714bb738e3ded7b4681397899db 100644
--- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
@@ -151,7 +151,7 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {
  * @param dispose if true, `ChunkedByteBuffer.dispose()` will be called at the end of the stream
  *                in order to close any memory-mapped files which back the buffer.
  */
-private class ChunkedByteBufferInputStream(
+private[spark] class ChunkedByteBufferInputStream(
     var chunkedByteBuffer: ChunkedByteBuffer,
     dispose: Boolean)
   extends InputStream {
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index e3ec99685f73c8eb7cff116a9278116ac705d6ec..e56e440380a54ce236bccd1bdaa99fbfdc5bfcd1 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.storage
 
-import java.io.InputStream
+import java.io.{File, InputStream, IOException}
 import java.util.concurrent.Semaphore
 
 import scala.concurrent.ExecutionContext.Implicits.global
@@ -31,8 +31,9 @@ import org.scalatest.PrivateMethodTester
 
 import org.apache.spark.{SparkFunSuite, TaskContext}
 import org.apache.spark.network._
-import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
 import org.apache.spark.network.shuffle.BlockFetchingListener
+import org.apache.spark.network.util.LimitedInputStream
 import org.apache.spark.shuffle.FetchFailedException
 
 
@@ -63,7 +64,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
   // Create a mock managed buffer for testing
   def createMockManagedBuffer(): ManagedBuffer = {
     val mockManagedBuffer = mock(classOf[ManagedBuffer])
-    when(mockManagedBuffer.createInputStream()).thenReturn(mock(classOf[InputStream]))
+    val in = mock(classOf[InputStream])
+    when(in.read(any())).thenReturn(1)
+    when(in.read(any(), any(), any())).thenReturn(1)
+    when(mockManagedBuffer.createInputStream()).thenReturn(in)
     mockManagedBuffer
   }
 
@@ -99,8 +103,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
       transfer,
       blockManager,
       blocksByAddress,
+      (_, in) => in,
       48 * 1024 * 1024,
-      Int.MaxValue)
+      Int.MaxValue,
+      true)
 
     // 3 local blocks fetched in initialization
     verify(blockManager, times(3)).getBlockData(any())
@@ -172,8 +178,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
       transfer,
       blockManager,
       blocksByAddress,
+      (_, in) => in,
       48 * 1024 * 1024,
-      Int.MaxValue)
+      Int.MaxValue,
+      true)
 
     verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release()
     iterator.next()._2.close() // close() first block's input stream
@@ -201,9 +209,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
     // Make sure remote blocks would return
     val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
     val blocks = Map[BlockId, ManagedBuffer](
-      ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]),
-      ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]),
-      ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer])
+      ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()
     )
 
     // Semaphore to coordinate event sequence in two different threads.
@@ -235,8 +243,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
       transfer,
       blockManager,
       blocksByAddress,
+      (_, in) => in,
       48 * 1024 * 1024,
-      Int.MaxValue)
+      Int.MaxValue,
+      true)
 
     // Continue only after the mock calls onBlockFetchFailure
     sem.acquire()
@@ -247,4 +257,148 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
     intercept[FetchFailedException] { iterator.next() }
     intercept[FetchFailedException] { iterator.next() }
   }
+
+  test("retry corrupt blocks") {
+    val blockManager = mock(classOf[BlockManager])
+    val localBmId = BlockManagerId("test-client", "test-client", 1)
+    doReturn(localBmId).when(blockManager).blockManagerId
+
+    // Make sure remote blocks would return
+    val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
+    val blocks = Map[BlockId, ManagedBuffer](
+      ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()
+    )
+
+    // Semaphore to coordinate event sequence in two different threads.
+    val sem = new Semaphore(0)
+
+    val corruptStream = mock(classOf[InputStream])
+    when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt"))
+    val corruptBuffer = mock(classOf[ManagedBuffer])
+    when(corruptBuffer.createInputStream()).thenReturn(corruptStream)
+    val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"), 0, 100)
+
+    val transfer = mock(classOf[BlockTransferService])
+    when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
+      override def answer(invocation: InvocationOnMock): Unit = {
+        val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
+        Future {
+          // Return the first block, and then fail.
+          listener.onBlockFetchSuccess(
+            ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
+          listener.onBlockFetchSuccess(
+            ShuffleBlockId(0, 1, 0).toString, corruptBuffer)
+          listener.onBlockFetchSuccess(
+            ShuffleBlockId(0, 2, 0).toString, corruptLocalBuffer)
+          sem.release()
+        }
+      }
+    })
+
+    val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
+      (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))
+
+    val taskContext = TaskContext.empty()
+    val iterator = new ShuffleBlockFetcherIterator(
+      taskContext,
+      transfer,
+      blockManager,
+      blocksByAddress,
+      (_, in) => new LimitedInputStream(in, 100),
+      48 * 1024 * 1024,
+      Int.MaxValue,
+      true)
+
+    // Continue only after the mock calls onBlockFetchFailure
+    sem.acquire()
+
+    // The first block should be returned without an exception
+    val (id1, _) = iterator.next()
+    assert(id1 === ShuffleBlockId(0, 0, 0))
+
+    when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
+      override def answer(invocation: InvocationOnMock): Unit = {
+        val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
+        Future {
+          // Return the first block, and then fail.
+          listener.onBlockFetchSuccess(
+            ShuffleBlockId(0, 1, 0).toString, corruptBuffer)
+          sem.release()
+        }
+      }
+    })
+
+    // The next block is corrupt local block (the second one is corrupt and retried)
+    intercept[FetchFailedException] { iterator.next() }
+
+    sem.acquire()
+    intercept[FetchFailedException] { iterator.next() }
+  }
+
+  test("retry corrupt blocks (disabled)") {
+    val blockManager = mock(classOf[BlockManager])
+    val localBmId = BlockManagerId("test-client", "test-client", 1)
+    doReturn(localBmId).when(blockManager).blockManagerId
+
+    // Make sure remote blocks would return
+    val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
+    val blocks = Map[BlockId, ManagedBuffer](
+      ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()
+    )
+
+    // Semaphore to coordinate event sequence in two different threads.
+    val sem = new Semaphore(0)
+
+    val corruptStream = mock(classOf[InputStream])
+    when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt"))
+    val corruptBuffer = mock(classOf[ManagedBuffer])
+    when(corruptBuffer.createInputStream()).thenReturn(corruptStream)
+
+    val transfer = mock(classOf[BlockTransferService])
+    when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
+      override def answer(invocation: InvocationOnMock): Unit = {
+        val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
+        Future {
+          // Return the first block, and then fail.
+          listener.onBlockFetchSuccess(
+            ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
+          listener.onBlockFetchSuccess(
+            ShuffleBlockId(0, 1, 0).toString, corruptBuffer)
+          listener.onBlockFetchSuccess(
+            ShuffleBlockId(0, 2, 0).toString, corruptBuffer)
+          sem.release()
+        }
+      }
+    })
+
+    val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
+      (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))
+
+    val taskContext = TaskContext.empty()
+    val iterator = new ShuffleBlockFetcherIterator(
+      taskContext,
+      transfer,
+      blockManager,
+      blocksByAddress,
+      (_, in) => new LimitedInputStream(in, 100),
+      48 * 1024 * 1024,
+      Int.MaxValue,
+      false)
+
+    // Continue only after the mock calls onBlockFetchFailure
+    sem.acquire()
+
+    // The first block should be returned without an exception
+    val (id1, _) = iterator.next()
+    assert(id1 === ShuffleBlockId(0, 0, 0))
+    val (id2, _) = iterator.next()
+    assert(id2 === ShuffleBlockId(0, 1, 0))
+    val (id3, _) = iterator.next()
+    assert(id3 === ShuffleBlockId(0, 2, 0))
+  }
+
 }