Skip to content
Snippets Groups Projects
Commit 58b764b7 authored by Kay Ousterhout's avatar Kay Ousterhout
Browse files

Addressed Matei's code review comments

parent c75eb14f
No related branches found
No related tags found
No related merge requests found
...@@ -26,10 +26,7 @@ import java.nio.ByteBuffer ...@@ -26,10 +26,7 @@ import java.nio.ByteBuffer
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
// Task result. Also contains updates to accumulator variables. // Task result. Also contains updates to accumulator variables.
// TODO: Use of distributed cache to return result is a hack to get around private[spark] sealed trait TaskResult[T]
// what seems to be a bug with messages over 60KB in libprocess; fix it
private[spark]
sealed abstract class TaskResult[T]
/** A reference to a DirectTaskResult that has been stored in the worker's BlockManager. */ /** A reference to a DirectTaskResult that has been stored in the worker's BlockManager. */
private[spark] private[spark]
......
...@@ -100,7 +100,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) ...@@ -100,7 +100,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
System.getProperty("spark.scheduler.mode", "FIFO")) System.getProperty("spark.scheduler.mode", "FIFO"))
// This is a var so that we can reset it for testing purposes. // This is a var so that we can reset it for testing purposes.
private[spark] var taskResultResolver = new TaskResultResolver(sc.env, this) private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this)
override def setListener(listener: TaskSchedulerListener) { override def setListener(listener: TaskSchedulerListener) {
this.listener = listener this.listener = listener
...@@ -267,10 +267,10 @@ private[spark] class ClusterScheduler(val sc: SparkContext) ...@@ -267,10 +267,10 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
activeTaskSets.get(taskSetId).foreach { taskSet => activeTaskSets.get(taskSetId).foreach { taskSet =>
if (state == TaskState.FINISHED) { if (state == TaskState.FINISHED) {
taskSet.removeRunningTask(tid) taskSet.removeRunningTask(tid)
taskResultResolver.enqueueSuccessfulTask(taskSet, tid, serializedData) taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
} else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) { } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
taskSet.removeRunningTask(tid) taskSet.removeRunningTask(tid)
taskResultResolver.enqueueFailedTask(taskSet, tid, state, serializedData) taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
} }
} }
case None => case None =>
...@@ -338,8 +338,8 @@ private[spark] class ClusterScheduler(val sc: SparkContext) ...@@ -338,8 +338,8 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
if (jarServer != null) { if (jarServer != null) {
jarServer.stop() jarServer.stop()
} }
if (taskResultResolver != null) { if (taskResultGetter != null) {
taskResultResolver.stop() taskResultGetter.stop()
} }
// sleeping for an arbitrary 5 seconds : to ensure that messages are sent out. // sleeping for an arbitrary 5 seconds : to ensure that messages are sent out.
......
...@@ -25,7 +25,6 @@ import scala.collection.mutable.HashMap ...@@ -25,7 +25,6 @@ import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet import scala.collection.mutable.HashSet
import scala.math.max import scala.math.max
import scala.math.min import scala.math.min
import scala.Some
import org.apache.spark._ import org.apache.spark._
import org.apache.spark.TaskState.TaskState import org.apache.spark.TaskState.TaskState
...@@ -458,8 +457,6 @@ private[spark] class ClusterTaskSetManager( ...@@ -458,8 +457,6 @@ private[spark] class ClusterTaskSetManager(
removeRunningTask(tid) removeRunningTask(tid)
val index = info.index val index = info.index
info.markFailed() info.markFailed()
// Count failed attempts only on FAILED and LOST state (not on KILLED)
var countFailedTaskAttempt = (state == TaskState.FAILED || state == TaskState.LOST)
if (!successful(index)) { if (!successful(index)) {
logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index)) logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
copiesRunning(index) -= 1 copiesRunning(index) -= 1
...@@ -505,7 +502,6 @@ private[spark] class ClusterTaskSetManager( ...@@ -505,7 +502,6 @@ private[spark] class ClusterTaskSetManager(
case TaskResultLost => case TaskResultLost =>
logInfo("Lost result for TID %s on host %s".format(tid, info.host)) logInfo("Lost result for TID %s on host %s".format(tid, info.host))
countFailedTaskAttempt = true
sched.listener.taskEnded(tasks(index), TaskResultLost, null, null, info, null) sched.listener.taskEnded(tasks(index), TaskResultLost, null, null, info, null)
case _ => {} case _ => {}
...@@ -513,7 +509,7 @@ private[spark] class ClusterTaskSetManager( ...@@ -513,7 +509,7 @@ private[spark] class ClusterTaskSetManager(
} }
// On non-fetch failures, re-enqueue the task as pending for a max number of retries // On non-fetch failures, re-enqueue the task as pending for a max number of retries
addPendingTask(index) addPendingTask(index)
if (countFailedTaskAttempt) { if (state != TaskState.KILLED) {
numFailures(index) += 1 numFailures(index) += 1
if (numFailures(index) > MAX_TASK_FAILURES) { if (numFailures(index) > MAX_TASK_FAILURES) {
logError("Task %s:%d failed more than %d times; aborting job".format( logError("Task %s:%d failed more than %d times; aborting job".format(
......
...@@ -26,17 +26,16 @@ import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, TaskRes ...@@ -26,17 +26,16 @@ import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, TaskRes
import org.apache.spark.serializer.SerializerInstance import org.apache.spark.serializer.SerializerInstance
/** /**
* Runs a thread pool that deserializes and remotely fetches (if neceessary) task results. * Runs a thread pool that deserializes and remotely fetches (if necessary) task results.
*/ */
private[spark] class TaskResultResolver(sparkEnv: SparkEnv, scheduler: ClusterScheduler) private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterScheduler)
extends Logging { extends Logging {
private val MIN_THREADS = 20 private val MIN_THREADS = System.getProperty("spark.resultGetter.minThreads", "4").toInt
private val MAX_THREADS = 60 private val MAX_THREADS = System.getProperty("spark.resultGetter.maxThreads", "4").toInt
private val KEEP_ALIVE_SECONDS = 60
private val getTaskResultExecutor = new ThreadPoolExecutor( private val getTaskResultExecutor = new ThreadPoolExecutor(
MIN_THREADS, MIN_THREADS,
MAX_THREADS, MAX_THREADS,
KEEP_ALIVE_SECONDS, 0L,
TimeUnit.SECONDS, TimeUnit.SECONDS,
new LinkedBlockingDeque[Runnable], new LinkedBlockingDeque[Runnable],
new ResultResolverThreadFactory) new ResultResolverThreadFactory)
......
...@@ -253,6 +253,23 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo ...@@ -253,6 +253,23 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo
assert(manager.resourceOffer("exec2", "host2", 1, ANY) === None) assert(manager.resourceOffer("exec2", "host2", 1, ANY) === None)
} }
test("task result lost") {
sc = new SparkContext("local", "test")
val sched = new FakeClusterScheduler(sc, ("exec1", "host1"))
val taskSet = createTaskSet(1)
val clock = new FakeClock
val manager = new ClusterTaskSetManager(sched, taskSet, clock)
assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0)
// Tell it the task has finished but the result was lost.
manager.handleFailedTask(0, TaskState.FINISHED, Some(TaskResultLost))
assert(sched.endedTasks(0) === TaskResultLost)
// Re-offer the host -- now we should get task 0 again.
assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0)
}
/** /**
* Utility method to create a TaskSet, potentially setting a particular sequence of preferred * Utility method to create a TaskSet, potentially setting a particular sequence of preferred
* locations for each task (given as varargs) if this sequence is not empty. * locations for each task (given as varargs) if this sequence is not empty.
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
* limitations under the License. * limitations under the License.
*/ */
package org.apache.spark.scheduler package org.apache.spark.scheduler.cluster
import java.nio.ByteBuffer import java.nio.ByteBuffer
...@@ -23,16 +23,16 @@ import org.scalatest.BeforeAndAfter ...@@ -23,16 +23,16 @@ import org.scalatest.BeforeAndAfter
import org.scalatest.FunSuite import org.scalatest.FunSuite
import org.apache.spark.{LocalSparkContext, SparkContext, SparkEnv} import org.apache.spark.{LocalSparkContext, SparkContext, SparkEnv}
import org.apache.spark.scheduler.cluster.{ClusterScheduler, ClusterTaskSetManager, TaskResultResolver} import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, TaskResult}
/** /**
* Removes the TaskResult from the BlockManager before delegating to a normal TaskResultResolver. * Removes the TaskResult from the BlockManager before delegating to a normal TaskResultGetter.
* *
* Used to test the case where a BlockManager evicts the task result (or dies) before the * Used to test the case where a BlockManager evicts the task result (or dies) before the
* TaskResult is retrieved. * TaskResult is retrieved.
*/ */
class ResultDeletingTaskResultResolver(sparkEnv: SparkEnv, scheduler: ClusterScheduler) class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterScheduler)
extends TaskResultResolver(sparkEnv, scheduler) { extends TaskResultGetter(sparkEnv, scheduler) {
var removedResult = false var removedResult = false
override def enqueueSuccessfulTask( override def enqueueSuccessfulTask(
...@@ -44,7 +44,7 @@ class ResultDeletingTaskResultResolver(sparkEnv: SparkEnv, scheduler: ClusterSch ...@@ -44,7 +44,7 @@ class ResultDeletingTaskResultResolver(sparkEnv: SparkEnv, scheduler: ClusterSch
case IndirectTaskResult(blockId) => case IndirectTaskResult(blockId) =>
sparkEnv.blockManager.master.removeBlock(blockId) sparkEnv.blockManager.master.removeBlock(blockId)
case directResult: DirectTaskResult[_] => case directResult: DirectTaskResult[_] =>
taskSetManager.abort("Expect only indirect results") taskSetManager.abort("Internal error: expect only indirect results")
} }
serializedData.rewind() serializedData.rewind()
removedResult = true removedResult = true
...@@ -56,9 +56,11 @@ class ResultDeletingTaskResultResolver(sparkEnv: SparkEnv, scheduler: ClusterSch ...@@ -56,9 +56,11 @@ class ResultDeletingTaskResultResolver(sparkEnv: SparkEnv, scheduler: ClusterSch
/** /**
* Tests related to handling task results (both direct and indirect). * Tests related to handling task results (both direct and indirect).
*/ */
class TaskResultResolverSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {
override def beforeAll() {
super.beforeAll()
before {
// Set the Akka frame size to be as small as possible (it must be an integer, so 1 is as small // Set the Akka frame size to be as small as possible (it must be an integer, so 1 is as small
// as we can make it) so the tests don't take too long. // as we can make it) so the tests don't take too long.
System.setProperty("spark.akka.frameSize", "1") System.setProperty("spark.akka.frameSize", "1")
...@@ -67,6 +69,11 @@ class TaskResultResolverSuite extends FunSuite with BeforeAndAfter with LocalSpa ...@@ -67,6 +69,11 @@ class TaskResultResolverSuite extends FunSuite with BeforeAndAfter with LocalSpa
sc = new SparkContext("local-cluster[1,1,512]", "test") sc = new SparkContext("local-cluster[1,1,512]", "test")
} }
override def afterAll() {
super.afterAll()
System.clearProperty("spark.akka.frameSize")
}
test("handling results smaller than Akka frame size") { test("handling results smaller than Akka frame size") {
val result = sc.parallelize(Seq(1), 1).map(x => 2 * x).reduce((x, y) => x) val result = sc.parallelize(Seq(1), 1).map(x => 2 * x).reduce((x, y) => x)
assert(result === 2) assert(result === 2)
...@@ -93,7 +100,7 @@ class TaskResultResolverSuite extends FunSuite with BeforeAndAfter with LocalSpa ...@@ -93,7 +100,7 @@ class TaskResultResolverSuite extends FunSuite with BeforeAndAfter with LocalSpa
assert(false, "Expect local cluster to use ClusterScheduler") assert(false, "Expect local cluster to use ClusterScheduler")
throw new ClassCastException throw new ClassCastException
} }
scheduler.taskResultResolver = new ResultDeletingTaskResultResolver(sc.env, scheduler) scheduler.taskResultGetter = new ResultDeletingTaskResultGetter(sc.env, scheduler)
val akkaFrameSize = val akkaFrameSize =
sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt
val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x) val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment