diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index c7893f288b4b5f606b9cb917a001dce3dfe13b81..811610c657b62475c378588ccd5f0d4945f0b036 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -47,7 +47,12 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { if (loading.contains(key)) { logInfo("Another thread is loading %s, waiting for it to finish...".format(key)) while (loading.contains(key)) { - try {loading.wait()} catch {case _ : Throwable =>} + try { + loading.wait() + } catch { + case e: Exception => + logWarning(s"Got an exception while waiting for another thread to load $key", e) + } } logInfo("Finished waiting for %s".format(key)) /* See whether someone else has successfully loaded it. The main way this would fail @@ -72,7 +77,9 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { val computedValues = rdd.computeOrReadCheckpoint(split, context) // Persist the result, so long as the task is not running locally - if (context.runningLocally) { return computedValues } + if (context.runningLocally) { + return computedValues + } // Keep track of blocks with updated statuses var updatedBlocks = Seq[(BlockId, BlockStatus)]() @@ -88,7 +95,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { updatedBlocks = blockManager.put(key, computedValues, storageLevel, tellMaster = true) blockManager.get(key) match { case Some(values) => - new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]]) + values.asInstanceOf[Iterator[T]] case None => logInfo("Failure to store %s".format(key)) throw new Exception("Block manager failed to return persisted valued") @@ -107,7 +114,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { val metrics = context.taskMetrics metrics.updatedBlocks = Some(updatedBlocks) - returnValue + new InterruptibleIterator(context, returnValue) } finally { loading.synchronized { diff --git a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala index fd1802ba2f984c9caecf7607dd523d6019a9e1e1..ec11dbbffaaf8a4dc1a7119b816b0319e87ff6fb 100644 --- a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala +++ b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala @@ -24,7 +24,17 @@ package org.apache.spark private[spark] class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator[T]) extends Iterator[T] { - def hasNext: Boolean = !context.interrupted && delegate.hasNext + def hasNext: Boolean = { + // TODO(aarondav/rxin): Check Thread.interrupted instead of context.interrupted if interrupt + // is allowed. The assumption is that Thread.interrupted does not have a memory fence in read + // (just a volatile field in C), while context.interrupted is a volatile in the JVM, which + // introduces an expensive read fence. + if (context.interrupted) { + throw new TaskKilledException + } else { + delegate.hasNext + } + } def next(): T = delegate.next() } diff --git a/core/src/main/scala/org/apache/spark/TaskKilledException.scala b/core/src/main/scala/org/apache/spark/TaskKilledException.scala new file mode 100644 index 0000000000000000000000000000000000000000..cbd6b2866e4f987e8ef320ab4e55b560805b1a3a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/TaskKilledException.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +/** + * Exception for a task getting killed. + */ +private[spark] class TaskKilledException extends RuntimeException diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 914bc205cebe277a915fd010013997e743192ebb..272bcda5f8f2f4d6c2b883bf7d30555c46380a27 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -161,8 +161,6 @@ private[spark] class Executor( class TaskRunner(execBackend: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) extends Runnable { - object TaskKilledException extends Exception - @volatile private var killed = false @volatile private var task: Task[Any] = _ @@ -200,7 +198,7 @@ private[spark] class Executor( // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl // exception will be caught by the catch block, leading to an incorrect ExceptionFailure // for the task. - throw TaskKilledException + throw new TaskKilledException } attemptedTask = Some(task) @@ -214,7 +212,7 @@ private[spark] class Executor( // If the task has been killed, let's fail it. if (task.killed) { - throw TaskKilledException + throw new TaskKilledException } val resultSer = SparkEnv.get.serializer.newInstance() @@ -257,7 +255,7 @@ private[spark] class Executor( execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) } - case TaskKilledException | _: InterruptedException if task.killed => { + case _: TaskKilledException | _: InterruptedException if task.killed => { logInfo("Executor killed task " + taskId) execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) } diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 16cfdf11c4a385cb943620f45fcd801734e70c03..2c8ef405c944cb626b70bc9919f9e37c57b5b56a 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -84,6 +84,35 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf assert(sc.parallelize(1 to 10, 2).count === 10) } + test("do not put partially executed partitions into cache") { + // In this test case, we create a scenario in which a partition is only partially executed, + // and make sure CacheManager does not put that partially executed partition into the + // BlockManager. + import JobCancellationSuite._ + sc = new SparkContext("local", "test") + + // Run from 1 to 10, and then block and wait for the task to be killed. + val rdd = sc.parallelize(1 to 1000, 2).map { x => + if (x > 10) { + taskStartedSemaphore.release() + taskCancelledSemaphore.acquire() + } + x + }.cache() + + val rdd1 = rdd.map(x => x) + + future { + taskStartedSemaphore.acquire() + sc.cancelAllJobs() + taskCancelledSemaphore.release(100000) + } + + intercept[SparkException] { rdd1.count() } + // If the partial block is put into cache, rdd.count() would return a number less than 1000. + assert(rdd.count() === 1000) + } + test("job group") { sc = new SparkContext("local[2]", "test") @@ -114,7 +143,6 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf assert(jobB.get() === 100) } - test("job group with interruption") { sc = new SparkContext("local[2]", "test") @@ -145,15 +173,14 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf assert(jobB.get() === 100) } -/* - test("two jobs sharing the same stage") { + ignore("two jobs sharing the same stage") { // sem1: make sure cancel is issued after some tasks are launched // sem2: make sure the first stage is not finished until cancel is issued val sem1 = new Semaphore(0) val sem2 = new Semaphore(0) sc = new SparkContext("local[2]", "test") - sc.dagScheduler.addSparkListener(new SparkListener { + sc.addSparkListener(new SparkListener { override def onTaskStart(taskStart: SparkListenerTaskStart) { sem1.release() } @@ -179,7 +206,7 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf intercept[SparkException] { f1.get() } intercept[SparkException] { f2.get() } } - */ + def testCount() { // Cancel before launching any tasks { @@ -238,3 +265,9 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf } } } + + +object JobCancellationSuite { + val taskStartedSemaphore = new Semaphore(0) + val taskCancelledSemaphore = new Semaphore(0) +}