diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 83170f7c5a4ab9d1ce22976a1fcd10a9779dfbb8..2499c11a65b0e6999a2298d80e45be80906e5baa 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -17,6 +17,7 @@ package org.apache.spark.storage +import java.io.{InputStream, IOException} import java.util.concurrent.LinkedBlockingQueue import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} @@ -289,17 +290,22 @@ final class ShuffleBlockFetcherIterator( } val iteratorTry: Try[Iterator[Any]] = result match { - case FailureFetchResult(_, e) => Failure(e) - case SuccessFetchResult(blockId, _, buf) => { - val is = blockManager.wrapForCompression(blockId, buf.createInputStream()) - val iter = serializer.newInstance().deserializeStream(is).asIterator - Success(CompletionIterator[Any, Iterator[Any]](iter, { - // Once the iterator is exhausted, release the buffer and set currentResult to null - // so we don't release it again in cleanup. - currentResult = null - buf.release() - })) - } + case FailureFetchResult(_, e) => + Failure(e) + case SuccessFetchResult(blockId, _, buf) => + // There is a chance that createInputStream can fail (e.g. fetching a local file that does + // not exist, SPARK-4085). In that case, we should propagate the right exception so + // the scheduler gets a FetchFailedException. + Try(buf.createInputStream()).map { is0 => + val is = blockManager.wrapForCompression(blockId, is0) + val iter = serializer.newInstance().deserializeStream(is).asIterator + CompletionIterator[Any, Iterator[Any]](iter, { + // Once the iterator is exhausted, release the buffer and set currentResult to null + // so we don't release it again in cleanup. + currentResult = null + buf.release() + }) + } } (result.blockId, iteratorTry) diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index cc3592ee43a35f758e44fb29c3f98df058b14af5..bac6fdbcdc976f39c35581b28e379fa4c66e54e2 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark -import java.util.concurrent.atomic.AtomicInteger - import org.scalatest.BeforeAndAfterAll import org.apache.spark.network.TransportContext diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 5a133c0490444a6e4daec5ba8e5542c55ac5ac90..58a96245a9b53cd26a7ed1ef651f1cfa4d0e7b4e 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -23,6 +23,7 @@ import org.scalatest.Matchers import org.apache.spark.ShuffleSuite.NonJavaSerializableClass import org.apache.spark.rdd.{CoGroupedRDD, OrderedRDDFunctions, RDD, ShuffledRDD, SubtractedRDD} import org.apache.spark.serializer.KryoSerializer +import org.apache.spark.storage.{ShuffleDataBlockId, ShuffleBlockId} import org.apache.spark.util.MutablePair abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { @@ -263,6 +264,28 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex } } } + + test("[SPARK-4085] rerun map stage if reduce stage cannot find its local shuffle file") { + val myConf = conf.clone().set("spark.test.noStageRetry", "false") + sc = new SparkContext("local", "test", myConf) + val rdd = sc.parallelize(1 to 10, 2).map((_, 1)).reduceByKey(_ + _) + rdd.count() + + // Delete one of the local shuffle blocks. + val hashFile = sc.env.blockManager.diskBlockManager.getFile(new ShuffleBlockId(0, 0, 0)) + val sortFile = sc.env.blockManager.diskBlockManager.getFile(new ShuffleDataBlockId(0, 0, 0)) + assert(hashFile.exists() || sortFile.exists()) + + if (hashFile.exists()) { + hashFile.delete() + } + if (sortFile.exists()) { + sortFile.delete() + } + + // This count should retry the execution of the previous stage and rerun shuffle. + rdd.count() + } } object ShuffleSuite {