Skip to content
Snippets Groups Projects
Commit 2584ea5b authored by Reynold Xin's avatar Reynold Xin
Browse files

[SPARK-3469] Make sure all TaskCompletionListener are called even with failures

This is necessary because we rely on this callback interface to clean resources up. The old behavior would lead to resource leaks.

Note that this also changes the fault semantics of TaskCompletionListener. Previously failures in TaskCompletionListeners would result in the task being reported immediately. With this change, we report the exception at the end, and the reported exception is a TaskCompletionListenerException that contains all the exception messages.

Author: Reynold Xin <rxin@apache.org>

Closes #2343 from rxin/taskcontext-callback and squashes the following commits:

a3845b2 [Reynold Xin] Mark TaskCompletionListenerException as private[spark].
ac5baea [Reynold Xin] Removed obsolete comment.
aa68ea4 [Reynold Xin] Throw an exception if task completion callback fails.
29b6162 [Reynold Xin] oops compilation failed.
1cb444d [Reynold Xin] [SPARK-3469] Call all TaskCompletionListeners even if some fail.
parent 6d887db7
No related branches found
No related tags found
No related merge requests found
......@@ -21,7 +21,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.util.TaskCompletionListener
import org.apache.spark.util.{TaskCompletionListenerException, TaskCompletionListener}
/**
......@@ -41,7 +41,7 @@ class TaskContext(
val attemptId: Long,
val runningLocally: Boolean = false,
private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty)
extends Serializable {
extends Serializable with Logging {
@deprecated("use partitionId", "0.8.1")
def splitId = partitionId
......@@ -103,8 +103,20 @@ class TaskContext(
/** Marks the task as completed and triggers the listeners. */
private[spark] def markTaskCompleted(): Unit = {
completed = true
val errorMsgs = new ArrayBuffer[String](2)
// Process complete callbacks in the reverse order of registration
onCompleteCallbacks.reverse.foreach { _.onTaskCompletion(this) }
onCompleteCallbacks.reverse.foreach { listener =>
try {
listener.onTaskCompletion(this)
} catch {
case e: Throwable =>
errorMsgs += e.getMessage
logError("Error in TaskCompletionListener", e)
}
}
if (errorMsgs.nonEmpty) {
throw new TaskCompletionListenerException(errorMsgs)
}
}
/** Marks the task for interruption, i.e. cancellation. */
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.util
/**
* Exception thrown when there is an exception in
* executing the callback in TaskCompletionListener.
*/
private[spark]
class TaskCompletionListenerException(errorMessages: Seq[String]) extends Exception {
override def getMessage: String = {
if (errorMessages.size == 1) {
errorMessages.head
} else {
errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n")
}
}
}
......@@ -17,16 +17,20 @@
package org.apache.spark.scheduler
import org.mockito.Mockito._
import org.mockito.Matchers.any
import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfter
import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
import org.apache.spark.util.{TaskCompletionListenerException, TaskCompletionListener}
class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {
test("Calls executeOnCompleteCallbacks after failure") {
test("calls TaskCompletionListener after failure") {
TaskContextSuite.completed = false
sc = new SparkContext("local", "test")
val rdd = new RDD[String](sc, List()) {
......@@ -45,6 +49,20 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte
}
assert(TaskContextSuite.completed === true)
}
test("all TaskCompletionListeners should be called even if some fail") {
val context = new TaskContext(0, 0, 0)
val listener = mock(classOf[TaskCompletionListener])
context.addTaskCompletionListener(_ => throw new Exception("blah"))
context.addTaskCompletionListener(listener)
context.addTaskCompletionListener(_ => throw new Exception("blah"))
intercept[TaskCompletionListenerException] {
context.markTaskCompleted()
}
verify(listener, times(1)).onTaskCompletion(any())
}
}
private object TaskContextSuite {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment